scipy 如何在python中求解一个具有上百对已知因变量和自变量的三次函数

bfnvny8b  于 2022-11-10  发布在  Python
关注(0)|答案(2)|浏览(118)

我试图找出如何解决一个三次函数与自变量(x)和因变量f(x)已知,但系数a,b,c和常数d未知。我尝试了sympy,但意识到它只适用于4对。现在我试图探索一些可能性来解决这个问题(通过找到系数a/b/c和常数d的实际值)。任何建议都非常感谢。
下面的代码显然不起作用,返回了一个空列表,因为有上百个对。

from sympy import Eq, solve
from sympy.abc import a,b,c,d, x

formula = a*x**3 + b*x**2 + c*x + d  # general cubic formula

xs = [28.0, 29.0, 12.0, 12.0, 42.0, 35.0, 28.0, 30.0, 32.0, 46.0, 18.0, 28.0, 28.0, 64.0, 
38.0, 18.0, 49.0, 37.0, 25.0, 24.0, 42.0, 50.0, 12.0, 64.0, 23.0, 35.0, 22.0, 16.0, 44.0, 
77.0, 26.0, 44.0, 38.0, 37.0, 45.0, 42.0, 24.0, 42.0, 12.0, 46.0, 12.0, 26.0, 37.0, 15.0, 
67.0, 36.0, 43.0, 36.0, 45.0, 82.0,
44.0, 30.0, 33.0, 51.0, 50.0]

fxs = [59.5833333333333, 59.5833333333333, 10.0, 10.0, 47.0833333333333, 51.2499999999999, 
34.5833333333333, 88.75, 63.7499999999999, 34.5833333333333, 51.2499999999999, 10.0, 
63.7499999999999, 51.0, 59.5833333333333,47.0833333333333, 49.5625, 43.5624999999999, 
63.7499999999999, 10.0, 76.25, 47.0833333333333,10.0, 51.2499999999999,47.0833333333333,10.0, 
35.0, 51.2499999999999, 76.25, 100.0, 51.2499999999999, 59.5833333333333, 63.7499999999999, 
76.25, 100.0, 51.2499999999999, 10.0, 22.5, 10.0, 88.75, 10.0, 59.5833333333333, 
47.0833333333333, 34.5833333333333, 51.2499999999999, 63.7499999999999,63.7499999999999, 10.0, 
76.25, 62.1249999999999, 47.0833333333333, 10.0, 76.25, 47.0833333333333, 88.75]

sol = solve([Eq(formula.subs(x, xi), fx) for xi, fx in zip(xs, fxs)])
print(sol)  
[]
flvtvl50

flvtvl501#

如何使用SymPy解决此问题(假设您需要最小平方误差解决方案):

In [2]: errors_squared = [(fx - formula.subs(x, xi))**2 for fx, xi in zip(xs, fxs)]

In [3]: error = Add(*errors_squared)

In [4]: sympy.linsolve([error.diff(v) for v in [a, b, c, d]], [a, b, c, d])
Out[4]: {(0.00019197277106452, -0.0310483217324413, 1.68127292155383, 7.51784205803798)}
a0x5cqrl

a0x5cqrl2#

对于曲线拟合,我建议使用scipy.optimize.curve_fit

import numpy as np
import plotly.graph_objects as go  # only used to show output; not needed in answer
from scipy.optimize import curve_fit

def cubic(x, a, b, c, d):
    return a * x**3 + b * x**2 + c * x + d

(a, b, c, d), _ = curve_fit(cubic, xs, fxs)  # `xs` and `fxs` copied from the OP

x = np.linspace(min(xs), max(xs), 1000)

fig = go.Figure(go.Scatter(x=xs, y=fxs, mode='markers', name='data'))
fig.add_trace(go.Scatter(x=x, y=cubic(x, a, b, c, d), name='fit'))
fig.show()

输出量:

相关问题