numpy 对之前的元素求和并将元素替换为总和

cmssoen2  于 2023-01-20  发布在  其他
关注(0)|答案(3)|浏览(120)

我有numpy数组

arr = np.array([[0, 0, 2, 5, 0, 0, 1, 8, 0, 3, 0],
                [1, 2, 0, 0, 0, 0, 5, 7, 0, 0, 0],
                [8, 5, 3, 9, 0, 1, 0, 0, 0, 0, 1]])

我需要这样的结果数组:

[[0, 0, 0, 0, 7, 0, 0, 0, 9, 0, 3]
 [0, 0, 3, 0, 0, 0, 0, 0, 12, 0, 0]
 [0, 0, 0, 0, 25, 0, 1, 0, 0, 0, 0]]

发生什么事了?
我们沿着行前进,如果行中的元素是0,那么我们前进到下一个元素,如果不是0,那么我们对元素求和直到满足0,一旦满足0,那么我们用结果和替换它(也用0替换初始非零数字
我已经知道如何使用循环来实现这一点,但是对于大量的行,它在时间上效果不好,所以我需要numpy方法中的时间效率解决方案

更新

这是一个尝试

zero_el = arr == 0
>>> np.where(zero_el, arr.cumsum(axis=1), 0)

[[ 0  0  0  0  7  7  0  0 16  0 19]
 [ 0  0  3  3  3  3  0  0 15 15 15]
 [ 0  0  0  0 25  0 26 26 26 26  0]]
e4yzc0pl

e4yzc0pl1#

首先,我们要找到数组中零与非零相邻的位置。

rr, cc = np.where((arr[:, 1:] == 0) & (arr[:, :-1] != 0))

现在,我们可以使用np.add.reduceat来添加元素。不幸的是,reduceat需要一个一维索引的列表,所以我们将不得不稍微使用一些形状。计算扁平数组中rr, cc的等效索引很容易:

reduce_indices = rr * arr.shape[1] + cc + 1
# array([ 4,  8, 10, 13, 19, 26, 28])

我们希望从 * 每一行的开头 * 开始进行约简,因此我们将创建一个row_starts来混合上面计算的索引:

row_starts = np.arange(arr.shape[0]) * arr.shape[1]
# array([ 0, 11, 22])

reduce_indices = np.hstack((row_starts, reduce_indices))
reduce_indices.sort()
# array([ 0,  4,  8, 10, 11, 13, 19, 22, 26, 28])

现在,在展平的输入数组上调用np.add.reduceat,减少到reduce_indices

totals = np.add.reduceat(arr.flatten(), reduce_indices)
# array([ 7,  9,  3,  0,  3, 12,  0, 25,  1,  1])

现在我们有了总数,我们需要将它们赋给一个零数组,注意totals的第0个元素需要转到reduce_indices的第1个索引,而totals的最后一个元素要被丢弃:

result_f = np.zeros((arr.size,))
result_f[reduce_indices[1:]] = totals[:-1]
result = result_f.reshape(arr.shape)

这给出了预期结果:

array([[ 0.,  0.,  0.,  0.,  7.,  0.,  0.,  0.,  9.,  0.,  3.],
       [ 0.,  0.,  3.,  0.,  0.,  0.,  0.,  0., 12.,  0.,  0.],
       [ 0.,  0.,  0.,  0., 25.,  0.,  1.,  0.,  0.,  0.,  0.]])
6jygbczu

6jygbczu2#

我们可以使用2个for循环来求解,在每一行我们定义current_sum,如果number为零,我们将current_sum赋值给number并重置current_sum;如果number不为零,则将0赋给number,并且递增current_sum。
编辑:首先对不起,我没有意识到你想要一个高效的解决方案。我们可以使用numba来加速for循环。它真的很简单,功能也很强大。下面是代码:

import numpy as np
import numba
arr = np.array([[0, 0, 2, 5, 0, 0, 1, 8, 0, 3, 0],
                [1, 2, 0, 0, 0, 0, 5, 7, 0, 0, 0],
                [8, 5, 3, 9, 0, 1, 0, 0, 0, 0, 1]])

@numba.jit(nopython=True)
def mySum(array):
    for i in range(array.shape[0]):
        current_sum = 0
        for j in range(array.shape[1]):
            if array[i,j] == 0:
                array[i,j] = current_sum
                current_sum = 0
            else:
                current_sum += array[i,j]
                array[i,j] = 0
    return array

print(mySum(arr))

函数在第一次运行时很慢,因为它理解输入和函数并创建机器码,但在那之后它真的很快。我希望它对你的情况来说足够快。

acruukt9

acruukt93#

可能比in循环长......但让我用单数组来演示:

a = np.array([0, 0, 2, 5, 0, 0, 1, 8, 0, 3, 0])
zero_index = np.where(a == 0)[0]
# Split zeros, sum each slice, drop the last one
replace_arr = np.array(list(map(sum, np.split(a, zero_index))))[:-1]
output = np.zeros(11)
# Put sum data into zeros array
np.put_along_axis(output, zero_index, replace_arr, axis=0)
output

相关问题