pytorch 如何在ConvNext small中添加自定义块

zhte4eai  于 2024-01-09  发布在  其他
关注(0)|答案(2)|浏览(188)

我在ConvNext small中添加了一个自定义的层块。我想用预先训练好的权重来训练它,但出现了错误

Traceback (most recent call last):
  File "C:\Users\Ali\PycharmProjects\pythonProject1\ConvNext_Custom_FromGitHubCode.py", line 187, in <module>
    model = convnext_small(pretrained=True, in_22k=False)
  File "C:\Users\Ali\PycharmProjects\pythonProject1\ConvNext_Custom_FromGitHubCode.py", line 184, in convnext_small
    model.load_state_dict(checkpoint["model"])
  File "C:\Users\Ali\anaconda3\envs\py38\lib\site-packages\torch\nn\modules\module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ConvNeXt:
    Missing key(s) in state_dict: "custom_block.conv1.weight", "custom_block.conv1.bias", "custom_block.conv2.weight", "custom_block.conv2.bias", "custom_block.multihead_attention.in_proj_weight", "custom_block.multihead_attention.in_proj_bias", "custom_block.multihead_attention.out_proj.weight", "custom_block.multihead_attention.out_proj.bias", "custom_block.linear.weight", "custom_block.linear.bias".

个字符

h7appiyu

h7appiyu1#

您可以使用strict=False禁用重量检查:

model.load_state_dict(checkpoint["model"], strict=False)

字符串
这将禁用不兼容和丢失的关键字检查在加载状态指令。

a64a0gku

a64a0gku2#

您可以使用following更新convnext_small函数:

def convnext_small(pretrained=False, in_22k=False, **kwargs):
    model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
    if pretrained:
        url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model_dict = model.state_dict()

        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict) 
        # 3. load the new state dict
        model.load_state_dict(model_dict)
    return model

字符串
上面的代码将通过过滤掉不匹配的键并加载权重来更新state_dict

相关问题