python 从numba创建一个njit修饰的numpy数组

jtoj6r0c  于 2023-05-21  发布在  Python
关注(0)|答案(2)|浏览(175)

代码在这里:

import numba as nb
import numpy as np

@nb.njit
def func(size):
    ary = np.array([np.arange(size),np.arange(size)+1,np.arange(size)-1]).T
    X = np.array([ary[1:,0] - ary[:-1,2],
                  ary[1:,1] - ary[:-1,2],
                  ary[1:,0] - ary[1:,1]
                  ])
    return X

Z = func(10**9)

当我运行代码时,它给我一个错误消息,我真的不明白这里发生了什么。njit修饰的函数不支持在函数内部创建新数组吗?错误消息如下:

TypingError: Invalid use of Function(<built-in function array>) with argument(s) of type(s): (list(array(int64, 1d, C)))
 * parameterized
In definition 0:
    TypingError: array(int64, 1d, C) not allowed in a homogeneous sequence
    raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\npydecl.py:459
In definition 1:
    TypingError: array(int64, 1d, C) not allowed in a homogeneous sequence
    raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\npydecl.py:459
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<built-in function array>)
[2] During: typing of call at C:/Users/User/Desktop/all python file/3.2.4/nb_datatype.py (65)

编辑:我忘了在编辑前转置数组,它应该是一个10^9乘3的数组。

wbrvyc0a

wbrvyc0a1#

numba.njit不支持通过NumPy数组的列表,甚至列表的列表示例化NumPy数组。相反,使用np.empty,然后通过NumPy索引分配值:

@nb.njit
def func(size):
    row_count = 3
    ary = np.empty((row_count, size))
    ranger = np.arange(size)
    ary[0] = ranger
    ary[1] = ranger + 1
    ary[2] = ranger - 1

    X = np.empty((row_count, row_count - 1))
    X[0] = ary[1:,0] - ary[:-1,2]
    X[1] = ary[1:,1] - ary[:-1,2]
    X[2] = ary[1:,0] - ary[1:,1]

    return X

Z = func(10**2)

print(Z)

array([[-1., -4.],
       [ 0., -3.],
       [-1., -1.]])
5lhxktic

5lhxktic2#

我听从了@jpp的建议,切换到了np.float64而不是float,但我也不得不将np.empty([1,2,3], np.float64)切换到np.empty((1,2,3), np.float64)

相关问题