python 当将pytorch模块移动到cuda时,具有缓冲区dict的pytorch模块出现意外行为

ivqmmu1c  于 2023-04-19  发布在  Python
关注(0)|答案(1)|浏览(216)

我有一个pytorch(版本1.10.0+cu111)模块,其中Tensor的dict被注册为缓冲区。起初一切看起来都很好,我可以通过dict或通过注册的名称访问该dict中的Tensor:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        self.tensor_buffer_dict = {
            'a': torch.randn(10),
        }

        for k, v in self.tensor_buffer_dict.items():
            self.register_buffer(f'buffer_{k}', v)

model = MyModel()
print(id(model.buffer_a))
print(id(model.tensor_buffer_dict['a']))
# output is
# 140361680762656
# 140361680762656

但模型移到cuda后,它们就不再一样了:

model.cuda()
print(id(model.buffer_a))
print(id(model.tensor_buffer_dict['a']))
# output is
# 140361680361968
# 140361680762656

既然model.buffer_amodel.tensor_buffer_dict['a']在迁移到cuda之前引用的是完全相同的Tensor,那么为什么它们在迁移到cuda之后会变得不同呢?

vatpfxk5

vatpfxk51#

当你调用model.cuda()时,与model.buffer_a关联的Tensor被移动到GPU内存中,并在GPU内存中创建一个新的Tensor。与model.tensor_buffer_dict['a']关联的Tensor仍然在CPU内存中。因此,在调用cuda()之后,这两个变量不再指向同一个Tensor对象。要解决这个问题,添加如下内容:

def forward(self, x):
    a = self.get_buffer('buffer_a')

相关问题