为什么我的scipy.optimize.curve_fit p0不将列表作为单个参数?

dvtswwa3  于 2022-11-23  发布在  其他
关注(0)|答案(1)|浏览(231)
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import scipy.io as spio
from scipy.optimize import curve_fit

我试图使用高斯函数创建一条与curve_fit拟合的直线。

def gauss (x, peaks) : #peaks is parameters of gaussian fitting function
    result = peaks[0] + peaks[1]*x 
    for i in range(2 , len(peaks), 3) : #looping through each peak needed in curve fitting
        result += peak[i] * np.exp(-(x-peak[i+2])**2/(2*(peak[i+1])**2))
    return result

guess = [[1, 0.1, 1, 640, 1, 0.2, 651, 2]]
popt, pcov = curve_fit(gauss, x, y, p0 = guess)

其中

x= [635.        , 635.2016129 , 635.40322581, 635.60483871,
        635.80645161, 636.00806452, 636.20967742, 636.41129032,
        636.61290323, 636.81451613, 637.01612903, 637.21774194,
        637.41935484, 637.62096774, 637.82258065, 638.02419355,
        638.22580645, 638.42741935, 638.62903226, 638.83064516,
        639.03225806, 639.23387097, 639.43548387, 639.63709677,
        639.83870968, 640.04032258, 640.24193548, 640.44354839,
        640.64516129, 640.84677419, 641.0483871 , 641.25      ,
        641.4516129 , 641.65322581, 641.85483871, 642.05645161,
        642.25806452, 642.45967742, 642.66129032, 642.86290323,
        643.06451613, 643.26612903, 643.46774194, 643.66935484,
        643.87096774, 644.07258065, 644.27419355, 644.47580645,
        644.67741935, 644.87903226, 645.08064516, 645.28225806,
        645.48387097, 645.68548387, 645.88709677, 646.08870968,
        646.29032258, 646.49193548, 646.69354839, 646.89516129,
        647.09677419, 647.2983871 , 647.5       , 647.7016129 ,
        647.90322581, 648.10483871, 648.30645161, 648.50806452,
        648.70967742, 648.91129032, 649.11290323, 649.31451613,
        649.51612903, 649.71774194, 649.91935484, 650.12096774,
        650.32258065, 650.52419355, 650.72580645, 650.92741935,
        651.12903226, 651.33064516, 651.53225806, 651.73387097,
        651.93548387, 652.13709677, 652.33870968, 652.54032258,
        652.74193548, 652.94354839, 653.14516129, 653.34677419,
        653.5483871 , 653.75      , 653.9516129 , 654.15322581,
        654.35483871, 654.55645161, 654.75806452, 654.95967742,
        655.16129032, 655.36290323, 655.56451613, 655.76612903,
        655.96774194, 656.16935484, 656.37096774, 656.57258065,
        656.77419355, 656.97580645, 657.17741935, 657.37903226,
        657.58064516, 657.78225806, 657.98387097, 658.18548387,
        658.38709677, 658.58870968, 658.79032258, 658.99193548,
        659.19354839, 659.39516129, 659.59677419, 659.7983871 ,
        660.        ]

y = [1.00610685, 1.11422789, 1.05862689, 0.99178016, 0.93604815,
        0.94402921, 1.03761458, 1.15099585, 1.13679326, 1.14722848,
        1.14475381, 1.16846848, 1.19926167, 1.29001915, 1.25844276,
        1.36683512, 1.25987339, 1.40151715, 1.62535977, 1.60804808,
        1.83149123, 2.06974792, 2.02933621, 2.10553861, 1.89437556,
        1.81698382, 1.75859034, 1.63153493, 1.75228822, 1.62093472,
        1.62339783, 1.71172881, 1.52044606, 1.43718326, 1.39070237,
        1.31553793, 1.39568388, 1.22597241, 1.35875285, 1.34322095,
        1.21272945, 1.28729749, 1.32186389, 1.33674073, 1.23697054,
        1.28606296, 1.20751882, 1.29997706, 1.15509164, 1.19586015,
        1.18754458, 1.18877351, 1.15815341, 1.16502881, 1.13353086,
        1.13657737, 1.15174437, 1.22875524, 1.13051963, 1.22225213,
        1.22628403, 1.16598094, 1.10445559, 1.12308681, 1.27446711,
        1.22632468, 1.20066011, 1.21928763, 1.24986732, 1.27032673,
        1.31073976, 1.34605145, 1.37473345, 1.42251658, 1.49298894,
        1.45713902, 1.47692657, 1.4757812 , 1.43597376, 1.4392674 ,
        1.44210732, 1.46068954, 1.47488177, 1.49574065, 1.49968791,
        1.49236512, 1.46846592, 1.4510318 , 1.44610083, 1.36172342,
        1.36954188, 1.32330287, 1.32561278, 1.28688407, 1.29428899,
        1.27018023, 1.24460649, 1.22342896, 1.22246587, 1.25698733,
        1.22738147, 1.20932436, 1.19963503, 1.19483376, 1.17124403,
        1.18308687, 1.19256997, 1.17640173, 1.18847394, 1.19310498,
        1.19029582, 1.19089711, 1.21491408, 1.19287658, 1.20004702,
        1.21787214, 1.19860458, 1.23193061, 1.20111501, 1.19508743,
        1.21893013, 1.21033764, 1.17575479, 1.17496657, 1.21453702]

