tensorflow `tf.gather_nd`和`tf.gather`在批处理维度(batch_dims)的类型检查上不一致,

pxy2qtax  于 3个月前  发布在  其他
关注(0)|答案(1)|浏览(25)

系统信息

  • 是否编写了自定义代码(与在TensorFlow中使用的库存示例脚本相反):Y
  • 操作系统平台和发行版(例如,Linux Ubuntu 16.04):Ubuntu 18.04
  • 移动设备(例如iPhone 8,Pixel 2,三星Galaxy)如果问题发生在移动设备上:N
  • 从哪里安装的TensorFlow(源代码或二进制文件):二进制
  • TensorFlow版本(请使用以下命令):2.8.0
  • Python版本:3.8
  • Bazel版本(如果从源代码编译):N/A
  • GCC/编译器版本(如果从源代码编译):N/A
  • CUDA/cuDNN版本: N/A
  • GPU型号和内存: N/A
    重现问题的独立代码
import tensorflow as tf
params = tf.random.uniform([3, 1, 12, 64], dtype=tf.float32)
indices = tf.random.uniform([35, 2], minval=0, maxval=1, dtype=tf.int64)
batch_dims = False
tf.gather_nd(params, indices, batch_dims=batch_dims) # Pass
tf.gather(params, indices, batch_dims=batch_dims) # InvalidArgumentError

详细错误信息:

InvalidArgumentError: Value for attr 'Taxis' of bool is not in the list of allowed values: int32, int64
	; NodeDef: {{node GatherV2}}; Op<name=GatherV2; signature=params:Tparams, indices:Tindices, axis:Taxis -> output:Tparams; attr=batch_dims:int,default=0; attr=Tparams:type; attr=Tindices:type,allowed=[DT_INT32, DT_INT64]; attr=Taxis:type,allowed=[DT_INT32, DT_INT64]> [Op:GatherV2]

描述当前行为

在上面的代码中, batch_dims 是一个 bool ,而不是一个 inttf.gather 对这种类型不匹配抱怨并抛出 InvalidArgumentError 。然而, tf.gather_nd 会进行隐式转换并将 False 转换为 0 。类型检查存在不一致性。

描述预期行为

在所有情况下,要么允许隐式的 bool-int 转换,要么在所有情况下抛出错误。

相关问题