最近在学习的过程中,遇到了一些场景,需要查看 pytorch 模型 结构,对权重文件进行修改,然后再加载,这里记录一下 处理流程以及相关的知识
查看保存的pth信息
有些pth文件只保存了 模型权重
有些pth文件保存的信息更多,有 model,optimizer,scheduler
因此在加载前,需要搞明白这个pth 保存了什么
weight = torch.load(PATH_TO_WEIGHT)
print(weight.keys())
打印权重文件信息
# 只保存了 model
for key,value in weight.items():
print(key,value)
# 保存了 model optimizer scheduler
for key,value in weight['model'].items():
print(key,value)
查看模型结构
# 定义模型
resnet
# 查看模型
# 1 直接打印
print(resnet)
# 2 api
# 不同的api 有所区别 注意区分 ,一般使用 children 相关即可
for m in resnet.children():
print(m)
for name, m in resnet.named_children():
print(name, " >>> ", m)
# 类似一种递归返回 可以进行实验查看
for m in resnet.modules():
print(m)
for name, m in resnet.named_modules():
print(name, " >>> ", m)
场景1:去除分类输出层并重新加载
有时候我们需要去加载模型的预训练权重,但是由于使用的数据集和任务,与预训练时的不一致,如预训练时分类数 80,但在我们的任务上 分类数为2,这个时候直接加载,就会出现 size dismatch 问题,无法正常训练;还有可能存在 部分组件 不存在的某些少见情况
我们处理的一个过程
对比 使用的模型 和 预训练权重的 结构
一般模型体积比较大, 这里建议 将 两个结构情况分别输出到文件里 便于对比查看
import torch
import sys
file1 = open('./model_structure.txt','w')
file2 = open('./weight_structure.txt','w')
# 定义模型
model
# 加载权重
pretrained_weight = torch.load(PATH_TO_WEIGHT)
# 定向print到文件
sys.stdout = file1
# 打印model structure信息
for name,module in model.named_children():
print(name,module)
sys.stdout = file2
# 打印 pretrained weight 信息
for k,v in pretrained_weight.items():
print(k,v.shape)
我们得到了 这两个 结构信息的 文件,进行对比,常见如分类数不一致 ,不加载最后一层输出的权重
我们就需要 将 预训练权重 这一部分给去除掉
filter_weight = {key:value for key,value in pretrained_weigths.items() if key in model.state_dict() and '最后一层名,如 classificationModel.output' not in key}
这样我们就得到了经过处理后的 预训练权重 然后去加载
model.load_state_dict(filter_weight,strict=False)
其他的需求场景
甚至还有可能只要 预训练权重中 的某一组件的权重,这个时候,就需要去细致处理,基于这个预训练权重 重新 制作一份 权重【就是使用 if 进行筛选】
如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:
如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度 学习率learning rate不一样,
这些场景 大家可以自行查找方法,进行解决。