numpy 如果使用np.array作为numba中的输入,如何设置签名

kcwpcxri  于 2022-11-23  发布在  其他
关注(0)|答案(1)|浏览(169)

我想给我的numba函数设置一个签名来规范它的类型。但是,我做了之后,发现这个函数不起作用。我该怎么设置签名呢?

mat = np.random.normal(0, 1, size=(1000000, 10))

@nb.jit(nopython=True)
def f(mat):
    max_min = 0
    for i in range(mat.shape[0]):
        max_min += mat[i].max() - mat[i].min()
    return max_min / mat.shape[0]

start = time.time()
print(f(mat))
end = time.time()
end - start

如果我这样做,它工作得很好。
但如果我这样做:

@nb.jit(nb.float64(nb.float64), nopython=True)
def f(mat):
    max_min = 0
    for i in range(mat.shape[0]):
        max_min += mat[i].max() - mat[i].min()
    return max_min / mat.shape[0]

start = time.time()
print(f(mat))
end = time.time()
end - start

它会报告一个错误:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'shape' of type float64

我该怎么解决这个问题?

ni65a41a

ni65a41a1#

您应该执行下列动作:

@nb.jit(nb.float64(nb.types.Array(nb.float64, 2, "C")), nopython=True)
def f(mat):
    max_min = 0
    for i in range(mat.shape[0]):
        max_min += mat[i].max() - mat[i].min()
    return max_min / mat.shape[0]

也就是说,nb.types.Array(nb.float64, 2, "C")告诉它输入应该是float64值的2维数组,具有C风格的存储器布局。

相关问题