我有一个2DTensor,它表示网格上的整数坐标。我想检查我的Tensor是否出现了特定的坐标(x,y)
伪代码示例:
positions = torch.arange(20).repeat(2).view(-1,2)
xy_dst1 = torch.tensor((5,7))
xy_dst2 = torch.tensor((4,5))
positions == xy_dst1 # should give none
positions == xy_dst2 # should give index 2 and 12
到目前为止,我唯一的解决方案是将Tensor转换成列表或元组,然后迭代地处理它们,但来回转换和迭代不可能是一个很好的解决方案。有人知道一个更好的解决方案,停留在Tensor框架内吗?
1条答案
按热度按时间xiozqbni1#
尝试