scipy 如何在python中求解一个具有上百对已知因变量和自变量的三次函数

bfnvny8b  于 2022-11-10  发布在  Python
关注(0)|答案(2)|浏览(130)

我试图找出如何解决一个三次函数与自变量(x)和因变量f(x)已知,但系数a,b,c和常数d未知。我尝试了sympy,但意识到它只适用于4对。现在我试图探索一些可能性来解决这个问题(通过找到系数a/b/c和常数d的实际值)。任何建议都非常感谢。
下面的代码显然不起作用,返回了一个空列表,因为有上百个对。

  1. from sympy import Eq, solve
  2. from sympy.abc import a,b,c,d, x
  3. formula = a*x**3 + b*x**2 + c*x + d # general cubic formula
  4. xs = [28.0, 29.0, 12.0, 12.0, 42.0, 35.0, 28.0, 30.0, 32.0, 46.0, 18.0, 28.0, 28.0, 64.0,
  5. 38.0, 18.0, 49.0, 37.0, 25.0, 24.0, 42.0, 50.0, 12.0, 64.0, 23.0, 35.0, 22.0, 16.0, 44.0,
  6. 77.0, 26.0, 44.0, 38.0, 37.0, 45.0, 42.0, 24.0, 42.0, 12.0, 46.0, 12.0, 26.0, 37.0, 15.0,
  7. 67.0, 36.0, 43.0, 36.0, 45.0, 82.0,
  8. 44.0, 30.0, 33.0, 51.0, 50.0]
  9. fxs = [59.5833333333333, 59.5833333333333, 10.0, 10.0, 47.0833333333333, 51.2499999999999,
  10. 34.5833333333333, 88.75, 63.7499999999999, 34.5833333333333, 51.2499999999999, 10.0,
  11. 63.7499999999999, 51.0, 59.5833333333333,47.0833333333333, 49.5625, 43.5624999999999,
  12. 63.7499999999999, 10.0, 76.25, 47.0833333333333,10.0, 51.2499999999999,47.0833333333333,10.0,
  13. 35.0, 51.2499999999999, 76.25, 100.0, 51.2499999999999, 59.5833333333333, 63.7499999999999,
  14. 76.25, 100.0, 51.2499999999999, 10.0, 22.5, 10.0, 88.75, 10.0, 59.5833333333333,
  15. 47.0833333333333, 34.5833333333333, 51.2499999999999, 63.7499999999999,63.7499999999999, 10.0,
  16. 76.25, 62.1249999999999, 47.0833333333333, 10.0, 76.25, 47.0833333333333, 88.75]
  17. sol = solve([Eq(formula.subs(x, xi), fx) for xi, fx in zip(xs, fxs)])
  18. print(sol)
  19. []
flvtvl50

flvtvl501#

如何使用SymPy解决此问题(假设您需要最小平方误差解决方案):

  1. In [2]: errors_squared = [(fx - formula.subs(x, xi))**2 for fx, xi in zip(xs, fxs)]
  2. In [3]: error = Add(*errors_squared)
  3. In [4]: sympy.linsolve([error.diff(v) for v in [a, b, c, d]], [a, b, c, d])
  4. Out[4]: {(0.00019197277106452, -0.0310483217324413, 1.68127292155383, 7.51784205803798)}
a0x5cqrl

a0x5cqrl2#

对于曲线拟合,我建议使用scipy.optimize.curve_fit

  1. import numpy as np
  2. import plotly.graph_objects as go # only used to show output; not needed in answer
  3. from scipy.optimize import curve_fit
  4. def cubic(x, a, b, c, d):
  5. return a * x**3 + b * x**2 + c * x + d
  6. (a, b, c, d), _ = curve_fit(cubic, xs, fxs) # `xs` and `fxs` copied from the OP
  7. x = np.linspace(min(xs), max(xs), 1000)
  8. fig = go.Figure(go.Scatter(x=xs, y=fxs, mode='markers', name='data'))
  9. fig.add_trace(go.Scatter(x=x, y=cubic(x, a, b, c, d), name='fit'))
  10. fig.show()

输出量:

展开查看全部

相关问题