pytorchTensor逐对/逐行比较

7ajki6be  于 2022-12-13  发布在  其他
关注(0)|答案(1)|浏览(171)

我有一个2DTensor,它表示网格上的整数坐标。我想检查我的Tensor是否出现了特定的坐标(x,y)
伪代码示例:

positions = torch.arange(20).repeat(2).view(-1,2)
xy_dst1 = torch.tensor((5,7))
xy_dst2 = torch.tensor((4,5))
positions == xy_dst1 # should give none
positions == xy_dst2 # should give index 2 and 12

到目前为止,我唯一的解决方案是将Tensor转换成列表或元组,然后迭代地处理它们,但来回转换和迭代不可能是一个很好的解决方案。有人知道一个更好的解决方案,停留在Tensor框架内吗?

xiozqbni

xiozqbni1#

尝试

def check(positions, xy):
    return (positions == xy.view(1, 2)).all(dim=1).nonzero()

print(check(positions, xy_dst1))
# Output: tensor([], size=(0, 1), dtype=torch.int64)

print(check(positions, xy_dst2))
# Output:
# tensor([[ 2],
#         [12]])

相关问题