Torch中的高级索引?(是否有与np.ix_等效的PyTorch?)

cx6n0qe3  于 2023-03-30  发布在  其他
关注(0)|答案(2)|浏览(191)

在Numpy中,我可以使用布尔和数值索引的任意组合来索引数组。

>>> x = np.arange(12).reshape(3, 4)
>>> x
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
>>> x[np.ix_([False, True, True], [0, 3])] = [[1, 2], [3, 4]]
array([[ 0,  1,  2,  3],
       [ 1,  5,  6,  2],
       [ 3,  9, 10,  4]])

假设x和索引数组(布尔和数值)都是torchTensor(在GPU上,如果有区别的话)。是否有一种方法可以像Numpy解决方案中那样索引和变异x

iyfjxgzm

iyfjxgzm1#

np.ix_所做的就是将布尔数组输入转换为整数索引数组(通过.nonzero),然后引入单例维度以方便广播。例如:

x[ np.ix_(idx1, idx2, idx3) ]

相当于

x[ idx1[:, None, None], idx2[None, :, None], idx3[None, None, :] ]

只要每个索引包含整数。
我不知道pytorch中有任何ix_等价物,但巧合的是,在用.nonzero转换布尔索引后,得到的形状已经非常适合广播:

>>> import torch
>>>
>>> x = torch.arange(12).reshape(3, 4).cuda()
>>>
>>> rows = torch.tensor([False, True, True]).cuda()
>>> cols = cols = torch.tensor([0, 3]).cuda()
>>> val = torch.tensor([[1, 2], [3, 4]]).cuda()
>>>
>>> x[rows.nonzero(), cols] = val
>>> x
tensor([[ 0,  1,  2,  3],
        [ 1,  5,  6,  2],
        [ 3,  9, 10,  4]], device='cuda:0')

然而,在更一般的上下文中,你可以像np.ix_一样,在索引Tensor中插入用于广播的单例维度。

enxuqcxy

enxuqcxy2#

您可以使用torch.meshgrid

>>> x = torch.arange(12).view(3, 4)

>>> x[torch.meshgrid(torch.tensor([1,2]), torch.tensor([0,3]))]
tensor([[ 4,  7],
        [ 8, 11]])

相关问题