用Pytorch实现XOR

3xiyfsfu  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(100)

我想实现XOR逻辑运算符。我没有得到最佳答案。我需要三个输出列用于下一个操作。我有一个data.csv文件作为数据文件(XOR):

in1,in2,in3,in4,out1,out2,out3
0,0,0,0,0,0,0
0,0,0,1,0,0,1
0,0,1,0,0,0,1
0,0,1,1,0,0,0
0,1,0,0,0,0,1
0,1,0,1,0,0,0
0,1,1,0,0,0,0
1,0,0,0,0,0,1
1,0,0,1,0,0,0
1,0,1,0,0,0,0
1,0,1,1,0,0,1
1,1,0,0,0,0,0
1,1,0,1,0,0,1
1,1,1,0,0,0,1

字符串
Python代码是:

import torch 
import torch.nn as nn
import pandas as pd
import numpy as np

# Defining input size, hidden layer size, output size and batch size respectively
n_in, n_h, n_out = 4, 5, 3

# Create dummy input and target tensors (data)
df = pd.read_csv('data.csv')

input_cols = ['in1', 'in2', 'in3', 'in4'] 
output_cols = ['out1', 'out2', 'out3']

input_np_array = df[input_cols].to_numpy()
target_np_array = df[output_cols].to_numpy()

inputs = torch.tensor(input_np_array, dtype=torch.float32)
targets = torch.tensor(target_np_array, dtype=torch.float32)

model = nn.Sequential(
    nn.Linear(n_in, n_h),
    nn.Sigmoid(),
    nn.Linear(n_h, n_out),
    nn.Sigmoid())

# Construct the loss function
criterion = torch.nn.MSELoss()
# Construct the optimizer (Stochastic Gradient Descent in this case)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

# Gradient Descent
for epoch in range(500000):
   # Forward pass: Compute predicted y by passing x to the model
   y_pred = model(inputs)

   # Compute and print loss
   loss = criterion(y_pred, targets)
   if epoch % 100==0:
     print('epoch: ', epoch,' loss: ', loss.item())

   # Zero gradients, perform a backward pass, and update the weights.
   optimizer.zero_grad()

   # perform a backward pass (backpropagation)
   loss.backward()

   # Update the parameters
   optimizer.step()

test_data = torch.tensor([1,1,1,1], dtype=torch.float32) # 0,0,0
output = model(test_data)
probabilities = torch.nn.functional.sigmoid(output)
# probabilities = torch.nn.functional.softmax(output, dim=0)
print(probabilities)

test_data = torch.tensor([0,1,1,1], dtype=torch.float32) # 0,0,1
output = model(test_data)
probabilities = torch.nn.functional.sigmoid(output)
print(probabilities)


输出为:

tensor([0.5058, 0.5060, 0.6247], grad_fn=<SigmoidBackward0>)
tensor([0.5052, 0.5058, 0.6214], grad_fn=<SigmoidBackward0>)


为什么我得不到最佳答案?

p8h8hvxi

p8h8hvxi1#

你可以这样修改你的代码:

import torch 
import torch.nn as nn
import pandas as pd

# Defining input size, hidden layer size, output size, and batch size respectively
n_in, n_h, n_out = 4, 5, 3

# Create dummy input and target tensors (data)
df = pd.read_csv('data.csv')

input_cols = ['in1', 'in2', 'in3', 'in4'] 
output_cols = ['out1', 'out2', 'out3']

input_np_array = df[input_cols].to_numpy()
target_np_array = df[output_cols].to_numpy()

inputs = torch.tensor(input_np_array, dtype=torch.float32)
targets = torch.tensor(target_np_array, dtype=torch.float32)

model = nn.Sequential(
    nn.Linear(n_in, n_h),
    nn.Sigmoid(),
    nn.Linear(n_h, n_out),
    nn.Sigmoid()  # Using Sigmoid activation for the output layer
)

# Construct the loss function
criterion = torch.nn.BCELoss()  # Binary Cross Entropy Loss for binary classification
# Construct the optimizer (Stochastic Gradient Descent in this case)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Gradient Descent
for epoch in range(500000):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(inputs)

    # Compute and print loss
    loss = criterion(y_pred, targets)
    if epoch % 100 == 0:
        print('epoch: ', epoch, ' loss: ', loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()

    # perform a backward pass (backpropagation)
    loss.backward()

    # Update the parameters
    optimizer.step()

# Testing
with torch.no_grad():
    test_data = torch.tensor([1, 1, 1, 1], dtype=torch.float32)
    output = model(test_data)
    probabilities = (output >= 0.5).float()
    print(probabilities)

    test_data = torch.tensor([0, 1, 1, 1], dtype=torch.float32)
    output = model(test_data)
    probabilities = (output >= 0.5).float()
    print(probabilities)

字符串

相关问题