scipy 如何从sympy.diff中提取函数?

bpzcxfmw  于 2023-06-29  发布在  其他
关注(0)|答案(1)|浏览(142)

我试着用sympy.diff计算一个导函数,然后求解这个函数中的'x',让导函数func(x) = 0的根。然而,求解这个导数函数非常慢,因为它返回五个解,但我只需要最接近固定值x0的解。

import sympy

def diff_dist_func(a, b, c):
    x = sympy.Symbol('x')
    x0 = sympy.Symbol('x0')
    y0 = sympy.Symbol('y0')

    dist = ((x - x0) ** 2 + (a * x ** 3 + b * x ** 2 + c * x - y0) ** 2) ** (1 / 2)
    return sympy.diff(dist, x)

a = -0.00020129919480721813
b=0.10107634020780536
c=-12.305150031126267
shortest_dist = diff_dist_func(a, b, c)
    
x0=252.3007982720215
y0=96.55526056735049
solve_shortest_dist = shortest_dist.evalf(subs={'x0': x0, 'y0': y0})  # build the derivative function
solve_x = sympy.solve(solve_shortest_dist, sympy.Symbol('x'), simplify=False, rational=False)  # Here solve the derivative function is very slow.

为了加快求解速度,我尝试使用scipy.optimize.fsolve,它能够给予func(x) = 0的根的初始估计值x0。因此,我将sympy.solve替换为fsolve(solve_shortest_dist, np.array([x0])),但出现错误TypeError: 'Mul' object is not callable。如何从sympy.diff的输出中提取导数函数,使其能够通过scipy.optimize.fsolve求解?或者有什么方法可以加快解决过程?

swvgeqrz

swvgeqrz1#

我使用的是SymPy 1.12,solve非常快,现在不需要转向数值库:

def diff_dist_func(a, b, c):
    x = sympy.Symbol('x')
    x0 = sympy.Symbol('x0')
    y0 = sympy.Symbol('y0')

    dist = sympy.sqrt((x - x0) ** 2 + (a * x ** 3 + b * x ** 2 + c * x - y0) ** 2)
    return sympy.diff(dist, x)

a = -0.00020129919480721813
b = 0.10107634020780536
c = -12.305150031126267
shortest_dist = diff_dist_func(a, b, c)

x0 = 252.3007982720215
y0 = 96.55526056735049
sols = sympy.solve(shortest_dist.subs({'x0': x0, 'y0': y0}), x)
sols
# out: [-5.99737132198727,
#  76.9370574126698,
#  252.303674956197,
#  256.81160524887 - 13.1689464528248*I,
#  256.81160524887 + 13.1689464528248*I]

注意复杂的解决方案。让我们计算错误wrt x0

sols = np.array(sols, dtype=complex)
error = np.abs(sols - x0)
error
# out: array([2.58298170e+02, 1.75363741e+02, 2.87668418e-03, 1.39200765e+01,
       1.39200765e+01])

最后,提取最接近x0的解:

idx = np.argmin(error)
sols[idx]
# out: 252.303674956197

相关问题