pytorch 加载不对齐预训练

x33g5p2x  于2021-11-15 转载在 其他  
字(0.6k)|赞(0)|评价(0)|浏览(308)

以前改变网络通道数,需要重新从头训练,无法加载预训练,今天研究了一下如何改变网络通道后,还有预训练模型可用,这样可以减少980%的训练时间,提供训练效率。

废话不说,直接上代码:

这个代码加载预训练模型后,再训练无效果:

  1. backbone = MobileFace_83_w(256,l_size=[2,6,8,4]).to(0)
  2. backbone_pth = os.path.join("/data/408800_net.pth")
  3. state_dict=torch.load(backbone_pth, map_location=torch.device(0))
  4. # backbone.load_state_dict(state_dict,strict=False)
  5. bone_dict=backbone.state_dict()
  6. # model_end=Model_end(256).to(0)
  7. new_state_dict = OrderedDict()
  8. for k, v in state_dict.items():
  9. head = k[:7]
  10. if head == 'module.':
  11. tmp_name = k[7:] # remove `module.`
  12. else:
  13. tmp_name = k
  14. # continue
  15. need_v= bone_dict[tmp_name]
  16. if len(need_v.size())==1:
  17. if need_v.size(0)>v.size(0):

相关文章