我正在处理PyTorch的argmax
函数,其定义为:
torch.argmax(input, dim=None, keepdim=False)
考虑一个示例
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))
这里,当我使用dim=1而不是搜索列向量时,函数将搜索行向量,如下所示。
print(a) :
tensor([[-1.7739, 0.8073, 0.0472, -0.4084],
[ 0.6378, 0.6575, -1.2970, -0.0625],
[ 1.7970, -1.3463, 0.9011, -0.8704],
[ 1.5639, 0.7123, 0.0385, 1.8410]])
print(torch.argmax(a, dim=1))
tensor([1, 1, 0, 3])
就我的假设而言,dim = 0表示行,dim =1表示列。
2条答案
按热度按时间ryhaxcpt1#
现在是时候 * 正确理解 *
axis
或dim
参数在PyTorch中是如何工作的了:一旦你理解了上面的图片,下面的例子就应该有意义了:
dim
('dimension' 的缩写)是NumPy中 'axis' 的 Torch 等价物。sxpgvts32#
维度的定义如上面优秀的答案所示。我已经在Torch和Numpy(分别为dim和axis)中突出了我理解维度的方式,希望这对其他人有帮助。
注意,在argmax操作期间,只有指定维的索引发生变化,并且一旦操作完成,指定维的索引范围将减少为单个索引。设TensorA有M行和N列,为简单起见,考虑求和操作。A的形状为(M,N)。如果指定dim=0,则矢量
A[0,:]
,A[1,:]
,...,A[M-1,:]
按元素求和,结果是另一个具有1行和N列的Tensor。请注意,在整个M-1中,只有第0维的索引从0开始变化。类似地,如果指定dim=1,则矢量A[:,0]
,A[:,1]
,...,A[:,N-1]
按元素求和,结果是另一个具有M行和1列的Tensor。下面是一个例子:
在上面的示例代码中,第一个求和运算指定dim=0,因此
A[0,:]
和A[1,:]
(即[1,2,3]
和[4,5,6]
)相加并得到[5, 7, 9]
。当指定dim=1时,矢量A[:,0]
、A[:,1]
和A[:2]
(即矢量[1, 4]
、[2, 5]
和A[:2]
)和[3, 6]
被逐元素地相加以得到[6, 15]
。另请注意,指定的维度折叠。再次假设A的形状为
(M, N)
。如果dim=0,则结果的形状为(1, N)
,其中维度0从M减少到1。类似地,如果dim=1,则结果的形状为(M, 1)
,其中N减少到1。另请注意,形状(1,N)和(M,1)分别由具有N和M个元素的一维Tensor表示。