深度学习模型错误:已保存的中间结果已被释放

iszxjhcz  于 2021-08-20  发布在  Java
关注(0)|答案(0)|浏览(223)

我当前在培训lstm模型时遇到此错误:

RuntimeError: Trying to backward through the graph a second time, 
          but the saved intermediate results have already been freed. 
          Specify retain_graph=True when calling backward the first time.

我先给每一个元素喂食 x_train 变成一个cnn模型来训练它。这将产生1221个元素的输出。然后将这些1221个元素逐个输入lstm模型进行训练(代码如下)
我已经指定了 retain_graph=True 正如错误所说,但它仍然不起作用。
我在网上找到的另一个解决方案是尝试 detach_() 函数,我也这么做了,但它仍然不起作用。我在这一点上有点卡住了,所以任何建议都将不胜感激!
我使用的lstm模型如下所示:

class LSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
    super(LSTM, self).__init__()
    self.hidden_dim= hidden_dim
    self.num_layers= num_layers

    self.lstm= nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
    self.lin= nn.Linear(hidden_dim, output_dim)

def forward(self, x):
    hidden0= torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
    cell0= torch.zeros(self.num_layers, x.size(0), self.hidden_dim)

    out, (hn, cn)= self.lstm(x, (hidden0, cell0))
    hn= hn.view(-1, self.hidden_dim)
    out = F.relu(self.lin(hn))

    return out

lstm = LSTM(120, 5, 1, 5)

培训代码为:

num_train= 4
num_epoch= len(x_train) # equal to 1221 bc there are 1221 elements in x_train
criterion= nn.CrossEntropyLoss()
optimizer= optim.Adam(lstm.parameters(), lr=1e-5)
outputListLSTM= [] # stores a list of all the outputs of the model

for trainEpoch in range(num_train):

    running_loss= 0.0
    for epoch in range(num_epoch):
        # input and label
        Input= x_train[epoch]
        Label= torch.tensor(y_train[epoch])

        outputsLSTM.detach_()
        outputsLSTM= lstm.forward(Input)
        optimizer.zero_grad()

        if trainEpoch == num_train-1:
            outputListLSTM.append(outputsLSTM) # add output to outputListLSTM

        loss= criterion(outputsLSTM, Label.unsqueeze(0).type(torch.LongTensor))

        loss.backward(retain_graph=True)

        optimizer.step()

        # print statistics
        running_loss += loss.item()
    print('Epoch: ', trainEpoch, ' Loss: ', running_loss / 2000)

print('Finished Training')

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题