numpy 快速替代条件设置数组元素

ds97pgxw  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(160)

我有两个给定的3d数组x_disty_dist,每个数组的形状都是(36,50,50)x_disty_dist中的元素属于np.float32类型,可以是正数、负数或零。我需要创建一个新数组res_array,在所有索引处将其值设置为(1-y_dist)*(x_dist),除了条件((x_dist <= 0) | ((x_dist > 0) & (y_dist > (1 + x_dist))))True的索引。我目前的实现如下。

res_array  = (1-y_dist)*(x_dist)
res_array[((x_dist <= 0) | ((x_dist > 0) & (y_dist > (1 + x_dist))))] = 0.0

然而,我需要运行包含此代码片段的代码数千次,我相信有一个更聪明,更快的方法来做同样的事情。你能帮我得到一个性能更好的代码或一行程序吗?

t9eec4r0

t9eec4r01#

Numba JIT可以有效地实现这一点。下面是一个实现:

@njit
def fastImpl(x_dist, y_dist):
    res_array = np.empty(x_dist.shape)
    for z in range(res_array.shape[0]):
        for y in range(res_array.shape[1]):
            for x in range(res_array.shape[2]):
                xDist = x_dist[z,y,x]
                yDist = y_dist[z,y,x]
                if xDist > 0.0 and yDist <= (1.0 + xDist):
                    res_array[z,y,x] = (1.0 - yDist) * xDist
    return res_array

以下是随机输入矩阵的性能结果:

Original implementation: 494 µs ± 6.23 µs per loop (mean ± std. dev. of 7 runs, 500 loops each)
New implementation: 37.8 µs ± 236 ns per loop (mean ± std. dev. of 7 runs, 500 loops each)

新的实现大约快13倍(不考虑编译/预热时间)。

相关问题