numpy 批量Tensor切片,切片B x N x M,其中B x 1

x9ybnkn6  于 2023-10-19  发布在  其他
关注(0)|答案(2)|浏览(122)

我有一个B x M x NTensor,X,还有一个B x 1Tensor,Y,它对应于TensorX在维度=1处的索引,我想保留。这个切片的简写是什么,这样我就可以避免循环了?
基本上,我想这样做:

Z = torch.zeros(B,N)

for i in range(B):
    Z[i] = X[i][Y[i]]
enxuqcxy

enxuqcxy1#

下面的代码类似于循环中的代码。不同之处在于,不是顺序索引数组ZXY,而是使用数组i并行索引它们

B, M, N = 13, 7, 19

X = np.random.randint(100, size= [B,M,N])
Y = np.random.randint(M  , size= [B,1])
Z = np.random.randint(100, size= [B,N])

i = np.arange(B)
Y = Y.ravel()    # reducing array to rank-1, for easy indexing

Z[i] = X[i,Y[i],:]

该代码可以进一步简化为

>> Z[i] = X[i,Y[i],:]
>> Z[i] = X[i,Y[i]]
>> Z[i] = X[i,Y]
>> Z    = X[i,Y]

Pytorch等效码

B, M, N = 5, 7, 3

X = torch.randint(100, size= [B,M,N])
Y = torch.randint(M  , size= [B,1])
Z = torch.randint(100, size= [B,N])

i = torch.arange(B)
Y = Y.ravel()

Z = X[i,Y]
ux6nzvsh

ux6nzvsh2#

@Hammad提供的答案简短而完美。如果你有兴趣使用一些不太知名的Pytorch内置程序,这里有一个替代解决方案。我们将使用torch.gather(类似地,您可以使用numpy.take实现此功能)。
torch.gather背后的想法是构建一个新的Tensor-基于两个相同形状的Tensor,包含索引(这里~ Y)和值(这里~ X)。
执行的操作是Z[i][j][k] = X[i][Y[i][j][k]][k]
由于X的形状是(B, M, N),而Y的形状是(B, 1),因此我们希望填充Y内部的空白,使Y的形状变为(B, 1, N)
这可以通过一些轴操作来实现:

>>> Y.expand(-1, N)[:, None] # expand to dim=1 to N and unsqueeze dim=1

torch.gather的实际调用将是:

>>> X.gather(dim=1, index=Y.expand(-1, N)[:, None])

您可以通过添加[:, 0]将其整形为(B, N)
此功能在棘手的场景中非常有效.

相关问题