如何使用Scipy Optimization提高模型拟合的准确性

gorkyyrv  于 2023-05-17  发布在  其他
关注(0)|答案(1)|浏览(143)

我想对以下数据进行曲线拟合:munich_temperatures_average.txt
我试过:

import numpy as np
import matplotlib.pyplot as plt
from scipy import optimize

def func(temp, a, b, c):
    return a * np.cos(2 * np.pi * temp + b) + c

date, temperature = np.loadtxt('munich_temperatures_average.txt', unpack=True)

result = optimize.curve_fit(func, date, temperature)

plt.plot(date, temperature, '.')
plt.plot(date, func(date, result[0][0], result[0][1], result[0][2]), c='red', zorder=10)
plt.ylim([-20, 30])
plt.xlabel("Year", fontsize=18)
plt.ylabel("Temperature", fontsize=18)
plt.show()

但是在输出图中可以看到,拟合后的模型的振荡幅度似乎比实际,请问如何才能使拟合更准确?先谢谢你。

mo49yndu

mo49yndu1#

感谢@Reinderien的解释,我使用1D median filter来过滤离群值,因此模型拟合似乎是准确的:

import numpy as np
import matplotlib.pyplot as plt
from scipy import optimize
from scipy.signal import medfilt

def func(temp, a, b, c):
    return a * np.cos(2 * np.pi * temp + b) + c

date, temperature = np.loadtxt('./data/munich_temperatures_average.txt', unpack=True)
popt, pcov = optimize.curve_fit(func, date, temperature)

# Median Filter
filtered = medfilt(temperature, 21)
for i in range(10):
    filtered = medfilt(filtered, 21)

fig = plt.figure(figsize=(14, 6), dpi=80)
ax1 = fig.add_subplot(131)
# ax1.subplot(121)
ax1.hist(temperature, color = "lightblue", ec="green", bins=150, orientation="horizontal")

x1 = np.linspace(0, 0, 100)
x2 = np.linspace(20, 20, 100)
y  = np.linspace(0, 400, 100)
ax1.plot(y, x1, 'r--', linewidth=1, markersize=1)
ax1.plot(y, x2, 'r--', linewidth=1, markersize=1)

ax1.set_ylim([-20, 30])
ax1.set_ylabel("Temperature", fontsize=14)
ax1.set_xlabel("Frequency", fontsize=14)

ax2 = fig.add_subplot(132)
ax2.plot(date, temperature, '.', zorder=0, label='data', alpha=0.1)
ax2.plot(date, func(date, *popt),'m', zorder=10, label='model')
ax2.set_ylim([-20, 30])
ax2.set_xlabel("Year", fontsize=14)
# ax2.set_ylabel("Temperature", fontsize=14)
ax2.legend(loc='best')

popt2, pcov2 = optimize.curve_fit(func, date, filtered)

ax3 = fig.add_subplot(133)
ax3.plot(date, filtered, '.', zorder=0, label='filtered data')
# ax3.plot(date, func(date, *popt2),'m', zorder=10, label='model')
ax3.set_ylim([-20, 30])
ax3.set_xlabel("Year", fontsize=14)
# ax2.set_ylabel("Temperature", fontsize=14)
ax3.legend(loc='best')

相关问题