scipy curve_fit不使用自定义函数

00jrzges  于 2023-06-23  发布在  其他
关注(0)|答案(1)|浏览(78)

我将fit方法定义如下:

def fit(f1, f2, x_data, y_data):
    def fun(x, w1, w2):
        return w1*f1(x) + w2*f2(x)
    return curve_fit(fun, x_data, y_data)

那么这段代码就可以正常工作了:

x_data, y_data = np.array(range(0, 5)), np.array([0, 1, 2, 3, 2])

f1 = np.sin
f2 = np.cos

fit(f1, f2, x_data, y_data)

但是,以下代码生成:不可散列类型:'numpy.ndarray'

d1 = {0: 1, 1:3, 2:4, 3:5, 4:3}
d2 = {0: 2, 1:1, 2:6, 3:2, 4:1}

f1 = lambda x: d1.get(x)
f2 = lambda x: d1.get(x)

fit(f1, f2, x_data, y_data)

有什么路可以绕过去吗?或者,是否有其他选项可以对给定数据进行自定义函数的线性组合拟合?

iyfjxgzm

iyfjxgzm1#

你需要对你的函数进行向量化,以允许它们在numpy数组上工作:

from scipy.optimize import curve_fit

def fit(f1, f2, x_data, y_data):
    def fun(x, w1, w2):
        return w1*f1(x) + w2*f2(x)
    return curve_fit(fun, x_data, y_data)

x_data, y_data = np.array(range(0, 5)), np.array([0, 1, 2, 3, 2])

d1 = {0: 1, 1:3, 2:4, 3:5, 4:3}
d2 = {0: 2, 1:1, 2:6, 3:2, 4:1}

f1 = np.vectorize(d1.get)
f2 = np.vectorize(d2.get)

fit(f1, f2, x_data, y_data)

相关问题