python—在项目中将numba函数拆分为单独的模块进行打包

2j4z5cfb  于 2021-07-14  发布在  Java
关注(0)|答案(1)|浏览(299)

我的项目中有几个模块,每个模块包含几个 numba 功能。我知道在第一次导入时,函数会被编译。我现在才注意到的是,即使我只从一个模块中导入一个函数,似乎所有的函数都会被编译,因为导入所花费的时间是相同的。
我想实现一种更细粒度的方法来解决这个问题,因为对于某些应用程序来说,实际上只需要一个函数,因此编译所有函数都是浪费时间。
为此,我将函数分解为如下单独的模块:

Project/
|--src/
|  |-- __init__.py
|  |-- fun1.py
|  |-- fun2.py
|  |-- fun3.py 
|  |-- fun4.py
|  |-- ...

__init__.py 包括

from .fun1 import fun1
from .fun2 import fun2
...

所以它们可以像 from src import fun1 .
这似乎工作正常,但在导入级别有一点重复,例如每个函数都需要 from numba import jit ,他们中的一些人需要 from numpy import zeros 等等。
所以我的问题是,这是一个好方法,还是有一个更好的方法来 Package 许多 numba 功能。

编辑:

将所有导入语句放入 __init__.py 显然,这意味着所有的函数在导入一个函数后都会被编译,所以根本没有任何好处。
我仍然可以导入如下函数

from src.fun1 import fun1

这似乎管用。但是语法有点笨拙。

y4ekin9u

y4ekin9u1#

有趣的问题-你本质上是在问如何延迟函数的定义,直到它被显式导入。我认为最好的方法就是像你说的,用 from src.fun1 import fun1 每个文件有一个函数。
我认为在同一个文件中有多个函数时实现这一点可能非常棘手,因此我将问题放宽到“如何延迟函数的定义,直到显式调用(而不是导入)”上。

琐碎的解决方案

一个简单的方法就是将你的函数 Package 在一个虚拟的外部函数中。
这并不是我们想要的,因为后续调用 fun1 将导致内部功能和 numba.jit 正在重新创建的装饰器,需要重新编译。


# main.py

# This lets us see when numba is compiling.

# See https://numba.pydata.org/numba-doc/dev/reference/envvars.html

import os
os.environ["NUMBA_DEBUG_FRONTEND"] = "1"

import fun1
print("note no numba debug output yet for fun1")
print("fun1 result is", fun1.fun1(1, 2))
print("fun1 result is", fun1.fun1(2, 1))
print("note the function was compiled twice :(")

# fun1.py

import numba

# Naively wrap fun1 in another function so it's only declared

# when the outer function is called.

def fun1(*args,**kwargs):
    @numba.jit("float32(float32, float32)", cache=False)  # No cache, for debugging
    def __fun1(a, b):
        return a + b
    return __fun1(*args,**kwargs)

使用装饰器的更高级解决方案

简单的解决方案是将您的函数 Package 到另一个函数中。。。。闻起来很像装饰师。。。。
我创建了一个decorator(外部decorator),它将另一个decorator(内部decorator)作为输入。外装饰应用内装饰( numba.jit 在本例中)仅在第一次调用函数时。然后在后续调用中重新使用内部修饰函数。


# main.py

# This lets us see when numba is compiling.

# See https://numba.pydata.org/numba-doc/dev/reference/envvars.html

import os
os.environ["NUMBA_DEBUG_FRONTEND"] = "1"

import fun2
print("note no numba debug output yet for fun2")
print("fun2 result is", fun2.fun2(3, 4))
print("fun2 result is", fun2.fun2(5, 6))
print("note the function was compiled only once :)")

# fun2.py

import numba
from functools import wraps

def delayed(internal_decorator):
    def _delayed(f):
        inner_decorated = None
        @wraps(f)
        def wrapper(*args,**kwds):
            nonlocal inner_decorated
            if inner_decorated is None:
                inner_decorated = internal_decorator(f)
            return inner_decorated(*args,**kwds)
        return wrapper
    return _delayed

@delayed(numba.jit("float32(float32, float32)", cache=False))
def fun2(a, b):
    return a * b

相关问题