python 用tf函数对未知形状Tensor进行整形

wkftcu5l  于 2023-03-21  发布在  Python
关注(0)|答案(1)|浏览(224)

假设在我的函数中,我必须处理形状为[4,3,2,1]和[5,4,3,2,1]的输入Tensor。我想以交换最后两个维度的方式重塑它们,例如,交换为[4,3,1,2]。在渴望模式下,这很容易,但当我尝试使用@tf.function Package 函数时,会抛出以下错误:
OperatorNotAllowedInGraphError: Iterating over a symbolic tf.Tensor is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
代码如下:

tensor = tf.random.uniform(shape=[4, 3, 2, 1])

    @tf.function
    def my_func():
        reshaped = tf.reshape(tensor, shape=[*tf.shape(tensor)[:-2], tf.shape(tensor)[-1], tf.shape(tensor)[-2]])
        return reshaped

    logging.info(my_func())

看起来tensorflow不喜欢[:-2]符号,但我真的不知道我应该如何以一种优雅和可读性良好的方式解决这个问题。

91zkwejq

91zkwejq1#

不幸的是,在Graph模式下像这样切片Tensor不起作用,但我认为你可以使用tf.transpose

import tensorflow as tf

tensor = tf.random.uniform(shape=[4, 3, 2, 1])

@tf.function
def my_func():
  rank = tf.rank(tensor)
  some_magic = tf.concat([tf.zeros((rank - 2,), dtype=tf.int32), [1, -1]], axis=-1)
  reshaped = tf.transpose(tensor, perm = tf.range(rank) + some_magic)
  return reshaped

print(my_func().shape)
# (4, 3, 1, 2)

相关问题