paddle ocr 文本检测模型训练

2024-12-11 19:25:00
pjd
原创 25
摘要:ppocrv3

文本检测 

我们基于PP-OCRv3使用3种方案 进行检测模型的训练、评估:

PP-OCRv3中文超轻量检测预训练模型

PP-OCRv3中文超轻量检测预训练模型+验证集padding

PP-OCRv3中文超轻量检测预训练模型+ finetune


#模型下载下载地址:https://aistudio.baidu.com/modelsdetail/17/download


1. 中文文本检测模型

模型名称模型简介推理模型大小下载地址
ch_PP-OCRv3_det原始超轻量模型,支持中英文、多语种文本检测3.80M 训练模型

2. 文本识别模型
模型名称模型简介推理模型大小下载地址
ch_PP-OCRv3_rec原始超轻量模型,支持中英文、数字识别12.4M 训练模型
说明: 训练模型是基于预训练模型在真实数据与竖排合成文本数据上finetune得到的模型,在真实应用场景中有着更好的表现,预训练模型则是直接基于全量真实数据与合成数据训练得到,更适合用于在自己的数据集上finetune。

3. 文本方向分类模型
模型名称模型简介推理模型大小下载地址
ch_ppocr_mobile_v2.0_cls 原始分类器模型,对检测到的文本行文字角度分类 1.38M 推理模型 / 训练模型



采用方案3:预训练模型 + finetune
在1500图片上进行训练和评估,其中train数据1200张,val数据300张,
修改配置文件configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml中的以下字段:

Global.epoch_num: 这里设置为1,方便快速跑通,实际中根据数据量调整该值
Global.save_model_dir:模型保存路径
Global.pretrained_model:指向预训练模型路径,'./pretrain_models/en_PP-OCRv3_det_distill_train/student.pdparams'
Optimizer.lr.learning_rate:调整学习率,本实验设置为0.0005
Train.dataset.data_dir:指向训练集图片存放目录,'/home/aistudio/dataset'
Train.dataset.label_file_list:指向训练集标注文件,'/home/aistudio/dataset/det_gt_train.txt'
Train.dataset.transforms.EastRandomCropData.size:训练尺寸改为[480,64]
Eval.dataset.data_dir:指向验证集图片存放目录,'/home/aistudio/dataset/'
Eval.dataset.label_file_list:指向验证集标注文件,'/home/aistudio/dataset/det_gt_val.txt'
Eval.dataset.transforms.DetResizeForTest:评估尺寸,添加如下参数
    limit_side_len: 64
    limit_type:'min'

#确认安装paddle
conda install paddlepaddle-gpu==2.6.2 cudatoolkit=11.6 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ -c conda-forge 

执行下面命令启动训练:

cd ~/PaddleOCR/
python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml




