Python numba不能合并数组

y1aodyip  于 2023-01-16  发布在  Python
关注(0)|答案(2)|浏览(285)

当尝试对组合两个numpy数组的函数执行JIT时

import numpy as np
import numba as nb

@nb.njit
def combine(a: nb.float64[:], b: nb.float64[:]):
    return np.array([a, b])

使用浮点参数运行函数不会抛出错误,即

>>> combine(1., 2.)
array([1., 2.])

但是当我尝试合并两个数组时,我得到

>>> combine(np.array([1., 2.]), np.array([3., 4.]))
Traceback (most recent call last):
  File "c:\Users\Lucas Gruwez\Documents\test.py", line 14, in <module>
    combine(np.array([1., 2.]), np.array([3., 4.]))
  File "C:\Users\Lucas Gruwez\AppData\Local\Programs\Python\Python310\lib\site-packages\numba\core\dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "C:\Users\Lucas Gruwez\AppData\Local\Programs\Python\Python310\lib\site-packages\numba\core\dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function array>) found for signature:

 >>> array(list(array(float64, 1d, C))<iv=None>)

There are 4 candidate implementations:
  - Of which 4 did not match due to:
  Overload in function '_OverloadWrapper._build.<locals>.ol_generated': File: numba\core\overload_glue.py: Line 129.
    With argument(s): '(list(array(float64, 1d, C))<iv=None>)':
   Rejected as the implementation raised a specific error:
     TypingError: array(float64, 1d, C) not allowed in a homogeneous sequence
  raised from C:\Users\Lucas Gruwez\AppData\Local\Programs\Python\Python310\lib\site-packages\numba\core\typing\npydecl.py:488

During: resolving callee type: Function(<built-in function array>)
During: typing of call at c:\Users\Lucas Gruwez\Documents\test.py (12)

File "src\test.py", line 12:
def combine(a: nb.float64[:], b: nb.float64[:]):
    return np.array([a, b])
6pp0gazn

6pp0gazn1#

使用np.array([a, b])连接两个numpy数组是一种快捷方式,由于numba的严格类型要求,这种快捷方式是不允许的。numba需要类似[0,1,2]的同构序列,因此失败。
键入错误:同类序列中不允许使用数组(float64,1d,C)
相反,只需使用np.array()从基本类型(如numba中的float或int)创建新数组,并使用np.stacknp.concatenate组合现有数组。

    • 例如:**
@nb.njit
def combine(a: nb.float64[:], b: nb.float64[:]):
    return np.stack((a, b))

combine(np.array([1., 2.]), np.array([3., 4.]))
array([[1., 2.],
       [3., 4.]])
    • 或:**

一个二个一个一个

    • PS**由于静态类型,一旦使用np.float[:]数组作为dtype进行编译,就不应该使用np.float64标量调用函数。同样,使用标量数字调用np.stack((1.,2.))也会失败。此时,您必须再次使用np.array([1.,2.])
rryofs0p

rryofs0p2#

下面的代码可能会有帮助,它可以在两个单浮点数或两个具有相同形状的1D数组上运行。对于这两个数组,我建议使用索引来填充结果数组,这将工作。isinstance现在是实验特性,可以用其他模块替换/编写,例如hassattrnp.isscaler等。我没有检查他们是否可以与numba njit一起使用。这只是一个解决方案,以显示我们如何面对这样的问题。

@nb.njit  # (["float64[::1](float64, float64)", "float64[:, ::1](float64[::1], float64[::1])"])
def combine(a, b):
    if isinstance(a, float):
        return np.array([a, b])
    else:
        arr = np.empty((2, *a.shape))
        arr[0, :] = a
        arr[1, :] = b
        return arr

相关问题