tensorflow 使用tf.gather索引包含非Tensor类型的列表

9jyewag0  于 2023-11-21  发布在  其他
关注(0)|答案(3)|浏览(152)

考虑下面的代码。我想知道如何从列表中收集非Tensor类型。

import tensorflow as tf

class Point(tf.experimental.ExtensionType):
    xx: tf.Tensor
    def __init__(self,xx):
        self.xx = xx
        super().__init__()

list1 = [ 1, 2, 3, 4] 
list2 = [ Point(1), Point(2), Point(3), Point(4) ]

# this works
out1 = tf.gather(list1,[0,2])
print('First gather ',out1)

# this throws: ValueError: Attempt to convert a value (Point(xx=<tf.Tensor:
# shape=(), dtype=int32, numpy=1>)) with an unsupported type
# (<class '__main__.Point'>) to a Tensor.

out2 = tf.gather(list2,[0,2])
print('Second gather ',out2)

字符串

to94eoyn

to94eoyn1#

我的回答

这很容易.转换Point列表到Tensor列表.

我的示例

import tensorflow as tf

class Point(tf.experimental.ExtensionType):
    xx: tf.Tensor

    def __init__(self, xx):
        self.xx = xx
        super().__init__()

list2 = [Point(1), Point(2), Point(3), Point(4)]

out2 = tf.gather([x.xx for x in list2], [0, 2])
print('Second gather ', out2)
# > Second gather  tf.Tensor([1 3], shape=(2,), dtype=int32)

字符串

fnx2tebb

fnx2tebb2#

tf.gather专门设计用于处理可以转换为Tensor的Tensor或列表。
tf.gather(list1, [0, 2])可以工作,因为list1包含的整数可以被tensorflow解释为Tensor。
然而,list2包含tensorflow无法理解的Point类型的自定义对象。因此,直接在list2上使用tf.gather会引发观察到的错误。
要从非Tensor类型的列表中收集元素,如自定义对象,可以使用列表解析。
例如

indices = [0,2]
out2 - [list2[i] for i in indices]

字符串
如果出于某种原因必须坚持使用TensorFlow函数,Point类的当前设计,直接TensorFlow操作(如tf.gather)不是理想的选择。
引入tf.experimental.ExtensionType是为了允许用户定义的类像tensorflow原生类型一样工作,但这仍然是实验性的,可能不适用于所有操作。
您看到的错误表明tf.gather操作错误无法将Point识别为类似Tensor的对象。
一个不将整个列表转换为Tensor的潜在解决方案可能是直接在Point类上实现所需的行为。
例如

class Point(tf.experimental.ExtensionType):
    xx: tf.Tensor
    
    def __init__(self, xx):
        self.xx = xx
        super().__init__()

    @classmethod
    def gather(cls, points, indices):
        return [points[i] for i in indices]

list2 = [Point(1), Point(2), Point(3), Point(4)]

# Using the gather method of Point class
out2 = Point.gather(list2, [0,2])
print('Second gather ', out2)


Point上的gather方法基本上复制了列表解析,但没有利用tensorflow的计算图形功能。
随着时间的推移,ExtensionType可以被进一步细化,从而允许更直接的解决方案。

pgpifvop

pgpifvop3#

根据tf.gather,它需要一个tensor作为输入。如果我们传递list(即list1,它也会被转换为tensor。但在您的情况下,list2中的每个项目都是Point类的对象,tensorflow无法解释。在这种情况下,你可以将每个项目转换为列表或Tensor,以便与tf.gather一起使用。

inputs = tf.convert_to_tensor(
    [point.xx.numpy() for point in list2]
)
ouput = tf.gather(tensor2, [0, 2])
ouput 
tf.Tensor([1 3], shape=(2,), dtype=int32)

字符串

相关问题