pytorch 在GPU上保存模型并将其加载到CPU上

x6492ojm  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(201)

我有一个模型,我训练的基础上 Torch 在GPU。现在,我想上传它的CPU。我用这个代码来保存一个加载的模型。
下面是我的模型和训练阶段:

model              = VAE(input_size ,  lead,     hidden_dim,hidden_dim1,hidden_dimd,  latent_dim, device,  num_hidden= lead).to(device)

optimizer          = torch.optim.Adam(model.parameters(), lr =learning_rate )

losses, kl_loss, l_loss, validate_loss = trainv(model, device, epochs, train_iterator, optimizer, validate_iterator) 

model = VAE()
torch.save(model.state_dict(), "model.pt")

# load

device = torch.device('cpu')
model = VAE()
model.load_state_dict(torch.load(PATH, map_location=device))

错误如下:

TypeError: __init__() missing 8 required positional arguments:
wecizke3

wecizke31#

您是否在torch.save()之前定义了模型类?这段代码在Google协作中可以正常工作:

import torch
import torch.nn as nn

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()

        self.linear = nn.Linear(125, 1)

    def forward(self, src):

        output = self.linear(src)

        return

model = TheModelClass()

torch.save(model.state_dict(), "model.pt") # you need to define your model before
device = torch.device('cpu')
model = TheModelClass() # you even don't need to redefine your model
model.load_state_dict(torch.load('/content/model.pt', map_location=device))
z9smfwbn

z9smfwbn2#

错误显示为TheModelClass is not defined,这意味着您的模型的类在当前文件中不可用。您通常在一个单独的 * 模块 * 中定义一个类,例如,您的模型TheModelClass的类,需要首先导入。
因此,由于您没有提供有关代码的更多信息,我建议您在项目中的所有文件中搜索TheModelClass的定义,并将其从相应的模块导入到当前文件中。

相关问题