Scipy Optimize即使使用强启动参数也无法优化

hwazgwia  于 2023-05-17  发布在  其他
关注(0)|答案(1)|浏览(140)

我有一个相当简单的函数,我想优化参数,但我不能让scipy.optimize.minimize成功。
以下是数据和问题的简化版本:

ref = np.array([0.586, 0.659, 0.73 , 0.799, 0.865, 0.929, 0.991, 1.05 , 1.107, 1.162])

input = np.array([70.0, 77.0, 82.0, 87.0, 93.0, 98.0, 98.0, 102.0, 106.0, 109.0])

x = np.array([6.96,  9.24, 10.92, 12.24, 13.92, 15.24, 15.24, 16.32, 17.64, 18.96])

## Function 
def fun(beta, x):
    return ((input**beta[0])*beta[2])*(x**beta[1])

## Starting parameters 
initial_guess = [0.15, 0.9475, 0.0427]

## Objective to be minimized 
def objective(beta, model, x, ref):
    return sum(((np.log(model(beta, x))-np.log(ref)))**2)

minimize(objective, initial_guess, args = (fun, x, ref))

我知道这些起始参数几乎是正确的,因为print(fun(initial_guess, x))返回的估计值接近参考数据(在我的实际情况中,它们比这个最小可重复示例中更接近)。
我已经尝试了许多组合的起始参数,并没有找到任何导致成功的优化。
我试着让这个函数更基本(例如,删除额外的beta项和x,只留下beta[0])。这成功地优化了(success: True),但是预测是不充分的(可能是因为函数不够复杂,无法将输入转换为相对于参考的期望输出)。
我最近已经最小化了明显比这个更复杂的函数(在这个例子中使用了与以前相同的方法),所以我很困惑为什么这个函数不起作用。

uurv41yg

uurv41yg1#

minimize不是正确的函数调用。使用curve_fit,即使没有日志步骤,它也能正常工作。此外,总是给予minimize(或curve_fit)一个合理的界限;如果你“使用了同样的方法”,并且它在过去毫无限制地工作,那只是巧合。
在某种意义上,这实际上是三维上的表面拟合,并且这样解释它没有足够的输入数据。对于这样的方案,我希望在xinput中的一个中有多个非单调跳跃。这个 * 应该 * 看起来像什么(在ix中有不同的值):

import numpy as np
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit

def fun(ix: np.ndarray, b0: float, b1: float, b2: float) -> np.ndarray:
    input_, x = ix
    return input_**b0 * b2 * x**b1

ref = np.array([0.586, 0.659, 0.73, 0.799, 0.865, 0.929, 0.991, 1.05, 1.107, 1.162])
ix = np.array((
    [70.0, 77.0, 82.0, 87.0, 93.0, 98.0, 98.0, 102.0, 106.0, 109.0],
    [6.96, 9.24, 10.92, 12.24, 13.92, 15.24, 15.24, 16.32, 17.64, 18.96],
))
initial_guess = (2, -0.5, 4e-4)
fit_param, _ = curve_fit(
    f=fun, xdata=ix, ydata=ref, p0=initial_guess,
    bounds=((-1,-1,0), (10, 10, 10)),
)
print(fit_param)

fig, ax = plt.subplots()
ax.plot(ix[0], ref, label='experiment')
ax.plot(ix[0], fun(ix, *initial_guess), label='guess')
ax.plot(ix[0], fun(ix, *fit_param), label='fit')
ax.legend()
plt.show()

相关问题