NumPy Sum(带轴)如何工作?

rmbxnbpk  于 2024-01-08  发布在  其他
关注(0)|答案(3)|浏览(146)

出于好奇心,我决定自己学习NumPy的工作原理。
看起来最简单的函数是最难翻译成代码的(我是通过代码理解的),很容易对每种情况下的每个轴进行硬编码,但我想找到一个动态算法,可以在任何n维轴上求和。官方网站上的文档没有帮助(它只显示结果而不是过程),很难在Python/C代码中导航。

**注意:**我确实发现,当一个数组求和时,指定的轴被“删除”,即形状为(4,3,2)的数组与轴1的总和产生形状为(4,2)的数组的答案。

7gcisfzg

7gcisfzg1#

设置

考虑numpy数组a

a = np.arange(30).reshape(2, 3, 5)
print(a)

[[[ 0  1  2  3  4]
  [ 5  6  7  8  9]
  [10 11 12 13 14]]

 [[15 16 17 18 19]
  [20 21 22 23 24]
  [25 26 27 28 29]]]

字符串

维度在哪里?

尺寸和位置通过以下方式突出显示

p  p  p  p  p
            o  o  o  o  o
            s  s  s  s  s

     dim 2  0  1  2  3  4

            |  |  |  |  |
  dim 0     ↓  ↓  ↓  ↓  ↓
  ----> [[[ 0  1  2  3  4]   <---- dim 1, pos 0
  pos 0   [ 5  6  7  8  9]   <---- dim 1, pos 1
          [10 11 12 13 14]]  <---- dim 1, pos 2
  dim 0
  ---->  [[15 16 17 18 19]   <---- dim 1, pos 0
  pos 1   [20 21 22 23 24]   <---- dim 1, pos 1
          [25 26 27 28 29]]] <---- dim 1, pos 2
            ↑  ↑  ↑  ↑  ↑
            |  |  |  |  |

     dim 2  p  p  p  p  p
            o  o  o  o  o
            s  s  s  s  s

            0  1  2  3  4

维度示例:

这一点通过几个例子变得更加清楚

a[0, :, :] # dim 0, pos 0

[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]]

x

a[:, 1, :] # dim 1, pos 1

[[ 5  6  7  8  9]
 [20 21 22 23 24]]
a[:, :, 3] # dim 2, pos 3

[[ 3  8 13]
 [18 23 28]]

的一种或多种

sum

  • sumaxis的解释 *

a.sum(0)是沿沿着dim 0的所有切片的总和

a.sum(0)

[[15 17 19 21 23]
 [25 27 29 31 33]
 [35 37 39 41 43]]


相同

a[0, :, :] + \
a[1, :, :]

[[15 17 19 21 23]
 [25 27 29 31 33]
 [35 37 39 41 43]]


a.sum(1)是沿沿着dim 1的所有切片的总和

a.sum(1)

[[15 18 21 24 27]
 [60 63 66 69 72]]


相同

a[:, 0, :] + \
a[:, 1, :] + \
a[:, 2, :]

[[15 18 21 24 27]
 [60 63 66 69 72]]


a.sum(2)是沿沿着dim 2的所有切片的总和

a.sum(2)

[[ 10  35  60]
 [ 85 110 135]]


相同

a[:, :, 0] + \
a[:, :, 1] + \
a[:, :, 2] + \
a[:, :, 3] + \
a[:, :, 4]

[[ 10  35  60]
 [ 85 110 135]]


默认轴为-1
这意味着所有的轴。或总和所有的数字。

a.sum()

435

lc8prwob

lc8prwob2#

我使用一个嵌套循环操作来解释它。

import numpy as np

n = np.array(
[[[1, 2, 3],
 [4, 5, 6],
 [7, 8, 9]],

 [[2, 4, 6],
 [8, 10, 12],
 [14, 16, 18]],

 [[1, 3, 5],
 [7, 9, 11],
 [13, 15, 17]]])

print(n)

print("============ sum axis=None=============")

sum = 0
for i in range(3):
  for j in range(3): 
    for k in range(3):
      sum += n[k][i][j]
print(sum) # 216

