bounty将在5天后过期。回答此问题可获得+50的声望奖励。Jaffer Wilson希望吸引更多人关注此问题:请让我知道我如何才能找到这个点所在的区域。提供这个问题的解决方案,我将奖励对我有帮助的人。
我有一个Sklearn的示例代码取自网站。我试图学习如何使用Sklearn(Scikit-Learn)对点进行分类。以下是代码:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_moons, make_circles, make_classification
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn.inspection import DecisionBoundaryDisplay
names = [
"Nearest Neighbors",
]
classifiers = [
KNeighborsClassifier(3),
]
X, y = make_classification(
n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1
)
rng = np.random.RandomState(2)
X += 2 * rng.uniform(size=X.shape)
linearly_separable = (X, y)
datasets = [
linearly_separable,
]
figure = plt.figure(figsize=(27, 9))
i = 1
# iterate over datasets
for ds_cnt, ds in enumerate(datasets):
# preprocess dataset, split into training and test part
X, y = ds
X = StandardScaler().fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.4, random_state=42
)
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
# just plot the dataset first
cm = plt.cm.RdBu
cm_bright = ListedColormap(["#FF0000", "#0000FF"])
ax = plt.subplot(len(datasets), len(classifiers) + 1, i)
if ds_cnt == 0:
ax.set_title("Input data")
# Plot the training points
ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors="k")
# Plot the testing points
ax.scatter(
X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.6, edgecolors="k"
)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xticks(())
ax.set_yticks(())
i += 1
# iterate over classifiers
for name, clf in zip(names, classifiers):
ax = plt.subplot(len(datasets), len(classifiers) + 1, i)
clf.fit(X_train, y_train)
score = clf.score(X_test, y_test)
All_Value_Response = DecisionBoundaryDisplay.from_estimator(
clf, X, cmap=cm, alpha=0.8, ax=ax, eps=0.5
)
# Plot the training points
ax.scatter(
X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors="k"
)
# Plot the testing points
ax.scatter(
X_test[:, 0],
X_test[:, 1],
c=y_test,
cmap=cm_bright,
edgecolors="k",
alpha=0.6,
)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xticks(())
ax.set_yticks(())
if ds_cnt == 0:
ax.set_title(name)
ax.text(
x_max - 0.3,
y_min + 0.3,
("%.2f" % score).lstrip("0"),
size=15,
horizontalalignment="right",
)
i += 1
plt.tight_layout()
plt.show()
输出如下:
现在我们可以看到,形成的区域不是规则的形状,因此,要了解如何知道新点是否到达以及将位于哪个区域变得有点困难。我设法捕获了区域的数据(All_Value_Response
变量存储了这些信息),但似乎对我没有帮助。
所以我想知道,如果我想知道点(1,3)
在哪个区域,那么我怎么通过代码来推导它,我可以通过在图上看来做,但是如何用代码来做呢?
请帮我找到解决问题的方法。
3条答案
按热度按时间2nc8po8w1#
试试这个
我在中找到了你所问问题的参考资料:https://stackoverflow.com/a/74613354/4948889
以上代码的输出如下所示:
4dbbbstv2#
所以,你有
X_train
和X_test
,这两个列表都包含元组,元组(a,b)中的值有一定的范围,比如0 -〉1,在你的图中,这是点的x和y坐标。还有
y_train
和y_test
。它们是X_train
和X_test
中所有元组的已知分类。这些值可以是0或1,不能介于两者之间。如果图形中的点位于蓝色区域,则意味着该点的预测值如果点在红色区域中,这意味着预测值是1。如果你在这个(但通常是更多的数据)上训练一个分类器,那么你可以问它任何一点(a,b),它会告诉你0或1(也就是蓝色或红色)。
例如,我预测一个点(a,b),它在X_train中没有看到(也就是X_test中的东西):
result
则等于:[0]
。这是因为看你的图,假设x轴和y轴的范围是0 -〉1。元组(0.2,0,2)落在蓝色区域。它知道这一点,因为它已经从
X_train
和X_test
中学习了图中的蓝红分类,所以当它得到新的元组时,它会看到点落在哪个区域,并将其分类为0或1,区域蓝或区域红。总而言之。彩色区域显示了任何给定元组(a,b)的预测值。点的位置(在散点图中)由元组中的值(a,b)给出。元组的a和b在0-〉1之间。颜色不是一个范围,而是一个分类0或1。
希望能有所帮助,祝你好运!
ymdaylpp3#
我们可以确定新的点在哪个区域,但在此之前,我想提醒大家注意,你在代码中所做的一些事情,这些事情会反过来影响你。
这条线会回来狠狠地打你。
请记住,使用StandardScaler的目的是将数据标准化(0均值,单位方差)。还请记住,在训练集上执行的操作必须在测试集上执行。这里需要注意的是,在测试集上执行的操作将从训练集学习。我将提供一个精简的代码来帮助说明这一点。
因此,尽管在这两种情况下都可以使用
predict
,但您需要确保将转换应用于正在输入的数据点。