PaddleOCR DB++知识蒸馏

bakd9h0s  于 2022-12-31  发布在  其他
关注(0)|答案(2)|浏览(406)

我想在DB++算法上进行知识蒸馏cml训练,Teacher的预训练模型是ResNet50_dcn_asf_synthtext_pretrained,student的预训练模型该怎么选择?我最先尝试了MobileNetV3_large_x0_5_pretrained,但是启动的时候会报一堆
The pretrained params backbone.** not in model
等类似警告,所以我想可能我的预训练模型选的不合适。
另外,训练参数该怎样配置呢,以下是我的配置,不知道是不是正确的,希望得到解答!

Global:
  debug: false
  use_gpu: true
  epoch_num: 500
  log_smooth_window: 20
  print_batch_step: 10
  save_model_dir: ./output/det_db++_cml/
  save_epoch_step: 25
  eval_batch_step:
  - 0
  - 2000
  cal_metric_during_train: false
  pretrained_model: ./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
  checkpoints: null
  save_inference_dir: 
  use_visualdl: True
  infer_img: doc/imgs_en/img_10.jpg
  save_res_path: ./checkpoints/det_db/predicts_db.txt

Architecture:
  name: DistillationModel
  model_type: det
  algorithm: Distillation
  # algorithm: DB++
  Models:
    Student:
      pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
      model_type: det
      algorithm: DB++
      Transform: null
      Backbone:
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: true
      Neck:
        name: RSEFPN
        out_channels: 96
        shortcut: True
      Head:
        name: DBHead
        k: 50
    Student2:
      pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
      model_type: det
      algorithm: DB++
      Transform: null
      Backbone:
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: true
      Neck:
        name: RSEFPN
        out_channels: 96
        shortcut: True
      Head:
        name: DBHead
        k: 50
    Teacher:
      freeze_params: true
      pretrained: ./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
      return_all_feats: false
      model_type: det
      algorithm: DB++
      Backbone:
        name: ResNet
        layers: 50
        dcn_stage: [False, True, True, True]
      Neck:
        name: DBFPN
        out_channels: 256
        use_asf: True
      Head:
        name: DBHead
        kernel_list: [7,2,2]
        k: 50

Loss:
  name: CombinedLoss
  loss_config_list:
  - DistillationDilaDBLoss:
      weight: 1.0
      model_name_pairs:
      - ["Student", "Teacher"]
      - ["Student2", "Teacher"]
      key: maps
      balance_loss: true
      main_loss_type: BCELoss
      alpha: 5
      beta: 10
      ohem_ratio: 3
  - DistillationDMLLoss:
      model_name_pairs:
      - ["Student", "Student2"]
      maps_name: "thrink_maps"
      weight: 1.0
      # act: None
      model_name_pairs: ["Student", "Student2"]
      key: maps
  - DistillationDBLoss:
      weight: 1.0
      model_name_list: ["Student", "Student2"]
      # key: maps
      # name: DBLoss
      balance_loss: true
      main_loss_type: BCELoss
      alpha: 5
      beta: 10
      ohem_ratio: 3

Optimizer:
  name: Momentum
  momentum: 0.9
  lr:
    name: DecayLearningRate
    learning_rate: 0.007
    epochs: 500
    factor: 0.9
    end_lr: 0
  weight_decay: 0.0001
PostProcess:
  name: DistillationDBPostProcess
  model_name: ["Student"]
  key: head_out
  thresh: 0.3
  box_thresh: 0.5
  max_candidates: 1000
  unclip_ratio: 1.5
Metric:
  name: DistillationMetric
  base_metric_name: DetMetric
  main_indicator: hmean
  key: "Student"
Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./dataset/                                                                                 
    label_file_list:                                                                                            
    - ./dataset/Label.txt                                                                                                                         
    ratio_list:                                                                                                 
    - 1.0                                                                                                       
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - DetLabelEncode: null
    - IaaAugment:
        augmenter_args:
        - type: Fliplr
          args:
            p: 0.5
        - type: Affine
          args:
            rotate:
            - -10
            - 10
        - type: Resize
          args:
            size:
            - 0.5
            - 3
    - EastRandomCropData:
        size:
        - 640
        - 640
        max_tries: 10
        keep_ratio: true
    - MakeShrinkMap:
        shrink_ratio: 0.4
        min_text_size: 8
    - MakeBorderMap:
        shrink_ratio: 0.4
        thresh_min: 0.3
        thresh_max: 0.7
    - NormalizeImage:
        scale: 1./255.
        mean:
        - 0.48109378172549
        - 0.45752457890196
        - 0.40787054090196
        std:
        - 1.0
        - 1.0
        - 1.0
        order: hwc
    - ToCHWImage: null
    - KeepKeys:
        keep_keys:
        - image
        - threshold_map
        - threshold_mask
        - shrink_map
        - shrink_mask
  loader:
    shuffle: true
    drop_last: false
    batch_size_per_card: 4
    num_workers: 8
Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./dataset/                                                                          
    label_file_list:                                                                                            
    - ./dataset/Label_eval.txt
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - DetLabelEncode: null
    - DetResizeForTest:
        image_shape:
        - 736
        - 736
        keep_ratio: True
    - NormalizeImage:
        scale: 1./255.
        mean:
        - 0.48109378172549
        - 0.45752457890196
        - 0.40787054090196
        std:
        - 1.0
        - 1.0
        - 1.0
        order: hwc
    - ToCHWImage: null
    - KeepKeys:
        keep_keys:
        - image
        - shape
        - polys
        - ignore_tags
  loader:
    shuffle: false
    drop_last: false
    batch_size_per_card: 1
    num_workers: 2
profiler_options: null
wn9m85ua

wn9m85ua1#

The pretrained params backbone.** not in model
等类似警告,所以我想可能我的预训练模型选的不合适。

这个警告是正常的,如果绝大部分参数都加载不上,可以用最新文档里的预训练模型下载链接: https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams

你的配置文件中用的是mv3_large_0.5,用mv3_large_0.5的预训练是没问题的,DB++模型部分参数的配置可以参考配置文件: https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/det/det_r50_db%2B%2B_icdar15.yml

kqhtkvqz

kqhtkvqz2#

The pretrained params backbone.** not in model 等类似警告,所以我想可能我的预训练模型选的不合适。
这个警告是正常的,如果绝大部分参数都加载不上,可以用最新文档里的预训练模型下载链接: https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams

你的配置文件中用的是mv3_large_0.5,用mv3_large_0.5的预训练是没问题的,DB++模型部分参数的配置可以参考配置文件: https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/det/det_r50_db%2B%2B_icdar15.yml

明白,感谢!

相关问题