matplotlib 绘制线性分类器输出

hfyxw5xn  于 2023-03-30  发布在  其他
关注(0)|答案(1)|浏览(129)

我尝试在不使用任何API的情况下编写一个小的线性分类器代码,以理解线性分类器的逻辑。我的代码如下:

import numpy as np
 import matplotlib.pyplot as plt 
    
    # Generate a random dataset with 2 features and two classes

np.random.seed(42)
x = np.random.randn(100,2)
y = np.concatenate([np.ones(50), -1*np.ones(50)])

# Generate a random test set with the same number of features

    X_test = np.random.randn(50,2)
    y_test = np.concatenate([np.ones(25), -1*np.ones(25)])
    
    
    
    # Define a function to train a linear classifier on the data
    
    def linear_classifier(X, y, learning_rate=0.01, num_epochs=100):
        num_features = X.shape[1] 
        weights = np.zeros(num_features)
        bias = 0
        
        for epoch in range(num_epochs):
            for i in range(X.shape[0]):
                linear_output = np.dot(X[i], weights) + bias
               
                y_pred = np.sign(linear_output)
                
               
                error = y[i] - y_pred
                # print("The value of error=", error)
                weights = weights + learning_rate * error * X[i]
    
                bias += learning_rate * error
                
        return weights, bias
    
    
    # Train the linear classifier on the training set 
    weights, bias = linear_classifier(X, y) 
    
    # Apply the learned weights and bias to the test set
    linear_output = np.dot(X_test, weights) + bias
    
    
    y_pred = np.sign(linear_output)
    
    # Compute the accuracy of the classifier on the test set
    accuracy = np.mean(y_pred == y_test)
    print("Accuracy:", accuracy)

我在各种链接中看到,为了绘制二维数据,他们使用一个称为散点图的函数。谁能指导我如何使用散点图。我尝试使用下面的语句,但我无法使其工作。

# Plot the data points with different colors for different classes
plt.scatter(X[:, 0], X[:, 1], y)

谁能教我如何用超平面绘制线性分类器特征。

编辑我得到的错误是:

Output exceeds the size limit. Open the full output data in a text editor
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[77], line 2
      1 # Plot the data points with different colors for different classes
----> 2 plt.scatter(X[:, 0], X[:, 1], y)
brccelvz

brccelvz1#

您只需要提供y作为“c”的参数
plt.scatter(x[:, 0], x[:, 1], c=y)
此外,您需要重新定义“x”为“X”,或者将代码中的“X”更改为“x”。
Here's the desired output

相关问题