考虑下面的代码。我想知道如何从列表中收集非Tensor类型。
import tensorflow as tf
class Point(tf.experimental.ExtensionType):
xx: tf.Tensor
def __init__(self,xx):
self.xx = xx
super().__init__()
list1 = [ 1, 2, 3, 4]
list2 = [ Point(1), Point(2), Point(3), Point(4) ]
# this works
out1 = tf.gather(list1,[0,2])
print('First gather ',out1)
# this throws: ValueError: Attempt to convert a value (Point(xx=<tf.Tensor:
# shape=(), dtype=int32, numpy=1>)) with an unsupported type
# (<class '__main__.Point'>) to a Tensor.
out2 = tf.gather(list2,[0,2])
print('Second gather ',out2)
字符串
3条答案
按热度按时间to94eoyn1#
我的回答
这很容易.转换
Point
列表到Tensor
列表.我的示例
字符串
fnx2tebb2#
tf.gather
专门设计用于处理可以转换为Tensor的Tensor或列表。tf.gather(list1, [0, 2])
可以工作,因为list1
包含的整数可以被tensorflow解释为Tensor。然而,
list2
包含tensorflow无法理解的Point
类型的自定义对象。因此,直接在list2
上使用tf.gather
会引发观察到的错误。要从非Tensor类型的列表中收集元素,如自定义对象,可以使用列表解析。
例如
字符串
如果出于某种原因必须坚持使用TensorFlow函数,Point类的当前设计,直接TensorFlow操作(如
tf.gather
)不是理想的选择。引入
tf.experimental.ExtensionType
是为了允许用户定义的类像tensorflow原生类型一样工作,但这仍然是实验性的,可能不适用于所有操作。您看到的错误表明
tf.gather
操作错误无法将Point
识别为类似Tensor的对象。一个不将整个列表转换为Tensor的潜在解决方案可能是直接在
Point
类上实现所需的行为。例如
型
Point
上的gather
方法基本上复制了列表解析,但没有利用tensorflow的计算图形功能。随着时间的推移,
ExtensionType
可以被进一步细化,从而允许更直接的解决方案。pgpifvop3#
根据tf.gather,它需要一个tensor作为输入。如果我们传递list(即
list1
,它也会被转换为tensor。但在您的情况下,list2
中的每个项目都是Point
类的对象,tensorflow无法解释。在这种情况下,你可以将每个项目转换为列表或Tensor,以便与tf.gather
一起使用。字符串