(paddle_ocr) PS ~\PaddleOCR> python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student_cimc.yml
[2024/12/11 13:35:07] ppocr INFO: Architecture : 
[2024/12/11 13:35:07] ppocr INFO:     Backbone : 
[2024/12/11 13:35:07] ppocr INFO:         disable_se : True
[2024/12/11 13:35:07] ppocr INFO:         model_name : large
[2024/12/11 13:35:07] ppocr INFO:         name : MobileNetV3
[2024/12/11 13:35:07] ppocr INFO:         scale : 0.5
[2024/12/11 13:35:07] ppocr INFO:     Head : 
[2024/12/11 13:35:07] ppocr INFO:         k : 50
[2024/12/11 13:35:07] ppocr INFO:         name : DBHead
[2024/12/11 13:35:07] ppocr INFO:     Neck : 
[2024/12/11 13:35:07] ppocr INFO:         name : RSEFPN
[2024/12/11 13:35:07] ppocr INFO:         out_channels : 96
[2024/12/11 13:35:07] ppocr INFO:         shortcut : True
[2024/12/11 13:35:07] ppocr INFO:     Transform : None
[2024/12/11 13:35:07] ppocr INFO:     algorithm : DB
[2024/12/11 13:35:07] ppocr INFO:     model_type : det
[2024/12/11 13:35:07] ppocr INFO: Eval : 
[2024/12/11 13:35:07] ppocr INFO:     dataset :
[2024/12/11 13:35:07] ppocr INFO:         data_dir : ./work/train_data
[2024/12/11 13:35:07] ppocr INFO:         label_file_list : ['./work/train_data/img/Label.txt']
[2024/12/11 13:35:07] ppocr INFO:         name : SimpleDataSet
[2024/12/11 13:35:07] ppocr INFO:         transforms :
[2024/12/11 13:35:07] ppocr INFO:             DecodeImage :
[2024/12/11 13:35:07] ppocr INFO:                 channel_first : False
[2024/12/11 13:35:07] ppocr INFO:                 img_mode : BGR
[2024/12/11 13:35:07] ppocr INFO:             DetLabelEncode : None
[2024/12/11 13:35:07] ppocr INFO:             DetResizeForTest : None
[2024/12/11 13:35:07] ppocr INFO:             NormalizeImage :
[2024/12/11 13:35:07] ppocr INFO:                 mean : [0.485, 0.456, 0.406]
[2024/12/11 13:35:07] ppocr INFO:                 order : hwc
[2024/12/11 13:35:07] ppocr INFO:                 scale : 1./255.
[2024/12/11 13:35:07] ppocr INFO:                 std : [0.229, 0.224, 0.225]
[2024/12/11 13:35:07] ppocr INFO:             ToCHWImage : None
[2024/12/11 13:35:07] ppocr INFO:             KeepKeys :
[2024/12/11 13:35:07] ppocr INFO:                 keep_keys : ['image', 'shape', 'polys', 'ignore_tags']
[2024/12/11 13:35:07] ppocr INFO:     loader :
[2024/12/11 13:35:07] ppocr INFO:         batch_size_per_card : 1
[2024/12/11 13:35:07] ppocr INFO:         drop_last : False
[2024/12/11 13:35:07] ppocr INFO:         num_workers : 2
[2024/12/11 13:35:07] ppocr INFO:         shuffle : False
[2024/12/11 13:35:07] ppocr INFO: Global :
[2024/12/11 13:35:07] ppocr INFO:     cal_metric_during_train : False
[2024/12/11 13:35:07] ppocr INFO:     checkpoints : None
[2024/12/11 13:35:07] ppocr INFO:     debug : True
[2024/12/11 13:35:07] ppocr INFO:     distributed : False
[2024/12/11 13:35:07] ppocr INFO:     epoch_num : 500
[2024/12/11 13:35:07] ppocr INFO:     eval_batch_step : [0, 400]
[2024/12/11 13:35:07] ppocr INFO:     infer_img : doc/imgs_en/img_10.jpg
[2024/12/11 13:35:07] ppocr INFO:     log_smooth_window : 20
[2024/12/11 13:35:07] ppocr INFO:     pretrained_model : ./work/pretrain_models/ch_PP-OCRv3_det_distill_train/student
[2024/12/11 13:35:07] ppocr INFO:     print_batch_step : 1
[2024/12/11 13:35:07] ppocr INFO:     save_epoch_step : 100
[2024/12/11 13:35:07] ppocr INFO:     save_inference_dir : None
[2024/12/11 13:35:07] ppocr INFO:     save_model_dir : ./work/output/ch_PP-OCR_V3_det/
[2024/12/11 13:35:07] ppocr INFO:     save_res_path : ./work/checkpoints/det_db/predicts_db.txt
[2024/12/11 13:35:07] ppocr INFO:     use_gpu : False
[2024/12/11 13:35:07] ppocr INFO:     use_visualdl : False
[2024/12/11 13:35:07] ppocr INFO: Loss :
[2024/12/11 13:35:07] ppocr INFO:     alpha : 5
[2024/12/11 13:35:07] ppocr INFO:     balance_loss : True
[2024/12/11 13:35:07] ppocr INFO:     beta : 10
[2024/12/11 13:35:07] ppocr INFO:     main_loss_type : DiceLoss
[2024/12/11 13:35:07] ppocr INFO:     name : DBLoss
[2024/12/11 13:35:07] ppocr INFO:     ohem_ratio : 3
[2024/12/11 13:35:07] ppocr INFO: Metric :
[2024/12/11 13:35:07] ppocr INFO:     main_indicator : hmean
[2024/12/11 13:35:07] ppocr INFO:     name : DetMetric
[2024/12/11 13:35:07] ppocr INFO: Optimizer :
[2024/12/11 13:35:07] ppocr INFO:     beta1 : 0.9
[2024/12/11 13:35:07] ppocr INFO:     beta2 : 0.999
[2024/12/11 13:35:07] ppocr INFO:     lr :
[2024/12/11 13:35:07] ppocr INFO:         learning_rate : 0.0005
[2024/12/11 13:35:07] ppocr INFO:         name : Cosine
[2024/12/11 13:35:07] ppocr INFO:         warmup_epoch : 2
[2024/12/11 13:35:07] ppocr INFO:     name : Adam
[2024/12/11 13:35:07] ppocr INFO:     regularizer :
[2024/12/11 13:35:07] ppocr INFO:         factor : 5e-05
[2024/12/11 13:35:07] ppocr INFO:         name : L2
[2024/12/11 13:35:07] ppocr INFO: PostProcess :
[2024/12/11 13:35:07] ppocr INFO:     box_thresh : 0.6
[2024/12/11 13:35:07] ppocr INFO:     max_candidates : 1000
[2024/12/11 13:35:07] ppocr INFO:     name : DBPostProcess
[2024/12/11 13:35:07] ppocr INFO:     thresh : 0.3
[2024/12/11 13:35:07] ppocr INFO:     unclip_ratio : 1.5
[2024/12/11 13:35:07] ppocr INFO: Train :
[2024/12/11 13:35:07] ppocr INFO:     dataset :
[2024/12/11 13:35:07] ppocr INFO:         data_dir : ./work/train_data/
[2024/12/11 13:35:07] ppocr INFO:         label_file_list : ['./work/train_data/img/Label.txt']
[2024/12/11 13:35:07] ppocr INFO:         name : SimpleDataSet
[2024/12/11 13:35:07] ppocr INFO:         ratio_list : [1.0]
[2024/12/11 13:35:07] ppocr INFO:         transforms :
[2024/12/11 13:35:07] ppocr INFO:             DecodeImage :
[2024/12/11 13:35:07] ppocr INFO:                 channel_first : False
[2024/12/11 13:35:07] ppocr INFO:                 img_mode : BGR
[2024/12/11 13:35:07] ppocr INFO:             DetLabelEncode : None
[2024/12/11 13:35:07] ppocr INFO:             IaaAugment :
[2024/12/11 13:35:07] ppocr INFO:                 augmenter_args :
[2024/12/11 13:35:07] ppocr INFO:                     args :
[2024/12/11 13:35:07] ppocr INFO:                         p : 0.5
[2024/12/11 13:35:07] ppocr INFO:                     type : Fliplr
[2024/12/11 13:35:07] ppocr INFO:                     args :
[2024/12/11 13:35:07] ppocr INFO:                         rotate : [-10, 10]
[2024/12/11 13:35:07] ppocr INFO:                     type : Affine
[2024/12/11 13:35:07] ppocr INFO:                     args :
[2024/12/11 13:35:07] ppocr INFO:                         size : [0.5, 3]
[2024/12/11 13:35:07] ppocr INFO:                     type : Resize
[2024/12/11 13:35:07] ppocr INFO:             EastRandomCropData :
[2024/12/11 13:35:07] ppocr INFO:                 keep_ratio : True
[2024/12/11 13:35:07] ppocr INFO:                 max_tries : 50
[2024/12/11 13:35:07] ppocr INFO:                 size : [960, 960]
[2024/12/11 13:35:07] ppocr INFO:             MakeBorderMap :
[2024/12/11 13:35:07] ppocr INFO:                 shrink_ratio : 0.4
[2024/12/11 13:35:07] ppocr INFO:                 thresh_max : 0.7
[2024/12/11 13:35:07] ppocr INFO:                 thresh_min : 0.3
[2024/12/11 13:35:07] ppocr INFO:             MakeShrinkMap :
[2024/12/11 13:35:07] ppocr INFO:                 min_text_size : 8
[2024/12/11 13:35:07] ppocr INFO:                 shrink_ratio : 0.4
[2024/12/11 13:35:07] ppocr INFO:             NormalizeImage :
[2024/12/11 13:35:07] ppocr INFO:                 mean : [0.485, 0.456, 0.406]
[2024/12/11 13:35:07] ppocr INFO:                 order : hwc
[2024/12/11 13:35:07] ppocr INFO:                 scale : 1./255.
[2024/12/11 13:35:07] ppocr INFO:                 std : [0.229, 0.224, 0.225]
[2024/12/11 13:35:07] ppocr INFO:             ToCHWImage : None
[2024/12/11 13:35:07] ppocr INFO:             KeepKeys :
[2024/12/11 13:35:07] ppocr INFO:                 keep_keys : ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask']
[2024/12/11 13:35:07] ppocr INFO:     loader :
[2024/12/11 13:35:07] ppocr INFO:         batch_size_per_card : 8
[2024/12/11 13:35:07] ppocr INFO:         drop_last : False
[2024/12/11 13:35:07] ppocr INFO:         num_workers : 4
[2024/12/11 13:35:07] ppocr INFO:         shuffle : True
[2024/12/11 13:35:07] ppocr INFO: profiler_options : None
[2024/12/11 13:35:07] ppocr INFO: train with paddle 2.6.2 and device Place(cpu)
[2024/12/11 13:35:07] ppocr INFO: Initialize indexs of datasets:['./work/train_data/img/Label.txt']
[2024/12/11 13:35:07] ppocr INFO: Initialize indexs of datasets:['./work/train_data/img/Label.txt']
[2024/12/11 13:35:08] ppocr INFO: train dataloader has 2 iters
[2024/12/11 13:35:08] ppocr INFO: valid dataloader has 9 iters
[2024/12/11 13:35:08] ppocr INFO: load pretrain successful from ./work/pretrain_models/ch_PP-OCRv3_det_distill_train/student
[2024/12/11 13:35:08] ppocr INFO: During the training process, after the 0th iteration, an evaluation is run every 400 iterations
[2024/12/11 13:35:55] ppocr INFO: epoch: [1/500], global_step: 1, lr: 0.000000, loss: 2.071318, loss_shrink_maps: 1.280234, loss_threshold_maps: 0.534262, loss_binary_maps: 0.256823, loss_cbn: 0.000000, avg_reader_cost: 2.17457 s, avg_batch_cost: 46.90998 s, avg_samples: 8.0, ips: 0.17054 samples/s, eta: 13:01:03,
[2024/12/11 13:35:55] ppocr INFO: save model in ./work/output/ch_PP-OCR_V3_det/latest
[2024/12/11 13:36:41] ppocr INFO: epoch: [2/500], global_step: 2, lr: 0.000062, loss: 2.127443, loss_shrink_maps: 1.316309, loss_threshold_maps: 0.547244, loss_binary_maps: 0.263890, loss_cbn: 0.000000, avg_reader_cost: 2.32840 s, avg_batch_cost: 46.07263 s, avg_samples: 8.0, ips: 0.17364 samples/s, eta: 12:52:31,
[2024/12/11 13:36:41] ppocr INFO: save model in ./work/output/ch_PP-OCR_V3_det/latest
[2024/12/11 13:37:26] ppocr INFO: epoch: [3/500], global_step: 3, lr: 0.000125, loss: 2.183568, loss_shrink_maps: 1.352385, loss_threshold_maps: 0.549175, loss_binary_maps: 0.270958, loss_cbn: 0.000000, avg_reader_cost: 2.12764 s, avg_batch_cost: 45.30531 s, avg_samples: 8.0, ips: 0.17658 samples/s, eta: 12:44:25,
                    


# 导出检测模型 
python tools/export_model.py  -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student_cimc.yml  -o Global.pretrained_model="./work/output/ch_PP-OCR_V3_det/best_accuracy"  Global.save_inference_dir="./work/inference_model/ch_PP-OCR_V3_det/" 
# 执行预测
python tools/infer/predict_system.py --image_dir="./work/train_data/img/1.jpg" --det_model_dir="./work/inference_model/ch_PP-OCR_V3_det"  --det_limit_side_len=48 --det_limit_type='min' --det_db_unclip_ratio=2.5 --rec_model_dir="./work/model/ch_PP-OCRv3_rec_infer"  --rec_image_shape="3, 48, 320" --draw_img_save_dir=./work/det_rec_infer/ --use_space_char=False --use_angle_cls=False --use_gpu=false


识别效果: