在Pytorch中计算反向传播时如何加入学习延迟

rwqw0loc  于 2023-02-12  发布在  其他
关注(0)|答案(1)|浏览(138)

我听说在做反向传播的时候,权值会用学习后期和偏导数来更新。
但是我不知道在向后传播代码中的learning_late参数应该放在哪里,而且我想知道如果没有学习延迟的设置,默认的学习延迟是什么?
所以,这里是代码,我想学习.

import torch
import torch.nn as nn
import torch.optim as optim

class MyNeuralNetwork(nn.Module):
    def __init__(self):
        super(MyNeuralNetwork, self).__init__()
        layer_1=nn.Linear(in_features=2, out_features=2, bias=False)
        weight_1 = torch.tensor([[.3,.25],[.4, .35]])
        
        layer_1.weight = nn.Parameter(weight_1)
        self.layer1 = nn.Sequential(
            layer_1,
            nn.Sigmoid()
        )
        
        layer_2 = nn.Linear(in_features=2, out_features=2, bias=False)
        weight_2 = torch.tensor([[.45, .4],[.7, .6]])
        
        layer_2.weight = nn.Parameter(weight_2)
        self.layer2 = nn.Sequential(
            layer_2,
            nn.Sigmoid()
        )
    
    def forward(self, input):
        output = self.layer1(input)
        output = self.layer2(output)
        
        return output
    
model = MyNeuralNetwork().to("cpu")
print(model)

input = torch.tensor([0.1,0.2]).reshape(1,-1)
target = torch.tensor([0.4,0.6]).reshape(1,-1)

out = model(input)
print(f"output value : {out}")

criterion = nn.MSELoss()
loss = criterion(out, target)
print(f"loss value : {loss}")

model.zero_grad() 
print('↓ layer1.weight before backward propagation ↓')
print(model._modules['layer1']._modules['0'].weight)
print(model._modules['layer2']._modules['0'].weight)
print()

loss.backward() # where can I put the learning late in back propagation.
print('↓ layer1.weight after backward propagation ↓')
print(model._modules['layer1']._modules['0'].weight)
print(model._modules['layer2']._modules['0'].weight)

我的问题的重点是如何添加学习晚,我想为培训这种模式。

xuo3flqw

xuo3flqw1#

答案是当你调用优化器的阶跃函数来更新模型的权重时,你需要使用优化器的学习速率参数。具体来说,当你创建一个优化器时,你需要像这样指定一个学习速率:

optimizer = optim.SGD(model.parameters(), lr=learning_rate)

然后,当你执行一个训练/测试循环时,你将在向后传递之后调用优化器的步骤函数,如下所示:

optimizer.step()

此阶跃函数将使用您指定的学习速率更新模型中的权重。默认学习速率通常为0.01,但您可以根据特定需要更改它。

相关问题