pytorch 将从GPU获得的pickle文件加载到CPU

vlf7wbxs  于 2024-01-09  发布在  其他
关注(0)|答案(2)|浏览(227)

我在加载pickle文件到CPU的时候遇到了一个问题,我在网上搜索了一下,他们说我需要添加map_location参数,但是我添加了这个参数之后,问题仍然存在。
守则如下:

torch.__version__
torch.load('featurs.pkl',map_location='cpu')

>>>

'1.0.1.post2'
Attempting to deserialize object on a CUDA device 
but torch.cuda.is_available() is False. If you are running 
on a CPU-only machine, please use torch.load with map_location='cpu' 
to map your storages to the CPU.

字符串
我知道这是因为不同的设备,但我使用错误信息中的指令,所以我不知道下一步如何解决它。
提前感谢!

xn1cxnb4

xn1cxnb41#

错误消息建议使用map_location=torch.device('cpu'),但即使这样也不起作用。一种解决方法是使用pickle库并实现一个自定义的unpickler。

import pickle
import io

class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else:
            return super().find_class(module, name)

#contents = pickle.load(f) becomes...
contents = CPU_Unpickler(f).load()

字符串
来源:Github

qq24tv8q

qq24tv8q2#

试试这个:

torch.load('featurs.pkl',map_location=torch.device('cpu'))

字符串

相关问题