numpy 在numba中连接python元组

bvjxkvbb  于 2023-03-02  发布在  Python
关注(0)|答案(2)|浏览(122)

我想用从元组中取出的数字填充一个零数组,就这么简单。
通常情况下,即使元组长度不同(这是这里的重点),这也不是问题。但似乎无法编译,我无法找到解决方案。

from numba import jit    

def cant_jit(ls):

    # Array total lenth
    tl = 6
    # Type
    typ = np.int64

    # Array to modify and return
    start = np.zeros((len(ls), tl), dtype=typ)

    for i in range(len(ls)):

        a = np.array((ls[i]), dtype=typ)
        z = np.zeros((tl - len(ls[i]),), dtype=typ)
        c = np.concatenate((a, z))
        start[i] = c

    return start

# Uneven tuples would be no problem in vanilla
cant_jit(((2, 4), (6, 8, 4)))

jt = jit(cant_jit)    
# working fine
jt(((2, 4), (6, 8)))
# non working
jt(((2, 4), (6, 8, 4)))

在误差范围内。

  • getitem(元组(单元组(int 64 x 3),单元组(int 64 x 2)),int 64)有22个候选实现:- 其中22个不匹配,原因是:函数“getitem”重载:文件::行N/A。带参数:'(元组(单元组(整数64 x 3),单元组(整数64 x 2)),整数64)':无匹配项。*

我在这里尝试了一些东西,但没有成功。有没有人知道绕过这个的方法,这样函数就可以编译并仍然做它的事情?

laik7k3q

laik7k3q1#

据我所知这是不可能的,numba文档告诉我们,除非你使用forceobj=True,否则长度不等的嵌套元组是不法律的的。你甚至不能解压缩 *args,这是令人沮丧的。你总是会收到警告/错误:
只需将该参数添加到jit()中,如下所示:

from numba import jit    
import numpy as np

def cant_jit(ls):

    # Array total lenth
    tl = 6
    # Type
    typ = np.int64

    # Array to modify and return
    start = np.zeros((len(ls), tl), dtype=typ)

    for i in range(len(ls)):

        a = np.array((ls[i]), dtype=typ)
        z = np.zeros((tl - len(ls[i]),), dtype=typ)
        c = np.concatenate((a, z))
        start[i] = c

    return start

# Uneven tuples would be no problem in vanilla
cant_jit(((2, 4), (6, 8, 4)))

jt = jit(cant_jit, forceobj=True)    
# working fine
jt(((2, 4), (6, 8)))
# now working
jt(((2, 4), (6, 8, 4)))

这是可行的,但有点无意义,你也可以使用core python。

ovfsdjhp

ovfsdjhp2#

我想知道numba是否会更喜欢这个非 numpy 版本:

def foo1(ls):
    res = []
    for row in ls:
        res.append(row+((0,)*(6-len(ls))))
    return res

相关问题