使用sympy和scipy会导致不良行为,我的错误在哪里?

jm2pwxwz  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(154)

我正在写一个计算导数和积分的代码。为此,我分别使用了sympyscipy.integrate。但是,这种使用导致了一个奇怪的错误行为。下面是一个最小的代码,它重现了这种行为:

from scipy.integrate import quadrature
import sympy
import math

def func(z):
    return math.e**(2*z**2)+3*z

z = symbols('z')
f = func(z)
dlogf_dz = sympy.Lambda(z, sympy.log(f).diff(z))
print(dlogf_dz)
print(dlogf_dz(10))

def integral(z):
    # WHY DO I HAVE TO USE [0] BELOW ?!?!
    d_ARF=dlogf_dz(z[0])
    return d_ARF

result = quadrature(integral, 0, 3)
print(result)

>>> Lambda(z, (4.0*2.71828182845905**(2*z**2)*z + 3)/(2.71828182845905**(2*z**2) + 3*z))
>>> 40.0000000000000
>>> (8.97457203290041, 0.00103422711649337)

前两个print语句提供了数学上正确的结果,但积分结果result是错误的--它应该正好是18,而不是~8.975。我知道这一点,因为我用Wolfram Alpha进行了双重检查,因为dlogf_dz在数学上非常简单,您可以自己检查here
奇怪的是,如果我把第一个print语句中的数学表达式硬编码到integral(z)中,我会得到正确的答案:

def integral(z):
    d_ARF=(4.0*2.71828182845905**(2*z**2)*z + 3)/(2.71828182845905**(2*z**2) + 3*z)
    return d_ARF

result = quadrature(integral, 0, 3)
print(result)

>>> (18.000000063540558, 1.9408245677254854e-07)

我认为问题是我没有以正确的方式将dlogf_dz提供给integral(z),我必须以其他方式将dlogf_dz定义为函数。注意,我在integral(z)中定义了d_ARF=dlogf_dz(z[0]),否则函数会给出错误。

**我做错了什么?如何修复代码,更精确地使dlogf_dzsympy集成兼容?**Tnx

2hh7jdfx

2hh7jdfx1#

在一个符号为z的symmy环境中,我可以使用lambdify(而不是Lambda)从symmy表达式创建一个与numpy兼容的函数。

In [22]: expr = E**(2*z**2)+3*z

In [23]: fn = lambdify(z, log(expr).diff(z))

In [24]: quad(fn,0,3)

Out[24]:

lambdify生成了此函数:

In [25]: help(fn)
Help on function _lambdifygenerated:

_lambdifygenerated(z)
    Created with lambdify. Signature:

    func(z)

    Expression:

    (4*z*exp(2*z**2) + 3)/(3*z + exp(2*z**2))

lambdify不是傻瓜,但它是将sympyscipy结合使用的最佳工具。
https://docs.sympy.org/latest/modules/utilities/lambdify.html

kgsdhlau

kgsdhlau2#

尽可能不要将sympy和其他数值包混在一起。例如,你写道:

def func(z):
    return math.e**(2*z**2)+3*z

让我们将math.e替换为sympy.E

def func(z):
    return sympy.E**(2*z**2)+3*z

现在,在与Scipy集成之前,由于这似乎是一个相对容易的表达式,我将尝试使用SymPy:

dlogf_dz = sympy.log(f).diff(z)
dlogf_dz.integrate((z, 0, 3)).n()

# out: 18.0000001370698

相关问题