我需要向LSTM解码器输入一个pooling模块,并且我使用一个自定义层来构造这个模块,该层以编码器LSTM状态和Keras Input层作为输入。在这个自定义层中,我需要将更新分散到索引中:
updates: <tf.Tensor --- shape=(None, 225, 5, 32) dtype=float32>
indices: <tf.Tensor --- shape=(None, 225) dtype=int32>
tf.scatter_nd
创建一个形状为(None,960,5,32)的Tensor,如下所示:
tf.scatter_nd(tf.expand_dims(indices, 2), updates, shape=[None, 960, 5, 32])
但问题是,这样做会因形状中的NoneType而导致错误,我不想在其中声明batch_size
,因为它是Keras层,只有在学习过程中才确定。在这种情况下,代码的工作版本如下:
tf.scatter_nd(tf.expand_dims(indices, 2), updates, shape=[960, 5, 32])
>>> <tf.Tensor 'ScatterNd_4:0' shape=(960, 5, 32) dtype=float32>
是否有其他方法来构造所需的输出Tensor,而不是tf.scatter_nd
,或者有其他方法使其正常工作?
1条答案
按热度按时间fdx2calv1#
我在
tf.scatter_nd
操作中也遇到过类似的问题。我通过在运行时使用tf.shape(input)[0]
推断批处理大小来解决这个问题。所以在您的情况下,下面的代码应该可以工作: