我想在DB++算法上进行知识蒸馏cml训练,Teacher的预训练模型是ResNet50_dcn_asf_synthtext_pretrained,student的预训练模型该怎么选择?我最先尝试了MobileNetV3_large_x0_5_pretrained,但是启动的时候会报一堆
The pretrained params backbone.** not in model
等类似警告,所以我想可能我的预训练模型选的不合适。
另外,训练参数该怎样配置呢,以下是我的配置,不知道是不是正确的,希望得到解答!
Global:
debug: false
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/det_db++_cml/
save_epoch_step: 25
eval_batch_step:
- 0
- 2000
cal_metric_during_train: false
pretrained_model: ./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
checkpoints: null
save_inference_dir:
use_visualdl: True
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
Architecture:
name: DistillationModel
model_type: det
algorithm: Distillation
# algorithm: DB++
Models:
Student:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
model_type: det
algorithm: DB++
Transform: null
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: true
Neck:
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: DBHead
k: 50
Student2:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
model_type: det
algorithm: DB++
Transform: null
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: true
Neck:
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: DBHead
k: 50
Teacher:
freeze_params: true
pretrained: ./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
return_all_feats: false
model_type: det
algorithm: DB++
Backbone:
name: ResNet
layers: 50
dcn_stage: [False, True, True, True]
Neck:
name: DBFPN
out_channels: 256
use_asf: True
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDilaDBLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
- ["Student2", "Teacher"]
key: maps
balance_loss: true
main_loss_type: BCELoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDMLLoss:
model_name_pairs:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
# act: None
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
# key: maps
# name: DBLoss
balance_loss: true
main_loss_type: BCELoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: DecayLearningRate
learning_rate: 0.007
epochs: 500
factor: 0.9
end_lr: 0
weight_decay: 0.0001
PostProcess:
name: DistillationDBPostProcess
model_name: ["Student"]
key: head_out
thresh: 0.3
box_thresh: 0.5
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DistillationMetric
base_metric_name: DetMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./dataset/
label_file_list:
- ./dataset/Label.txt
ratio_list:
- 1.0
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 640
- 640
max_tries: 10
keep_ratio: true
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 4
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./dataset/
label_file_list:
- ./dataset/Label_eval.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest:
image_shape:
- 736
- 736
keep_ratio: True
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
profiler_options: null
2条答案
按热度按时间wn9m85ua1#
The pretrained params backbone.** not in model
等类似警告,所以我想可能我的预训练模型选的不合适。
这个警告是正常的,如果绝大部分参数都加载不上,可以用最新文档里的预训练模型下载链接: https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
你的配置文件中用的是mv3_large_0.5,用mv3_large_0.5的预训练是没问题的,DB++模型部分参数的配置可以参考配置文件: https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/det/det_r50_db%2B%2B_icdar15.yml
kqhtkvqz2#
The pretrained params backbone.** not in model 等类似警告,所以我想可能我的预训练模型选的不合适。
这个警告是正常的,如果绝大部分参数都加载不上,可以用最新文档里的预训练模型下载链接: https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
你的配置文件中用的是mv3_large_0.5,用mv3_large_0.5的预训练是没问题的,DB++模型部分参数的配置可以参考配置文件: https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/det/det_r50_db%2B%2B_icdar15.yml
明白,感谢!