pytorch中绘图训练和验证损失

pn9klfpd  于 2022-12-13  发布在  其他
关注(0)|答案(1)|浏览(195)

我正在使用pytorch来训练我的CNN网络。我想绘制我的训练和验证损失曲线来可视化模型性能。我如何绘制两条曲线?
我有以下代码

# create a function (this my favorite choice)
def RMSELoss(predicted,target):
    return torch.sqrt(torch.mean((predicted-target)**2))

criterion = RMSELoss

# loss = torch.sqrt(criterion(x, y))
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
epochs = 300

n_total_steps = len(train_dataset)

trainingEpoch_loss = []
validationEpoch_loss = []

for epoch in range(epochs):
    step_loss = []
    model.train()
    for i, data in enumerate(train_dataset):
        feature,target = data['data'].type(torch.FloatTensor),torch.tensor(data['target']).type(torch.FloatTensor)
         
        # Clear the gradients
        optimizer.zero_grad()
        # Forward Pass
        outputs = model(feature)
        # Find the Loss
        training_loss = criterion(outputs, target)
        # Calculate gradients
        training_loss.backward()
        # Update Weights
        optimizer.step()
        # Calculate Loss
        step_loss.append(training_loss.item())
        if (i+1) % 1 == 0:
            print (f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{n_total_steps}], Loss: {training_loss.item():.4f}')
    trainingEpoch_loss.append(np.array(step_loss).mean())
 
    model.eval()     # Optional when not using Model Specific layer
    for i, data in enumerate(val_dataset):
        validationStep_loss = []
        feature,target = data['data'].type(torch.FloatTensor),torch.tensor(data['target']).type(torch.FloatTensor)
        
        # Forward Pass
        outputs = model(feature)
        # Find the Loss
        validation_loss = criterion(outputs, target)
        # Calculate Loss
        validationStep_loss.append(validation_loss.item())
    validationEpoch_loss.append(np.array(validationStep_loss).mean())

你能让我知道我做的对不对吗?也请让我知道如何策划训练和验证损失?

fnatzsnv

fnatzsnv1#

您在trainingEpoch_lossvalidationEpoch_loss列表中收集历元损耗是正确的。现在,在训练之后,添加代码来绘制损耗:

from matplotlib import pyplot as plt
plt.plot(trainingEpoch_loss, label='train_loss')
plt.plot(validationEpoch_loss,label='val_loss')
plt.legend()
plt.show

请阅读matplotlib文档以获得更多有趣绘图功能。

相关问题