系统信息
- 是否编写了自定义代码(与在TensorFlow中使用的库存示例脚本相反):Y
- 操作系统平台和发行版(例如,Linux Ubuntu 16.04):Ubuntu 18.04
- 移动设备(例如iPhone 8,Pixel 2,三星Galaxy)如果问题发生在移动设备上:N
- 从哪里安装的TensorFlow(源代码或二进制文件):二进制
- TensorFlow版本(请使用以下命令):2.8.0
- Python版本:3.8
- Bazel版本(如果从源代码编译):N/A
- GCC/编译器版本(如果从源代码编译):N/A
- CUDA/cuDNN版本: N/A
- GPU型号和内存: N/A
重现问题的独立代码
import tensorflow as tf
params = tf.random.uniform([3, 1, 12, 64], dtype=tf.float32)
indices = tf.random.uniform([35, 2], minval=0, maxval=1, dtype=tf.int64)
batch_dims = False
tf.gather_nd(params, indices, batch_dims=batch_dims) # Pass
tf.gather(params, indices, batch_dims=batch_dims) # InvalidArgumentError
详细错误信息:
InvalidArgumentError: Value for attr 'Taxis' of bool is not in the list of allowed values: int32, int64
; NodeDef: {{node GatherV2}}; Op<name=GatherV2; signature=params:Tparams, indices:Tindices, axis:Taxis -> output:Tparams; attr=batch_dims:int,default=0; attr=Tparams:type; attr=Tindices:type,allowed=[DT_INT32, DT_INT64]; attr=Taxis:type,allowed=[DT_INT32, DT_INT64]> [Op:GatherV2]
描述当前行为
在上面的代码中, batch_dims
是一个 bool
,而不是一个 int
。 tf.gather
对这种类型不匹配抱怨并抛出 InvalidArgumentError
。然而, tf.gather_nd
会进行隐式转换并将 False
转换为 0
。类型检查存在不一致性。
描述预期行为
在所有情况下,要么允许隐式的 bool
-int
转换,要么在所有情况下抛出错误。
1条答案
按热度按时间f45qwnt81#
已为修复添加PR #55210。