opencv 基于Ransac的二次模型估计

d8tt03nd  于 2023-10-24  发布在  其他
关注(0)|答案(1)|浏览(114)

我尝试使用Ransac进行模型拟合,根据以下示例:https://scikit-image.org/docs/dev/auto_examples/transform/plot_ransac.html#sphx-glr-auto-examples-transform-plot-ransac-py
根据https://scikit-image.org/docs/0.13.x/api/skimage.measure.html#skimage.measure.LineModelND,如果我选择model_class作为LineModel,它会用标准的线模型y = ax + B来拟合我的数据。相反,我想用二次函数y = ax^2 + B*x + c来拟合我的数据。有没有办法用scikit-image或opencv库来做到这一点?

2ul0zpep

2ul0zpep1#

下面是一个简单的例子:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression, RANSACRegressor
from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import r2_score

# Get example sea-ice data
df = pd.read_csv(
    "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/seaice.csv"
)
df["index1"] = df.index

# Define X and y
start=2950
X = (df[["index1"]].values)[start:start+100]
y = (df[["Extent"]].values)[start:start+100]

# Define RANSAC regressor
ransac = RANSACRegressor(
    LinearRegression(),
    max_trials=100,
    min_samples=50,
    residual_threshold=0.15,
    random_state=0,
)

# Fit RANSAC model to data
quadratic = PolynomialFeatures(degree=2)
X_quad = quadratic.fit_transform(X)
ransac = ransac.fit(X_quad, y)

# Get fitted RANSAC curve
X_fit = np.arange(X.min(), X.max(), 1)[:, np.newaxis]
y_quad_fit = ransac.predict(quadratic.fit_transform(X_fit))

# Get R2 value
quadratic_r2 = r2_score(y, ransac.predict(X_quad))

# Plot inliers
inlier_mask = ransac.inlier_mask_
plt.scatter(X[inlier_mask], y[inlier_mask], c="blue", marker="o", label="Inliers")

# Plot outliers
outlier_mask = np.logical_not(inlier_mask)
plt.scatter(
    X[outlier_mask], y[outlier_mask], c="lightgreen", marker="s", label="Outliers"
)

# Plot fitted RANSAC curve
plt.plot(
    X_fit,
    y_quad_fit,
    label="quadratic (d=2), $R^2=%.2f$" % quadratic_r2,
    color="red",
    lw=2,
    linestyle="-",
)

plt.xlabel("X")
plt.ylabel("Sea ice extent")
plt.legend(loc="upper left")
plt.tight_layout()
plt.show()

相关问题