pytorch 当我使用torch.nn.DataParallel()时如何访问类对象?

e0bqpujr  于 2022-11-23  发布在  其他
关注(0)|答案(1)|浏览(265)

我想使用PyTorch和多个GPU来训练我的模型。

model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)

然后,我尝试访问在模型定义中定义的优化器:

G_opt = model.module.optimizer_G

然而,我得到了一个错误:
属性错误:'DataParallel'对象没有属性optimizer_G
我认为这与我的模型定义中优化器的定义有关。当我使用单个GPU而不使用torch.nn.DataParallel时,它可以工作。但它不适用于多个GPU,即使我使用module调用,也无法找到解决方案。
模型定义如下:

class MyModel(torch.nn.Module):
    ...
   self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))

如果你想看完整的代码,我在GitHub中使用了Pix2PixHD实现。
谢谢你,贝斯特。
编辑:我通过使用model.module.module.optimizer_G解决了这个问题。

aurhwmvo

aurhwmvo1#

使用model.module,但在运行模型之前的某个时间由于某种原因它不工作,此时使用model.module.module
祝你好运

相关问题