Tensorflow教程包括使用tf.expand_dims
向Tensor添加“批量维度”。我已经阅读了这个函数的文档,但它对我来说仍然很神秘。有人知道在什么情况下必须使用它吗?
代码在下面。我的目的是根据预测和实际箱之间的距离计算损失。(例如,如果predictedBin = 10
和truthBin = 7
,则binDistanceLoss = 3
)。
batch_size = tf.size(truthValues_placeholder)
labels = tf.expand_dims(truthValues_placeholder, 1)
predictedBin = tf.argmax(logits)
binDistanceLoss = tf.abs(tf.sub(labels, logits))
在这种情况下,我需要将tf.expand_dims
应用到predictedBin
和binDistanceLoss
吗?先谢谢你了。
2条答案
按热度按时间e0uiprwp1#
expand_dims
不会增加或减少Tensor中的元素,它只是通过将1
添加到维度来改变形状。例如,具有10个元素的向量可以被视为10x1矩阵。我遇到的使用
expand_dims
的情况是当我试图构建一个ConvNet来分类灰度图像时。灰度图像将作为大小为[320, 320]
的矩阵加载。然而,tf.nn.conv2d
要求输入为[batch, in_height, in_width, in_channels]
,其中in_channels
维度在我的数据中丢失,在本例中应该是1
。所以我用expand_dims
再添加一个维度。在您的情况下,我认为您不需要
expand_dims
。3zwtqj6y2#
为了补充大同的答案,你可能想同时扩展多个维度。例如,如果您对秩为1的向量执行TensorFlow的
conv1d
操作,则需要为它们提供秩为3的向量。多次执行
expand_dims
是可读的,但可能会在计算图中引入一些开销。您可以使用reshape
在一行程序中获得相同的功能:**注意:**如果您得到错误
TypeError: Failed to convert object of type <type 'list'> to Tensor.
,请尝试传递tf.shape(x)[0]
而不是x.get_shape()[0]
,如建议的here。希望能帮上忙!
干杯
安德烈斯