pytorch 为什么dim=1返回torch.argmax中的行索引?

nlejzf6q  于 2023-01-13  发布在  其他
关注(0)|答案(2)|浏览(199)

我正在处理PyTorch的argmax函数,其定义为:

torch.argmax(input, dim=None, keepdim=False)

考虑一个示例

a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))

这里,当我使用dim=1而不是搜索列向量时,函数将搜索行向量,如下所示。

print(a) :   
tensor([[-1.7739,  0.8073,  0.0472, -0.4084],  
        [ 0.6378,  0.6575, -1.2970, -0.0625],  
        [ 1.7970, -1.3463,  0.9011, -0.8704],  
        [ 1.5639,  0.7123,  0.0385,  1.8410]])  

print(torch.argmax(a, dim=1))  
tensor([1, 1, 0, 3])

就我的假设而言,dim = 0表示行,dim =1表示列。

ryhaxcpt

ryhaxcpt1#

现在是时候 * 正确理解 * axisdim参数在PyTorch中是如何工作的了:

一旦你理解了上面的图片,下面的例子就应该有意义了:

|
    v
  dim-0  ---> -----> dim-1 ------> -----> --------> dim-1
    |   [[-1.7739,  0.8073,  0.0472, -0.4084],
    v    [ 0.6378,  0.6575, -1.2970, -0.0625],
    |    [ 1.7970, -1.3463,  0.9011, -0.8704],
    v    [ 1.5639,  0.7123,  0.0385,  1.8410]]
    |
    v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])
    • 注意**:dim'dimension' 的缩写)是NumPy中 'axis' 的 Torch 等价物。
sxpgvts3

sxpgvts32#

维度的定义如上面优秀的答案所示。我已经在Torch和Numpy(分别为dim和axis)中突出了我理解维度的方式,希望这对其他人有帮助。
注意,在argmax操作期间,只有指定维的索引发生变化,并且一旦操作完成,指定维的索引范围将减少为单个索引。设TensorA有M行和N列,为简单起见,考虑求和操作。A的形状为(M,N)。如果指定dim=0,则矢量A[0,:]A[1,:],...,A[M-1,:]按元素求和,结果是另一个具有1行和N列的Tensor。请注意,在整个M-1中,只有第0维的索引从0开始变化。类似地,如果指定dim=1,则矢量A[:,0]A[:,1],...,A[:,N-1]按元素求和,结果是另一个具有M行和1列的Tensor。
下面是一个例子:

>>> A = torch.tensor([[1,2,3], [4,5,6]])
>>> A
tensor([[1, 2, 3],
        [4, 5, 6]])
>>> S0 = torch.sum(A, dim = 0)
>>> S0
tensor([5, 7, 9])
>>> S1 = torch.sum(A, dim = 1)
>>> S1
tensor([ 6, 15])

在上面的示例代码中,第一个求和运算指定dim=0,因此A[0,:]A[1,:](即[1,2,3][4,5,6])相加并得到[5, 7, 9]。当指定dim=1时,矢量A[:,0]A[:,1]A[:2](即矢量[1, 4][2, 5]A[:2])和[3, 6]被逐元素地相加以得到[6, 15]
另请注意,指定的维度折叠。再次假设A的形状为(M, N)。如果dim=0,则结果的形状为(1, N),其中维度0从M减少到1。类似地,如果dim=1,则结果的形状为(M, 1),其中N减少到1。另请注意,形状(1,N)和(M,1)分别由具有N和M个元素的一维Tensor表示。

相关问题