我从[this repository][1]得到了一个pytorch模型,我必须把它转换成tflite。代码如下:
第一个
到这里为止,一切都运行得很顺利。但是当我运行下面的单元格时:
torch.onnx.export(
model=torch_model,
args=sample_input,
f=ONNX_FILE,
verbose=False,
export_params=True,
do_constant_folding=False, # fold constant values for optimization
input_names=['input'],
opset_version=10,
output_names=['output']
)
onnx_model = onnx.load(ONNX_FILE)
onnx.checker.check_model(onnx_model)
完整的错误日志:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-33-15df717ec276> in <module>
8 input_names=['input'],
9 opset_version=10,
---> 10 output_names=['output']
11 )
12
~\anaconda3\envs\py36\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
274 do_constant_folding, example_outputs,
275 strip_doc_string, dynamic_axes, keep_initializers_as_inputs,
--> 276 custom_opsets, enable_onnx_checker, use_external_data_format)
277
278
~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
92 dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
93 custom_opsets=custom_opsets, enable_onnx_checker=enable_onnx_checker,
---> 94 use_external_data_format=use_external_data_format)
95
96
~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference, use_new_jit_passes)
677 _set_opset_version(opset_version)
678 _set_operator_export_type(operator_export_type)
--> 679 with select_model_mode_for_export(model, training):
680 val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
681 operator_export_type,
~\anaconda3\envs\py36\lib\contextlib.py in __enter__(self)
79 def __enter__(self):
80 try:
---> 81 return next(self.gen)
82 except StopIteration:
83 raise RuntimeError("generator didn't yield") from None
~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in select_model_mode_for_export(model, mode)
36 def select_model_mode_for_export(model, mode):
37 if not isinstance(model, torch.jit.ScriptFunction):
---> 38 is_originally_training = model.training
39
40 if mode is None:
AttributeError: 'collections.OrderedDict' object has no attribute 'training'
当我使用torch.onnx.export()时发生这个错误。
请让我知道这里出了什么问题。我没有正确加载权重吗?如果没有,我该如何加载模型?我不知道类或体系结构的细节,所以我该如何使用model.load_state_dict()?
[1]: https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
1条答案
按热度按时间kmbjn2e31#
pytorch中的
.pth
二进制文件并不存储模型,而只存储它的训练权重。你需要import
class
(torch.nn.Module
的派生class
)来实现模型的 * 功能 *。一旦你有了功能,你就可以加载训练权重来获得模型的一个特定示例。