scipy曲线拟合

fdx2calv  于 2022-12-18  发布在  其他
关注(0)|答案(1)|浏览(173)

我试图用SciPy中的curve_fit来拟合一条曲线,但是它没有按预期工作,我不知道为什么。

xdata = np.asarray(std_ex_90degree[5050:5150,0])
ydata = np.asarray(std_ex_90degree[5050:5150,1])
print(xdata,ydata)
  

def Gauss(x, A, B):
    y = A*np.exp(-1*B*x**2)
    return y

popt, covariance = curve_fit(Gauss, xdata, ydata)

fit_A, fit_B = popt
  
fit_y = Gauss(xdata, fit_A, fit_B)

plt.scatter(xdata, ydata, label='data',s=5)
plt.plot(xdata, fit_y, '-', label='fit')
plt.legend()

正如你所看到的高斯拟合没有工作,我只得到了一条直线。
数据如下:

[2834.486 2834.968 2835.45  2835.932 2836.414 2836.896 2837.378 2837.861
 2838.343 2838.825 2839.307 2839.789 2840.271 2840.753 2841.235 2841.718
 2842.2   2842.682 2843.164 2843.646 2844.128 2844.61  2845.093 2845.575
 2846.057 2846.539 2847.021 2847.503 2847.985 2848.468 2848.95  2849.432
 2849.914 2850.396 2850.878 2851.36  2851.843 2852.325 2852.807 2853.289
 2853.771 2854.253 2854.735 2855.218 2855.699 2856.182 2856.664 2857.146
 2857.628 2858.11  2858.592 2859.074 2859.557 2860.039 2860.521 2861.003
 2861.485 2861.967 2862.449 2862.932 2863.414 2863.896 2864.378 2864.86
 2865.342 2865.824 2866.307 2866.789 2867.271 2867.753 2868.235 2868.717
 2869.199 2869.682 2870.164 2870.646 2871.128 2871.61  2872.092 2872.574
 2873.056 2873.539 2874.021 2874.503 2874.985 2875.467 2875.949 2876.431
 2876.914 2877.396 2877.878 2878.36  2878.842 2879.324 2879.806 2880.289
 2880.771 2881.253 2881.735 2882.217] 
[0.5027119 0.5155925 0.5296563 0.5450429 0.5619112 0.5804411 0.6008373
 0.6233361 0.6482099 0.67577   0.7063611 0.7403504 0.7781109 0.8200049
 0.8663718 0.9175249 0.9737514 1.035319  1.102472  1.175419  1.254304
 1.339163  1.429889  1.526202  1.627649  1.733603  1.84322   1.955248
 2.067605  2.176702  2.276757  2.359875  2.417753  2.445059  2.441798
 2.41245   2.362954  2.298523  2.223243  2.14052   2.05336   1.964326
 1.87539   1.787885  1.702644  1.620191  1.540921  1.465193  1.393333
 1.325607  1.262171  1.203057  1.148185  1.097403  1.050529  1.007382
 0.9678    0.9316369 0.8987471 0.8689752 0.8421496 0.8180863 0.7965991
 0.7775094 0.76065   0.7458642 0.732995  0.7218768 0.7123291 0.7041584
 0.6971676 0.6911709 0.6860058 0.6815417 0.6776828 0.674363  0.6715436
 0.6692089 0.6673671 0.6660498 0.6653103 0.6652156 0.6658351 0.6672268
 0.6694273 0.6724483 0.676279  0.6808962 0.686272  0.6923797 0.699192
 0.7066767 0.7147906 0.7234787 0.7326793 0.7423348 0.7524015 0.7628553
 0.7736901 0.7849081]
vyswwuz2

vyswwuz21#

模型不适当。对模型进行细微调整会产生粗略拟合:

import numpy as np
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit

