Tensorflow:何时使用tf.expand_dims?

bnl4lu3b  于 2023-06-24  发布在  其他
关注(0)|答案(2)|浏览(116)

Tensorflow教程包括使用tf.expand_dims向Tensor添加“批量维度”。我已经阅读了这个函数的文档,但它对我来说仍然很神秘。有人知道在什么情况下必须使用它吗?
代码在下面。我的目的是根据预测和实际箱之间的距离计算损失。(例如,如果predictedBin = 10truthBin = 7,则binDistanceLoss = 3)。

batch_size = tf.size(truthValues_placeholder)
labels = tf.expand_dims(truthValues_placeholder, 1)
predictedBin = tf.argmax(logits)
binDistanceLoss = tf.abs(tf.sub(labels, logits))

在这种情况下,我需要将tf.expand_dims应用到predictedBinbinDistanceLoss吗?先谢谢你了。

e0uiprwp

e0uiprwp1#

expand_dims不会增加或减少Tensor中的元素,它只是通过将1添加到维度来改变形状。例如,具有10个元素的向量可以被视为10x1矩阵。
我遇到的使用expand_dims的情况是当我试图构建一个ConvNet来分类灰度图像时。灰度图像将作为大小为[320, 320]的矩阵加载。然而,tf.nn.conv2d要求输入为[batch, in_height, in_width, in_channels],其中in_channels维度在我的数据中丢失,在本例中应该是1。所以我用expand_dims再添加一个维度。
在您的情况下,我认为您不需要expand_dims

3zwtqj6y

3zwtqj6y2#

为了补充大同的答案,你可能想同时扩展多个维度。例如,如果您对秩为1的向量执行TensorFlow的conv1d操作,则需要为它们提供秩为3的向量。
多次执行expand_dims是可读的,但可能会在计算图中引入一些开销。您可以使用reshape在一行程序中获得相同的功能:

import tensorflow as tf

# having some tensor of rank 1, it could be an audio signal, a word vector...
tensor = tf.ones(100)
print(tensor.get_shape()) # => (100,)

# expand its dimensionality to fit into conv2d
tensor_expand = tf.expand_dims(tensor, 0)
tensor_expand = tf.expand_dims(tensor_expand, 0)
tensor_expand = tf.expand_dims(tensor_expand, -1)
print(tensor_expand.get_shape()) # => (1, 1, 100, 1)

# do the same in one line with reshape
tensor_reshape = tf.reshape(tensor, [1, 1, tensor.get_shape().as_list()[0],1])
print(tensor_reshape.get_shape()) # => (1, 1, 100, 1)

**注意:**如果您得到错误TypeError: Failed to convert object of type <type 'list'> to Tensor.,请尝试传递tf.shape(x)[0]而不是x.get_shape()[0],如建议的here

希望能帮上忙!
干杯
安德烈斯

相关问题