pytorch 如何加载.pth文件?

qij5mzcb  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(260)

我从[this repository][1]得到了一个pytorch模型,我必须把它转换成tflite。代码如下:
第一个
到这里为止,一切都运行得很顺利。但是当我运行下面的单元格时:

torch.onnx.export(
        model=torch_model,
        args=sample_input,
        f=ONNX_FILE,
        verbose=False,
        export_params=True,
        do_constant_folding=False,  # fold constant values for optimization
        input_names=['input'],
        opset_version=10,
        output_names=['output']
)

onnx_model = onnx.load(ONNX_FILE)

onnx.checker.check_model(onnx_model)

完整的错误日志:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-33-15df717ec276> in <module>
      8         input_names=['input'],
      9         opset_version=10,
---> 10         output_names=['output']
     11 )
     12 

~\anaconda3\envs\py36\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
    274                         do_constant_folding, example_outputs,
    275                         strip_doc_string, dynamic_axes, keep_initializers_as_inputs,
--> 276                         custom_opsets, enable_onnx_checker, use_external_data_format)
    277 
    278 

~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
     92             dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
     93             custom_opsets=custom_opsets, enable_onnx_checker=enable_onnx_checker,
---> 94             use_external_data_format=use_external_data_format)
     95 
     96 

~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference, use_new_jit_passes)
    677         _set_opset_version(opset_version)
    678         _set_operator_export_type(operator_export_type)
--> 679         with select_model_mode_for_export(model, training):
    680             val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
    681                                                              operator_export_type,

~\anaconda3\envs\py36\lib\contextlib.py in __enter__(self)
     79     def __enter__(self):
     80         try:
---> 81             return next(self.gen)
     82         except StopIteration:
     83             raise RuntimeError("generator didn't yield") from None

~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in select_model_mode_for_export(model, mode)
     36 def select_model_mode_for_export(model, mode):
     37     if not isinstance(model, torch.jit.ScriptFunction):
---> 38         is_originally_training = model.training
     39 
     40         if mode is None:

AttributeError: 'collections.OrderedDict' object has no attribute 'training'

当我使用torch.onnx.export()时发生这个错误。
请让我知道这里出了什么问题。我没有正确加载权重吗?如果没有,我该如何加载模型?我不知道类或体系结构的细节,所以我该如何使用model.load_state_dict()?

[1]: https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
kmbjn2e3

kmbjn2e31#

pytorch中的.pth二进制文件并不存储模型,而只存储它的训练权重。你需要importclasstorch.nn.Module的派生class)来实现模型的 * 功能 *。一旦你有了功能,你就可以加载训练权重来获得模型的一个特定示例。

相关问题