最新下载
热门教程
- 1
 - 2
 - 3
 - 4
 - 5
 - 6
 - 7
 - 8
 - 9
 - 10
 
解决Pytorch修改预训练模型时遇到key不匹配的情况方法
时间:2022-06-25 01:57:32 编辑:袖梨 来源:一聚教程网
本篇文章小编给大家分享一下解决Pytorch修改预训练模型时遇到key不匹配的情况方法,文章代码介绍的很详细,小编觉得挺不错的,现在分享给大家供大家参考,有需要的小伙伴们可以来看看。
最近想着修改网络的预训练模型vgg.pth,但是发现当加载预训练模型权重到新建的模型并保存之后。
在使用新赋值的网络模型时出现了key不匹配的问题
#加载后保存(未修改网络) base_weights = torch.load(args.save_folder + args.basenet) ssd_net.vgg.load_state_dict(base_weights) torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights) 
此时会如下出错误:
Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)
…
RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.
说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”
我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。
现在的问题是因为自己定义保存的模型key参数多了一个前缀。
可以通过如下语句进行修改,并加载
from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict) 
此时就不会再出错了。
相关文章
- 原神杜林圣遗物选择推荐 11-04
 - 百度网盘SVIP激活码能用的有哪些 百度网盘vip免费领取 11-04
 - 打个螺丝兑换码能用的有哪些 2025最新有效兑换码汇总 11-04
 - 抓大鹅有效兑换码有哪些 2025最新可用兑换码大全 11-04
 - 密室出逃兑换码能用的有哪些 2025最新有效兑换码大全 11-04
 - 猪了个猪兑换码最新可用 2025最新有效兑换码汇总 11-04