numpy 将Numba njit与np.array一起使用

svmlkihl  于 2023-01-17  发布在  其他
关注(0)|答案(2)|浏览(188)

我有两个Python函数,我试图用njit来加速它们,因为它们影响了我的程序的性能。下面是一个MWE,当我们将@njit(fastmath=True)装饰器添加到f时,它重现了下面的错误。否则它会工作。我相信这个错误是因为数组A有dtype对象。除了g之外,我还可以使用Numba修饰f吗?如果不能,将gMap到A的元素的最快方法是什么?大致来说,A的长度= B~5000。这些函数被调用了大约500 MM次,尽管是hpc工作流的一部分。

@njit(fastmath=True)
def g(a, B):
    # some function of a and B
    return 19.12 / (len(a) + len(B))

def f(A, B):
    total = 0.0
    for i in range(len(B)):
        total += g(A[i], B)
    return total

A = [[2, 5], [4, 5, 6, 7], [0, 8], [6, 7], [1, 8], [0, 1], [1, 3], [1, 3], [2, 4]]
B = [1, 1, 1, 1, 1, 1, 1, 1, 1]

A = np.array([np.array(a, dtype=int) for a in A], dtype=object)
B = np.array(B, dtype=int)
    
f(A, B)

键入错误:nopython模式管道失败(步骤:nopython前端)非精确类型数组(pyobject,1d,C)期间:在/var/文件夹/9x/hnb8fg0x2p1c9p69p_70jnn40000gq/T/ipykernel_59724/www.example.com中键入参数(8)1681580915.py (8)
文件"../../../../var/folders/9x/hnb8fg0x2p1c9p69p_70jnn40000gq/T/ipykernel_59724/www.example.com",第8行:缺少源代码,REPL/exec是否正在使用?1681580915.py", line 8: <source missing, REPL/exec in use?>

mwyxok5s

mwyxok5s1#

我可以用Numba修饰f和g吗?
不可以。不能在@njit修饰的Numba函数中使用CPython对象。Numba之所以快主要是因为原生类型(支持生成快速编译代码,而不是解释动态代码)。
如果不是,将gMap到A的元素的最快方法是什么?
交错数组的效率很低。一般来说,快速解决这个问题的方法是使用2个数组:一个包含所有值,另一个包含每行的值的起始-结束范围(有点像稀疏矩阵,但使用范围)。存储每个段的长度也可以(而且更紧凑),尽管起始-结束范围需要累积和,这有时会使事情更复杂(例如,依赖性妨碍直接并行化)。

hpxqektj

hpxqektj2#

要创建@Jérôme Richard提到的非交错数组,我们可以这样做:

# Imports.
import numpy as np
from numba import njit, prange

# Data.
A_list = [[2, 5], [4, 5, 6, 7], [0, 8], [6, 7], [1, 8], [0, 1], [1, 3], [1, 3], [2, 4]]
B_list = [1, 1, 1, 1, 1, 1, 1, 1, 1]

A_lenghts = np.array([len(sublist) for sublist in A_list])
dim1 = len(A_list)
dim2 = A_lenghts.max()
A = np.empty(shape=(dim1, dim2), dtype=int) # 9x4.
for i, (sublist, length) in enumerate(zip(A_list, A_lenghts)):
    A[i, :length][:] = sublist

B = np.array(B_list, dtype=int)
assert A.shape[0] == B.size

数组A看起来如下所示:

array([[      2,       5, xxxxxx, xxxxxx],
       [      4,       5,      6,      7],
       [      0,       8, xxxxxx, xxxxxx],
       [      6,       7, xxxxxx, xxxxxx],
       [      1,       8, xxxxxx, xxxxxx],
       [      0,       1, xxxxxx, xxxxxx],
       [      1,       3, xxxxxx, xxxxxx],
       [      1,       3, xxxxxx, xxxxxx],
       [      2,       4, xxxxxx, xxxxxx]])

xxxxxx是我们得到的随机值,因为我们用np.empty创建了数组,这就是为什么要保留A_lengths,以确定每行数据在哪里变得不相关。
回到计算,我刚刚添加了对f的优化:@njit(parallel=True)装饰器和numba.prange,而不是Python的range

# Calculations.
@njit(fastmath=True)
def g(a, b):
    return 19.12 / (len(a) + len(b))

@njit(parallel=True)
def f(A, B):
    total = 0.0
    for i in prange(len(B)):
        total += g(A[i], B)
    return total

# Test.
print(f(A, B))

相关问题