input: dynamic input is missing dimensions in profile

x33g5p2x  于2022-07-22 转载在 其他  
字(1.9k)|赞(0)|评价(0)|浏览(869)

input: dynamic input is missing dimensions in profile

onnx2trt代码报错:

  1. import numpy as np
  2. import tensorrt as trt
  3. import os
  4. import pycuda.driver as cuda
  5. import argparse
  6. def GiB(val):
  7. return val * 1 << 30
  8. def ONNX_build_engine(onnx_file_path, write_engine=True):
  9. # :return: engine
  10. G_LOGGER = trt.Logger(trt.Logger.WARNING)
  11. # 1、动态输入第一点必须要写的
  12. explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  13. batch_size = 8 # trt推理时最大支持的batchsize
  14. with trt.Builder(G_LOGGER) as builder, builder.create_network(explicit_batch) as network, trt.OnnxParser(network,
  15. G_LOGGER) as parser:
  16. builder.max_batch_size = batch_size
  17. config = builder.create_builder_config()
  18. config.max_workspace_size = GiB(2)
  19. config.set_flag(trt.BuilderFlag.FP16)
  20. print('Loading ONNX file from path {}...'.format(onnx_file_path))
  21. with open(onnx_file_path, 'rb') as model:
  22. print('Beginning ONNX file parsing')
  23. parser.parse(model.read())
  24. print('Completed parsing of ONNX file')
  25. print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
  26. # 重点
  27. profile = builder.create_optimization_profile() # 动态输入时候需要 分别为最小输入、常规输入、最大输入
  28. # 有几个输入就要写几个profile.set_shape 名字和转onnx的时候要对应
  29. # tensorrt6以后的版本是支持动态输入的,需要给每个动态输入绑定一个profile,用于指定最小值,常规值和最大值,如果超出这个范围会报异常。
  30. profile.set_shape("input", (1, 3, 128, 128), (4, 3, 128, 128), (16, 3, 128, 128))
  31. config.add_optimization_profile(profile)
  32. engine = builder.build_engine(network, config)
  33. print("Completed creating Engine")
  34. # 保存engine文件
  35. if write_engine:
  36. engine_file_path = 'efficientnet_b1.trt'
  37. with open(engine_file_path, "wb") as f:
  38. f.write(engine.serialize())
  39. return engine
  40. onnx_file_path = r'skipnet_0712.onnx'
  41. onnx_file_path = r'model2.onnx'
  42. onnx_file_path = r'skip_simp2.onnx'
  43. # onnx_file_path = r'mobileone_0713.onnx'
  44. write_engine = True
  45. engine = ONNX_build_engine(onnx_file_path, write_engine)

原错误代码:

profile.set_shape("inputs", (1, 3, 240, 240), (8, 3, 240, 240), (16, 3, 480, 480))

改之后代码:

profile.set_shape("inputs", (1, 3, 128, 128), (8, 3, 128, 128), (16, 3, 128, 128))

相关文章