paddleoc onnx转换推理

x33g5p2x  于2022-02-07 转载在 其他  
字(5.0k)|赞(0)|评价(0)|浏览(393)

paddle 文字识别验证代码:

onnx c++推理:

python onnx识别部分推理示例:

感谢博客:

PaddleOCR转ONNX模型(推理部分)_favorxin的博客-CSDN博客_paddleocr转onnx

paddle 文字识别验证代码:

infer_rec.py

参见博客:

https://blog.csdn.net/jacke121/article/details/122647698

onnx c++推理:

参考项目:

A cross platform OCR Library based on PaddleOCR & OnnxRuntime

https://github.com/RapidAI/RapidOCR

库下载:

https://gitee.com/benjaminwan/ocr-lite-onnx/releases/v1.0

开源库:

https://github.com/RapidAI/RapidOCR/tree/main/cpp

c++推理示例代码:

#include <stdlib.h>
#include <stdio.h>
#include <iostream>
#include <memory.h>
#include <string>
#include "../include/rapidocr_api.h"

#define BPOCR_DET_MODEL "ch_ppocr_mobile_v2.0_det_infer.onnx"
#define BPOCR_CLS_MODEL "ch_ppocr_mobile_v2.0_cls_infer.onnx"
#define BPOCR_REC_MODEL "ch_ppocr_mobile_v2.0_rec_infer.onnx"
#define BPOCR_KEY_PATH  "ppocr_keys_v1.txt"

#define THREAD_NUM   3
#define MAX_PATH    260
#ifdef WIN32

const  char* DEFAULT_MODEL_DIR = "E:\\bai-piao-ocr\\cpp\\BaiPiaoOcrOnnx\\models\\";

const char* DEFAULT_TEST_IMG = "E:\\bai-piao-ocr\\cpp\\BaiPiaoOcrOnnx\\images\\long1.jpg";
#else
const  char * DEFAULT_MODEL_DIR  ="/data/workprj/RapidOCR/models/";

const char *  DEFAULT_TEST_IMG  ="/data/workprj/RapidOCR/images/1.jpg";
#endif
int main(int argc, char * argv[])
{

	const char *szModelDir = NULL;
	const char *szImagePath = NULL;

	if (argc == 1)
	{
		szModelDir = DEFAULT_MODEL_DIR;
		szImagePath = DEFAULT_TEST_IMG;
	}
	else
	{
		szModelDir = argv[1];
		szImagePath = argv[2];
	}

	char szDetModelPath[MAX_PATH] = { 0 };
	char szClsModelPath[MAX_PATH] = { 0 };
	char szRecModelPath[MAX_PATH] = { 0 };
	char szKeylPath[MAX_PATH] = { 0 };

	strcpy(szDetModelPath, szModelDir);
	strcpy(szClsModelPath, szModelDir);
	strcpy(szRecModelPath, szModelDir);
	strcpy(szKeylPath, szModelDir);

	strcat(szDetModelPath, BPOCR_DET_MODEL);
	strcat(szClsModelPath, BPOCR_CLS_MODEL);
	strcat(szRecModelPath, BPOCR_REC_MODEL);
	strcat(szKeylPath, BPOCR_KEY_PATH);

	BPHANDLE  Handle = BPOcrInit(szDetModelPath, szClsModelPath, szRecModelPath, szKeylPath, THREAD_NUM);
	if (!Handle)
	{
		printf("cannot initialize the OCR Engine.\n");
		return -1;
	}

	RAPIDOCR_PARAM Param = { 0 };
	BOOL bRet = BPOcrDoOcr(Handle, szImagePath, true, false, &Param);
	if (bRet)
	{
		int nLen = BPOcrGetLen(Handle);
		if (nLen > 0)
		{
			char* szInfo = (char*)malloc(nLen);
			if (szInfo)
			{
				if (BPOcrGetResult(Handle, szInfo, nLen))
				{
					printf(szInfo);
				}

				free(szInfo);
			}

		}
	}

	if (Handle)
	{
		BPOcrDeinit(Handle);
	}

	return 0;
}

python onnx识别部分推理示例:

import math

import cv2
import onnxruntime
import numpy as np
small_rec_file="vc_rec_dynamic.onnx"
onet_rec_session = onnxruntime.InferenceSession(small_rec_file)

## 根据推理结果解码识别结果
class process_pred(object):
    def __init__(self, character_dict_path=None, character_type='ch', use_space_char=False):
        self.character_str = ''
        with open(character_dict_path, 'rb') as fin:
            lines = fin.readlines()
            for line in lines:
                line = line.decode('utf-8').strip('\n').strip('\r\n')
                self.character_str += line
        if use_space_char:
            self.character_str += ' '
        dict_character = list(self.character_str)

        dict_character = self.add_special_char(dict_character)
        self.dict = {}
        for i, char in enumerate(dict_character):
            self.dict[char] = i
        self.character = dict_character

    def add_special_char(self, dict_character):
        dict_character = ['blank'] + dict_character
        return dict_character

    def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
        result_list = []
        ignored_tokens = [0]
        batch_size = len(text_index)
        for batch_idx in range(batch_size):
            char_list = []
            conf_list = []
            for idx in range(len(text_index[batch_idx])):
                if text_index[batch_idx][idx] in ignored_tokens:
                    continue
                if is_remove_duplicate:
                    if idx > 0 and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]:
                        continue
                char_list.append(self.character[int(text_index[batch_idx][idx])])
                if text_prob is not None:
                    conf_list.append(text_prob[batch_idx][idx])
                else:
                    conf_list.append(1)
            text = ''.join(char_list)
            result_list.append((text, np.mean(conf_list)))
        return result_list

    def __call__(self, preds, label=None):
        if not isinstance(preds, np.ndarray):
            preds = np.array(preds)
        preds_idx = preds.argmax(axis=2)
        preds_prob = preds.max(axis=2)
        text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
        if label is None:
            return text
        label = self.decode(label)
        return text, label

postprocess_op = process_pred('../ppocr/utils/ppocr_keys_v1.txt', 'ch', True)
def resize_norm_img(img, max_wh_ratio):
    imgC, imgH, imgW = [int(v) for v in "3, 32, 100".split(",")]
    assert imgC == img.shape[2]
    imgW = int((32 * max_wh_ratio))
    h, w = img.shape[:2]
    ratio = w / float(h)
    if math.ceil(imgH * ratio) > imgW:
        resized_w = imgW
    else:
        resized_w = int(math.ceil(imgH * ratio))
    resized_image = cv2.resize(img, (resized_w, imgH))
    resized_image = resized_image.astype('float32')
    resized_image = resized_image.transpose((2, 0, 1)) / 255
    resized_image -= 0.5
    resized_image /= 0.5
    padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
    padding_im[:, :, 0:resized_w] = resized_image
    return padding_im

def get_img_res(onnx_model, img, process_op):
    h, w = img.shape[:2]
    img = resize_norm_img(img, w * 1.0 / h)
    img = img[np.newaxis, :]
    inputs = {onnx_model.get_inputs()[0].name: img}
    outs = onnx_model.run(None, inputs)
    result = process_op(outs[0])
    return result

pic=cv2.imread(r"F:\project\jushi\data\shuini\ocr_crop\image\093807_148562_1.jpg")

# pic=cv2.cvtColor(pic,cv2.COLOR_BGR2RGB)
res = get_img_res(onet_rec_session, pic, postprocess_op)

print(res)

相关文章