我有一个模型,我训练的基础上 Torch 在GPU。现在,我想上传它的CPU。我用这个代码来保存一个加载的模型。
下面是我的模型和训练阶段:
model = VAE(input_size , lead, hidden_dim,hidden_dim1,hidden_dimd, latent_dim, device, num_hidden= lead).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr =learning_rate )
losses, kl_loss, l_loss, validate_loss = trainv(model, device, epochs, train_iterator, optimizer, validate_iterator)
model = VAE()
torch.save(model.state_dict(), "model.pt")
# load
device = torch.device('cpu')
model = VAE()
model.load_state_dict(torch.load(PATH, map_location=device))
错误如下:
TypeError: __init__() missing 8 required positional arguments:
2条答案
按热度按时间wecizke31#
您是否在
torch.save()
之前定义了模型类?这段代码在Google协作中可以正常工作:z9smfwbn2#
错误显示为
TheModelClass is not defined
,这意味着您的模型的类在当前文件中不可用。您通常在一个单独的 * 模块 * 中定义一个类,例如,您的模型TheModelClass
的类,需要首先导入。因此,由于您没有提供有关代码的更多信息,我建议您在项目中的所有文件中搜索
TheModelClass
的定义,并将其从相应的模块导入到当前文件中。