B, M, N = 13, 7, 19
X = np.random.randint(100, size= [B,M,N])
Y = np.random.randint(M , size= [B,1])
Z = np.random.randint(100, size= [B,N])
i = np.arange(B)
Y = Y.ravel() # reducing array to rank-1, for easy indexing
Z[i] = X[i,Y[i],:]
B, M, N = 5, 7, 3
X = torch.randint(100, size= [B,M,N])
Y = torch.randint(M , size= [B,1])
Z = torch.randint(100, size= [B,N])
i = torch.arange(B)
Y = Y.ravel()
Z = X[i,Y]
2条答案
按热度按时间enxuqcxy1#
下面的代码类似于循环中的代码。不同之处在于,不是顺序索引数组
Z
、X
和Y
,而是使用数组i
并行索引它们该代码可以进一步简化为
Pytorch等效码
ux6nzvsh2#
@Hammad提供的答案简短而完美。如果你有兴趣使用一些不太知名的Pytorch内置程序,这里有一个替代解决方案。我们将使用
torch.gather
(类似地,您可以使用numpy.take
实现此功能)。torch.gather
背后的想法是构建一个新的Tensor-基于两个相同形状的Tensor,包含索引(这里~Y
)和值(这里~X
)。执行的操作是
Z[i][j][k] = X[i][Y[i][j][k]][k]
。由于
X
的形状是(B, M, N)
,而Y
的形状是(B, 1)
,因此我们希望填充Y
内部的空白,使Y
的形状变为(B, 1, N)
。这可以通过一些轴操作来实现:
对
torch.gather
的实际调用将是:您可以通过添加
[:, 0]
将其整形为(B, N)
。此功能在棘手的场景中非常有效.