keras 将tf.gather按行应用于Tensor

pn9klfpd  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(185)

我想对Tensor的每一行分别进行tf.gather操作,假设我对每一行都有所需的索引。例如,如果一个Tensor:

A = tf.constant([[2., 5., 12., 9., 0., 0., 3.],
                 [0., 12., 2., 0., 0., 0., 5.],
                 [0., 0., 10., 0., 4., 4., 3.]], dtype=tf.float32)

散列索引:

idxs = tf.constant([[0, 1, 3, 6, 0, 0, 0],
                    [1, 1, 2, 6, 6, 6, 6],
                    [2, 2, 4, 4, 6, 6, 6]], dtype=tf.int32)

我希望根据相应的索引行收集每一行:

output:
[[2. 5. 9. 3. 2. 2. 2.]
 [12. 12. 2. 5. 5. 5. 5.]
 [10. 10. 4. 4. 3. 3. 3.]]

我想过也许使用tf.scan,但还没有成功。

6uxekuva

6uxekuva1#

需要将idxs转换为full indices,然后使用tf.gather_nd

ii = tf.cast(tf.range(idxs.shape[0])[...,None], tf.float32)*tf.ones(idxs.shape[1], dtype=tf.float32)
indices = tf.stack([tf.cast(ii, tf.int32), idxs], axis=-1)

使用,

tf.gather_nd(A, indices)

[[ 2.,  5.,  9.,  3.,  2.,  2.,  2.],
 [12., 12.,  2.,  5.,  5.,  5.,  5.],
 [10., 10.,  4.,  4.,  3.,  3.,  3.]]

相关问题