xdata = np.array((
    2834.486, 2834.968, 2835.45 , 2835.932, 2836.414, 2836.896, 2837.378, 2837.861,
    2838.343, 2838.825, 2839.307, 2839.789, 2840.271, 2840.753, 2841.235, 2841.718,
    2842.2  , 2842.682, 2843.164, 2843.646, 2844.128, 2844.61 , 2845.093, 2845.575,
    2846.057, 2846.539, 2847.021, 2847.503, 2847.985, 2848.468, 2848.95 , 2849.432,
    2849.914, 2850.396, 2850.878, 2851.36 , 2851.843, 2852.325, 2852.807, 2853.289,
    2853.771, 2854.253, 2854.735, 2855.218, 2855.699, 2856.182, 2856.664, 2857.146,
    2857.628, 2858.11 , 2858.592, 2859.074, 2859.557, 2860.039, 2860.521, 2861.003,
    2861.485, 2861.967, 2862.449, 2862.932, 2863.414, 2863.896, 2864.378, 2864.86 ,
    2865.342, 2865.824, 2866.307, 2866.789, 2867.271, 2867.753, 2868.235, 2868.717,
    2869.199, 2869.682, 2870.164, 2870.646, 2871.128, 2871.61 , 2872.092, 2872.574,
    2873.056, 2873.539, 2874.021, 2874.503, 2874.985, 2875.467, 2875.949, 2876.431,
    2876.914, 2877.396, 2877.878, 2878.36 , 2878.842, 2879.324, 2879.806, 2880.289,
    2880.771, 2881.253, 2881.735, 2882.217,
))
ydata = np.array((
    0.5027119, 0.5155925, 0.5296563, 0.5450429, 0.5619112, 0.5804411, 0.6008373,
    0.6233361, 0.6482099, 0.67577  , 0.7063611, 0.7403504, 0.7781109, 0.8200049,
    0.8663718, 0.9175249, 0.9737514, 1.035319 , 1.102472 , 1.175419 , 1.254304 ,
    1.339163 , 1.429889 , 1.526202 , 1.627649 , 1.733603 , 1.84322  , 1.955248 ,
    2.067605 , 2.176702 , 2.276757 , 2.359875 , 2.417753 , 2.445059 , 2.441798 ,
    2.41245  , 2.362954 , 2.298523 , 2.223243 , 2.14052  , 2.05336  , 1.964326 ,
    1.87539  , 1.787885 , 1.702644 , 1.620191 , 1.540921 , 1.465193 , 1.393333 ,
    1.325607 , 1.262171 , 1.203057 , 1.148185 , 1.097403 , 1.050529 , 1.007382 ,
    0.9678   , 0.9316369, 0.8987471, 0.8689752, 0.8421496, 0.8180863, 0.7965991,
    0.7775094, 0.76065  , 0.7458642, 0.732995 , 0.7218768, 0.7123291, 0.7041584,
    0.6971676, 0.6911709, 0.6860058, 0.6815417, 0.6776828, 0.674363 , 0.6715436,
    0.6692089, 0.6673671, 0.6660498, 0.6653103, 0.6652156, 0.6658351, 0.6672268,
    0.6694273, 0.6724483, 0.676279 , 0.6808962, 0.686272 , 0.6923797, 0.699192 ,
    0.7066767, 0.7147906, 0.7234787, 0.7326793, 0.7423348, 0.7524015, 0.7628553,
    0.7736901, 0.7849081,
))

def gauss(x: np.ndarray, *args: float) -> np.ndarray:
    a, b, c, d = args
    return a*np.exp(-b*(x - c)**2) + d

popt, _ = curve_fit(
    gauss, xdata, ydata,
    p0=(1.7, 0.02, 2851, 0.7),
    maxfev=100_000,
)
print(popt)
fit_y = gauss(xdata, *popt)

plt.scatter(xdata, ydata, label='data', s=5)
plt.plot(xdata, fit_y, '-', label='fit')
plt.legend()
plt.show()
[1.68927347e+00 2.10977276e-02 2.85117456e+03 6.81806648e-01]

要做得更好,您的模型需要进行更多更改。

相关问题