print('------------------')
print(np.sum(n))  # 216
print("============ sum axis=0 =============") 
for i in range(3):
  for j in range(3):
    sum = 0
    for axis in range(3):
      sum += n[axis][i][j]
    print(sum,end=' ')
  print()

print('------------------')
print("sum[0][0] = %d" % (n[0][0][0] + n[1][0][0] + n[2][0][0]))
print("sum[1][1] = %d" % (n[0][1][1] + n[1][1][1] + n[2][1][1]))
print("sum[2][2] = %d" % (n[0][2][2] + n[1][2][2] + n[2][2][2]))
print('------------------')
print(np.sum(n, axis=0)) 
print("============ sum axis=1 =============") 
for i in range(3):
  for j in range(3):
    sum = 0
    for axis in range(3):
      sum += n[i][axis][j]
    print(sum,end=' ')
  print()
print('------------------')
print("sum[0][0] = %d" % (n[0][0][0] + n[0][1][0] + n[0][2][0]))
print("sum[1][1] = %d" % (n[1][0][1] + n[1][1][1] + n[1][2][1]))
print("sum[2][2] = %d" % (n[2][0][2] + n[2][1][2] + n[2][2][2]))
print('------------------')
print(np.sum(n, axis=1))  
print("============ sum axis=2 =============") 
for i in range(3):
  for j in range(3):
    sum = 0
    for axis in range(3):
      sum += n[i][j][axis]
    print(sum,end=' ')
  print()
print('------------------')
print("sum[0][0] = %d" % (n[0][0][0] + n[0][0][1] + n[0][0][2]))
print("sum[1][1] = %d" % (n[1][1][0] + n[1][1][1] + n[1][1][2]))
print("sum[2][2] = %d" % (n[2][2][0] + n[2][2][1] + n[2][2][2]))
print('------------------')
print(np.sum(n, axis=2))
print("============ sum axis=(0,1)) =============") 
for i in range(3):
  sum = 0
  for axis1 in range(3):   
    for axis2 in range(3):
      sum += n[axis1][axis2][i]
  print(sum,end=' ')

print()
print('------------------')
print("sum[1] = %d" % (n[0][0][1] + n[0][1][1] + n[0][2][1] +
              n[1][0][1] + n[1][1][1] + n[1][2][1] +
              n[2][0][1] + n[2][1][1] + n[2][2][1] ))
print('------------------')
print(np.sum(n, axis=(0,1)))

字符串
结果:

[[[ 1  2  3]
  [ 4  5  6]
  [ 7  8  9]]

 [[ 2  4  6]
  [ 8 10 12]
  [14 16 18]]

 [[ 1  3  5]
  [ 7  9 11]
  [13 15 17]]]
============ sum axis=None=============
216
------------------
216
============ sum axis=0 =============
4 9 14 
19 24 29 
34 39 44 
------------------
sum[0][0] = 4
sum[1][1] = 24
sum[2][2] = 44
------------------
[[ 4  9 14]
 [19 24 29]
 [34 39 44]]
============ sum axis=1 =============
12 15 18 
24 30 36 
21 27 33 
------------------
sum[0][0] = 12
sum[1][1] = 30
sum[2][2] = 33
------------------
[[12 15 18]
 [24 30 36]
 [21 27 33]]
============ sum axis=2 =============
6 15 24 
12 30 48 
9 27 45 
------------------
sum[0][0] = 6
sum[1][1] = 30
sum[2][2] = 45
------------------
[[ 6 15 24]
 [12 30 48]
 [ 9 27 45]]
============ sum axis=(0,1)) =============
57 72 87 
------------------
sum[1] = 72
------------------
[57 72 87]

fdx2calv

fdx2calv3#

假设我们的数组有2行3列

import numpy as np
a = np.array([[1,2,3],[3,4,6]])

print(a.shape)
#prints:(2, 3) This array has 2 rows and 3 columns

字符串
以下是三种不同的可能性:

print(np.sum(a)) #computes sum of all the elements; prints: 19
print(np.sum(a, axis= 0)) #computes sum of all the column; prints: [4 6 9]
print(np.sum(a, axis= 1)) #computes sum of all the rows; prints: [6 13]

相关问题