我想对Tensor的每一行分别进行tf.gather操作,假设我对每一行都有所需的索引。例如,如果一个Tensor:
A = tf.constant([[2., 5., 12., 9., 0., 0., 3.],
[0., 12., 2., 0., 0., 0., 5.],
[0., 0., 10., 0., 4., 4., 3.]], dtype=tf.float32)
散列索引:
idxs = tf.constant([[0, 1, 3, 6, 0, 0, 0],
[1, 1, 2, 6, 6, 6, 6],
[2, 2, 4, 4, 6, 6, 6]], dtype=tf.int32)
我希望根据相应的索引行收集每一行:
output:
[[2. 5. 9. 3. 2. 2. 2.]
[12. 12. 2. 5. 5. 5. 5.]
[10. 10. 4. 4. 3. 3. 3.]]
我想过也许使用tf.scan,但还没有成功。
1条答案
按热度按时间6uxekuva1#
需要将
idxs
转换为full indices
,然后使用tf.gather_nd
:使用,