numpy 更快的麻木:连续数字替换

6ss1mwsb  于 2022-12-04  发布在  其他
关注(0)|答案(1)|浏览(143)

有人能帮我快点吗?

import numpy as np

# an array to split
a = np.array([0,0,1,0,1,1,1,0,1,1,0,0,0,1])

# idx where the number changes
idx = np.where(np.roll(a,1)!=a)[0][1:]

# split of array into groups
aout = np.split(a,idx)

# sum of each group
sumseg = [aa.sum() for aa in aout]

#fill criteria
idx2 = np.where( (np.array(sumseg)>0) & (np.array(sumseg)<2) )

#fill targets
[aout[ai].fill(0) for ai in idx2[0]]

# a is now updated? didn't follow how a gets updated
# return a

我注意到a通过这个过程得到更新,但不明白这些对象如何保持链接,认为分裂等...
如果这很重要,或者有帮助的话,a实际上是一个二维数组,我正在对每一行/列循环执行此操作。

bwitn5fc

bwitn5fc1#

更好的解决方案1D:

我们可以使用卷积:

aout = ((np.convolve(a,[1,1,1],mode='same')>1)&(a>0)).astype(a.dtype)
# aout = array([0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0])

更好的二维解决方案:

from scipy.signal import convolve2d

a = np.array([[1, 1, 0, 0, 0, 0, 1, 0, 0, 1],
              [1, 0, 1, 0, 1, 0, 0, 0, 1, 1]])

aout = ((convolve2d(a,np.ones((1,3)),mode='same')>1)&(a>0)).astype(a.dtype)

#aout = array([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
#              [0, 0, 0, 0, 0, 0, 0, 0, 1, 1]])

为什么a发生了变化?

为了理解为什么a在您的过程中被更新,您需要理解副本和视图之间的区别。
从文档中:

检视

可以通过改变某些元数据(如stride和dtype)而不改变数据缓冲区来访问数组。这创建了一种查看数据的新方式,这些新数组称为视图。数据缓冲区保持不变,因此对视图所做的任何更改都会反映在原始副本中。可以通过ndarray.view方法强制创建视图。

复制

通过复制数据缓冲区和元数据来创建新数组时,该数组称为副本。对副本所做的更改不会反映在原始数组上。创建副本的速度较慢且占用内存,但有时是必要的。可以使用ndarray.copy强制创建副本。
或者np.split()返回一个视图,而不是a的副本,因此aout仍然指向与a相同的数据缓冲区,如果更改aout,则会更改a

性能指标评测

import numpy as np

a = np.random.randint(0,2,(1000000,))

def continuous_split(a):
    idx = np.where(np.roll(a,1)!=a)[0][1:]
    aout = np.split(a,idx)
    sumseg = [aa.sum() for aa in aout]
    idx2 = np.where( (np.array(sumseg)>0) & (np.array(sumseg)<2) )
    [aout[ai].fill(0) for ai in idx2[0]]
    return aout
    
def continuous_conv(a):
    return ((np.convolve(a,[1,1,1],mode='same')>1)&(a>0)).astype(a.dtype)

%timeit continuous_split(a)
%timeit continuous_conv(a)
  • np.split()解决方案:*
668 ms ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • np.卷积()解:*
7.63 ms ± 115 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

相关问题