matplotlib 使用Scikit-Learn绘制多项式回归图

e37o9pze  于 2023-10-24  发布在  其他
关注(0)|答案(1)|浏览(106)

我正在编写一个python代码,用于研究在[0,1]范围内使用函数sin(2.pi.x)的过拟合。我首先使用mu=0和sigma=1的高斯分布添加一些随机噪声来生成N个数据点。我使用M次多项式拟合模型。下面是我的代码

import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression

# generate N random points
N=30
X= np.random.rand(N,1)
y= np.sin(np.pi*2*X)+ np.random.randn(N,1)

M=2
poly_features=PolynomialFeatures(degree=M, include_bias=False)
X_poly=poly_features.fit_transform(X) # contain original X and its new features
model=LinearRegression()
model.fit(X_poly,y) # Fit the model

# Plot
X_plot=np.linspace(0,1,100).reshape(-1,1)
X_plot_poly=poly_features.fit_transform(X_plot)
plt.plot(X,y,"b.")
plt.plot(X_plot_poly,model.predict(X_plot_poly),'-r')
plt.show()

我不知道为什么我有M=2行的m次多项式线?我认为它应该是1行,不管M。你能帮我弄清楚这个问题吗?

llycmphe

llycmphe1#

多项式特征变换后的数据为(n_samples,2)形状。因此pyplot是用两列绘制预测变量。
将打印代码更改为

plt.plot(X_plot_poly[:,i],model.predict(X_plot_poly),'-r')
where i your column number

相关问题