使用tf.scatter_nd使Keras 'None'批处理大小保持不变

ejk8hzay  于 2022-11-24  发布在  其他
关注(0)|答案(1)|浏览(130)

我需要向LSTM解码器输入一个pooling模块,并且我使用一个自定义层来构造这个模块,该层以编码器LSTM状态和Keras Input层作为输入。在这个自定义层中,我需要将更新分散到索引中:

updates: <tf.Tensor --- shape=(None, 225, 5, 32) dtype=float32>
indices: <tf.Tensor --- shape=(None, 225) dtype=int32>

tf.scatter_nd创建一个形状为(None,960,5,32)的Tensor,如下所示:

tf.scatter_nd(tf.expand_dims(indices, 2), updates, shape=[None, 960, 5, 32])

但问题是,这样做会因形状中的NoneType而导致错误,我不想在其中声明batch_size,因为它是Keras层,只有在学习过程中才确定。在这种情况下,代码的工作版本如下:

tf.scatter_nd(tf.expand_dims(indices, 2), updates, shape=[960, 5, 32])
        >>> <tf.Tensor 'ScatterNd_4:0' shape=(960, 5, 32) dtype=float32>

是否有其他方法来构造所需的输出Tensor,而不是tf.scatter_nd,或者有其他方法使其正常工作?

fdx2calv

fdx2calv1#

我在tf.scatter_nd操作中也遇到过类似的问题。我通过在运行时使用tf.shape(input)[0]推断批处理大小来解决这个问题。所以在您的情况下,下面的代码应该可以工作:

bs = tf.shape(indices)[0]
tf.scatter_nd(tf.expand_dims(indices, 2), updates, shape=[bs, 960, 5, 32])

相关问题