opencv 基于Ransac的二次模型估计

d8tt03nd  于 2023-10-24  发布在  其他
关注(0)|答案(1)|浏览(117)

我尝试使用Ransac进行模型拟合,根据以下示例:https://scikit-image.org/docs/dev/auto_examples/transform/plot_ransac.html#sphx-glr-auto-examples-transform-plot-ransac-py
根据https://scikit-image.org/docs/0.13.x/api/skimage.measure.html#skimage.measure.LineModelND,如果我选择model_class作为LineModel,它会用标准的线模型y = ax + B来拟合我的数据。相反,我想用二次函数y = ax^2 + B*x + c来拟合我的数据。有没有办法用scikit-image或opencv库来做到这一点?

2ul0zpep

2ul0zpep1#

下面是一个简单的例子:

  1. import pandas as pd
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from sklearn.linear_model import LinearRegression, RANSACRegressor
  5. from sklearn.preprocessing import PolynomialFeatures
  6. from sklearn.metrics import r2_score
  7. # Get example sea-ice data
  8. df = pd.read_csv(
  9. "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/seaice.csv"
  10. )
  11. df["index1"] = df.index
  12. # Define X and y
  13. start=2950
  14. X = (df[["index1"]].values)[start:start+100]
  15. y = (df[["Extent"]].values)[start:start+100]
  16. # Define RANSAC regressor
  17. ransac = RANSACRegressor(
  18. LinearRegression(),
  19. max_trials=100,
  20. min_samples=50,
  21. residual_threshold=0.15,
  22. random_state=0,
  23. )
  24. # Fit RANSAC model to data
  25. quadratic = PolynomialFeatures(degree=2)
  26. X_quad = quadratic.fit_transform(X)
  27. ransac = ransac.fit(X_quad, y)
  28. # Get fitted RANSAC curve
  29. X_fit = np.arange(X.min(), X.max(), 1)[:, np.newaxis]
  30. y_quad_fit = ransac.predict(quadratic.fit_transform(X_fit))
  31. # Get R2 value
  32. quadratic_r2 = r2_score(y, ransac.predict(X_quad))
  33. # Plot inliers
  34. inlier_mask = ransac.inlier_mask_
  35. plt.scatter(X[inlier_mask], y[inlier_mask], c="blue", marker="o", label="Inliers")
  36. # Plot outliers
  37. outlier_mask = np.logical_not(inlier_mask)
  38. plt.scatter(
  39. X[outlier_mask], y[outlier_mask], c="lightgreen", marker="s", label="Outliers"
  40. )
  41. # Plot fitted RANSAC curve
  42. plt.plot(
  43. X_fit,
  44. y_quad_fit,
  45. label="quadratic (d=2), $R^2=%.2f$" % quadratic_r2,
  46. color="red",
  47. lw=2,
  48. linestyle="-",
  49. )
  50. plt.xlabel("X")
  51. plt.ylabel("Sea ice extent")
  52. plt.legend(loc="upper left")
  53. plt.tight_layout()
  54. plt.show()

展开查看全部

相关问题