我很困惑为什么这样做不起作用。我试图将guess作为一个列表传递给gauss作为单个输入,但是curve_fit将该列表作为一个输入列表,而不是单个输入。
错误如下:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [6], in <cell line: 8>()
      5     return result
      7 guess = [[1, 0.1, 1, 640, 1, 0.2, 651, 2]]
----> 8 popt, pcov = curve_fit(gauss, x, y, p0 = guess)

File /opt/conda/lib/python3.10/site-packages/scipy/optimize/_minpack_py.py:834, in curve_fit(f, xdata, ydata, p0, sigma, absolute_sigma, check_finite, bounds, method, jac, full_output, **kwargs)
    831 if ydata.size != 1 and n > ydata.size:
    832     raise TypeError(f"The number of func parameters={n} must not"
    833                     f" exceed the number of data points={ydata.size}")
--> 834 res = leastsq(func, p0, Dfun=jac, full_output=1, **kwargs)
    835 popt, pcov, infodict, errmsg, ier = res
    836 ysize = len(infodict['fvec'])

File /opt/conda/lib/python3.10/site-packages/scipy/optimize/_minpack_py.py:410, in leastsq(func, x0, args, Dfun, full_output, col_deriv, ftol, xtol, gtol, maxfev, epsfcn, factor, diag)
    408 if not isinstance(args, tuple):
    409     args = (args,)
--> 410 shape, dtype = _check_func('leastsq', 'func', func, x0, args, n)
    411 m = shape[0]
    413 if n > m:

File /opt/conda/lib/python3.10/site-packages/scipy/optimize/_minpack_py.py:24, in _check_func(checker, argname, thefunc, x0, args, numinputs, output_shape)
     22 def _check_func(checker, argname, thefunc, x0, args, numinputs,
     23                 output_shape=None):
---> 24     res = atleast_1d(thefunc(*((x0[:numinputs],) + args)))
     25     if (output_shape is not None) and (shape(res) != output_shape):
     26         if (output_shape[0] != 1):

File /opt/conda/lib/python3.10/site-packages/scipy/optimize/_minpack_py.py:485, in _wrap_func.<locals>.func_wrapped(params)
    484 def func_wrapped(params):
--> 485     return func(xdata, *params) - ydata

TypeError: gauss() takes 2 positional arguments but 9 were given

正如您所看到的,给出了9个输入,但guess的长度只有8,这意味着它工作得比较正常,并将一个x作为输入,但列表并没有被当作列表。
我也尝试过使用np.array和其他一些小的修改,但是没有任何效果

kiayqfof

kiayqfof1#

要拟合的函数应该具有签名def f(xdata, *params),即函数f具有任意数目的params。在函数params内,只是一个包含所有参数的元组。
在代码中:

from scipy.optimize import curve_fit

def gauss(x, *peaks):
    result = peaks[0] + peaks[1]*x 
    for i in range(2 , len(peaks), 3):
        result += peaks[i] * np.exp(-(x-peaks[i+2])**2/(2*(peaks[i+1])**2))
    return result

guess = [1, 0.1, 1, 640, 1, 0.2, 651, 2]
popt, pcov = curve_fit(gauss, x, y, p0 = guess)

还请注意,循环中有一个排印错误,guess不应该是嵌套列表。

相关问题