Scipy曲线拟合错误(_F):函数调用的结果不是正确的浮点数数组

mcdcgff0  于 2022-11-10  发布在  其他
关注(0)|答案(1)|浏览(135)

如何求解curve_fit?ValueError:对象对于所需数组太深错误:函数调用的结果不是正确的浮点数数组。


### Import Libraries

import numpy as np
from scipy.optimize import curve_fit

### Define Function

def Func(vars, C1, C2):
        (X, Y) = vars
        Z1 = (C1*Y**2) / (1+(1-(C1*Y)**2)**0.5)
        Z2 = (C2*X**2) / (1+(1-(C2*X)**2)**0.5)
        return Z1 + Z2

### Y Data

xL = np.linspace(0.0, 10, 11).flatten()   ## Sub
yL = np.linspace(0.0, 100, 101).flatten() ## Main 
X, Y = np.meshgrid(abs(xL), abs(yL))

### Coefficient

C1 = 0.002
C2 = 0.005

### Calculate : Original and Noise Data

Z_original = Func((X, Y), C1, C2)

Z_noise = np.random.normal(size=(len(xL)*len(yL)), scale=0.5)
Z_noise.resize(len(yL), len(xL)) 
Z_noise = Z_original + Z_noise

### Curve_Fit """ ???????????????????????? """

p0 = (0.002, 0.005)
popt, pcov = curve_fit(Func, (X,Y), Z_noise, p0)
Z_curvefit = Func((X,Y), *popt)
llycmphe

llycmphe1#

scipy.optimize.curve_fit可用于拟合2D数据,但从属数据(模型函数的输出)必须仍为1D(如scipy doc中所述)。
一个解决方案是使用np.ravel()来展平func的返回值:

import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt

def func(data, c1, c2):
        (X, Y) = data
        Z1 = (c1*Y**2) / (1+(1-(c1*Y)**2)**0.5)
        Z2 = (c2*X**2) / (1+(1-(c2*X)**2)**0.5)
        return (Z1 + Z2).ravel()  # <-- add ravel() here

xL = np.linspace(0.0, 10, 11).flatten()   ## Sub
yL = np.linspace(0.0, 100, 101).flatten() ## Main 
X, Y = np.meshgrid(abs(xL), abs(yL))

c1 = 0.002
c2 = 0.005

# Original and noisy Data

Z_original = func((X, Y), c1, c2)
Z_noise = np.random.normal(size=(len(xL)*len(yL)), scale=0.5)  # <-- no resizing necessary now
Z = Z_original + Z_noise

# Curve fiting

p0 = (0.002, 0.005)
popt, pcov = curve_fit(func, (X,Y), Z, p0)
Z_curvefit = func((X,Y), *popt)

# plot (reshape Z and Z_curvefit first)

fig, ax = plt.subplots(1, 1)
ax.imshow(Z.reshape(101, 11), 
    cmap=plt.cm.jet, origin='lower', 
    extent=(X.min(), X.max(), Y.min(), Y.max()))
ax.contour(X, Y, Z_curvefit.reshape(101, 11), 8, colors='w')
plt.show()

相关问题