tensorflow 我可以将tf.map_fn(...)应用于多个输入/输出吗?

eh57zj3b  于 2023-03-03  发布在  其他
关注(0)|答案(3)|浏览(167)
a = tf.constant([[1,2,3],[4,5,6]])
b = tf.constant([True, False], dtype=tf.bool)

a.eval()
array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
b.eval()
array([ True, False], dtype=bool)

我想使用tf.map_fn将一个函数应用于上面的输入ab。它将输入[1,2,3]True,并输出相似的值。
我们假设out函数就是恒等式:lambda(x,y): x,y所以,给定输入[1,2,3], True,它将输出那些相同的Tensor。
我知道如何将tf.map_fn(...)与一个变量一起使用,但不知道如何将其与两个变量一起使用,并且在本例中,我混合了数据类型(int 32和bool),所以我不能简单地连接Tensor并在调用后拆分它们。
tf.map_fn(...)是否可以用于不同数据类型的多个输入/输出?

ljsrvy3e

ljsrvy3e1#

明白了,你必须在dtype中定义每个Tensor的数据类型,然后你可以把Tensor作为一个元组传递,你的map函数接收一个元组的输入,map_fn返回一个元组。
有效的示例:

a = tf.constant([[1,2,3],[4,5,6]])
b = tf.constant([True, False], dtype=tf.bool)

c = tf.map_fn(lambda x: (x[0], x[1]), (a,b), dtype=(tf.int32, tf.bool))

c[0].eval()
array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
c[1].eval()
array([ True, False], dtype=bool)
jecbmhm3

jecbmhm32#

按照下面所述执行操作,但tuple解包tf.function内的值,而不使用lambda,因为计算成本高,并且lambda无法与TensorFlow一起作为tf函数正常工作。

@tf.function
def x(x):
  x,y = tuple(x) 
  return 

c = tf.map_fn(, (a,b), dtype=(tf.int32, tf.bool))
8i9zcol2

8i9zcol23#

这两种方法都有点混乱,你可以用一个定制的Extension Type显式定义栈的“type”。下面是你的例子,希望更易读:

class BoolAndVector(tf.experimental.BatchableExtensionType):
    """A collection bools and vectors."""
    bool: tf.Tensor
    vector: tf.Tensor

@tf.function
def fn(bool_and_vector):
   return BoolAndVector(bool=bool_and_vector.bool,
                         vector=[bool_and_vector.vector[0]])

c = tf.map_fn(fn, BoolAndVector(bool=b, vector=a))

您可能必须显式地说明签名类型,但是错误会引导您找到TF想要的类型。
注意,您需要使用BatchableExtensionType来使其可用于map。

相关问题