我目前正在尝试使用SymPy来生成和数值计算函数及其梯度。为了简单起见,我将使用以下函数作为示例(请记住,真实的函数要长得多):
import sympy as sp
def g(x):
return sp.cos(x) + sp.cos(x)**2 + sp.cos(x)**3
数值计算这个函数及其导数很容易:
import numpy as np
g_expr = sp.lambdify(x,g(x),modules='numpy')
dg_expr = sp.lambdify(x,sp.diff(g(x)),modules='numpy')
print g_expr(np.linspace(0,1,50))
print dg_expr(np.linspace(0,1,50))
然而,对于我的真实的函数,lambdify很慢,无论是在生成数值函数还是在计算方面。由于我的函数中的许多元素都是相似的,我想在lambdify中使用公共子表达式消除(cse)来加速这个过程。我知道SymPy有一个内置函数来执行cse,
>>> print sp.cse(g(x))
([(x0, cos(x))], [x0**3 + x0**2 + x0])
但是不知道要使用什么语法来在我的lambdify函数中使用这个结果(我仍然希望使用x作为输入参数):
>>> g_expr_fast = sp.lambdify(x,sp.cse(g(x)),modules='numpy')
>>> print g_expr_fast(np.linspace(0,1,50))
Traceback (most recent call last):
File "test3.py", line 34, in <module>
print g_expr1(nx1)
File "<string>", line 1, in <lambda>
NameError: global name 'x0' is not defined
任何关于如何在lambdify中使用cse的帮助都将不胜感激。或者,如果有更好的方法来加速我的梯度计算,我也会很感激听到这些。
如果它是相关的,我使用Python 2。7.3和SymPy 0。7.6.
3条答案
按热度按时间u1ehiz5o1#
计算速度可以提高:
前言
我假设这是“计算函数在渐近一次和使用它在不同的项目以后多次”类型的情况。因此,有一些手动复制粘贴和创建文件包括在内。* 然而 * 它可以自动创建新文件的功能,也编译步骤,但我离开了这一点。
我也遇到了类似的问题,我对不同的方法做了一些基准测试。我使用的函数很长(
len(str(expr)) = 45857
),cse(expr)
将其分解为72个子表达式。在这里复制粘贴太长了,但这里有一些步骤,可以使使用sympy创建的函数的速度提高100 - 1000倍。基准测试
A)求单浮点数
对每个参数使用 * 一个浮点值 * 来计算函数的时间。使用
timeit myfunc(*params)
。modules="numpy"
进行lambda定义:277µsstr(expr)
复制粘贴到函数定义:275µs(无差异)cse
后表达式的复制粘贴:8.2 µs(提高33倍)cse(optimizations="basic")
后表达式的复制粘贴:7.6µs(提高36倍)func_numba_f()
:0.25µs(提高1090倍)autowrap
:0.47 µs(提高589倍)B)评估np.数组1000浮点数
str(expr)
复制粘贴到函数定义:15100 µs|15.1µs/值cse
后表达式的复制粘贴:493微秒|每个值0.49µs(31倍改进)cse(optimizations="basic")
后表达式的复制粘贴:413微秒|每个值0.41µs(37倍改进)func_numba_arr()
:114µs|0.11µs/值(132倍改善)(1)
str(expr)
的复制粘贴(2)
cse
后表达式的复制粘贴print(redu[0])
.(3)
cse(optimizations="basic")
后表达式的复制粘贴optimizations="basic"
(4)使用numba编译代码
src_mymodule.py
func_numba_f()
中有 * 5个 * 浮点值输入变量和一个浮点值输出变量。f8
表示浮点数。func_numba_arr()
是处理np的版本。dtype="float64"
或dtype="float32"
的数组,具体取决于编译时使用的内容。python src_mymodule.py
编译一次代码。这将创建my_numba_module.cp38-win_amd64.pyd
或类似。它只能与文件名中的 * 相同的python版本和位数 * 一起使用。(5)使用sympy
autowrap
temp_dir
参数,它将保存所有源文件(。c,.h,.pyx)和a .pyd(win)/ .so(unix)文件,可用于稍后导入函数(假设temp_dir
在sys.path
中):gab6jxml2#
所以这可能不是最好的方法,但对于我的小例子来说,它是有效的。
下面代码的思想是对每个公共子表达式进行lambdifying,并生成一个可能包含所有参数的新函数。我添加了一些额外的sin和cos项,以添加来自先前子表达式的可能依赖项。
repl包含:
和redu包含
所以
funs
包含所有子表达式lambdified,列表xs
包含每个子表达式,这样最终可以正确地馈送glam
。xs
随着每个子表达式的增长而增长,最终可能会成为瓶颈。您可以对
sp.cse(sp.diff(g(sp.abc.x)))
的表达式执行相同的方法。2ul0zpep3#
从SymPy 1开始9,lambdify可以使用kwarg
cse=True
应用公共子表达式消除。参见: www.example.com