三种类型:v2-ch,v3-ch和v3-en
用的docker镜像是 paddlecloud/paddleocr:2.5-gpu-cuda11.2-cudnn8-efbb0a
三种类型:v2-ch,v3-ch和v3-en
用的docker镜像是 paddlecloud/paddleocr:2.5-gpu-cuda11.2-cudnn8-efbb0a
| #!/usr/bin/env python3 | |
| import os.path | |
| import shutil | |
| import sys | |
| from pathlib import Path | |
| import sqlite3 | |
| import json | |
| import cv2 | |
| from paddleocr import PaddleOCR | |
| OUTPUT_DIR = './pp-ocr' | |
| INVALID_DIR = OUTPUT_DIR + '/ignored/invalid' | |
| ERR_LIST_FILE = OUTPUT_DIR + '/errs' | |
| TMP_DIR = OUTPUT_DIR + '/tmp' | |
| class Colors: | |
| HEADER = '\033[95m' | |
| OKBLUE = '\033[94m' | |
| OKCYAN = '\033[96m' | |
| OKGREEN = '\033[92m' | |
| WARNING = '\033[93m' | |
| FAIL = '\033[91m' | |
| ENDC = '\033[0m' | |
| BOLD = '\033[1m' | |
| UNDERLINE = '\033[4m' | |
| def split_image(image, out_dir): | |
| def find_pieces_number(long_side, short_side, ratio): | |
| n = 1 | |
| while True: | |
| if (long_side // n) / short_side <= ratio: | |
| return n | |
| n += 1 | |
| def split_pieces_v(img, num): | |
| out_list = [] | |
| # split vertically | |
| piece_h = img.shape[:2][0] // num | |
| w = img.shape[:2][1] | |
| for i in range(num): | |
| piece = img[piece_h * i: piece_h * (i + 1), 0:w] | |
| name = out_dir + '/' + str(i) + '.png' | |
| print(name) | |
| cv2.imwrite(name, piece) | |
| out_list.append(name) | |
| return out_list | |
| def split_pieces_h(img, num): | |
| # split horizontally | |
| out_list = [] | |
| piece_w = img.shape[:2][1] // num | |
| h = img.shape[:2][0] | |
| for i in range(num): | |
| piece = img[0:h, piece_w * i: piece_w * (i + 1)] | |
| name = out_dir + '/' + str(i) + '.png' | |
| print(name) | |
| cv2.imwrite(name, piece) | |
| out_list.append(name) | |
| return out_list | |
| height, width = image.shape[:2] | |
| if height > width: | |
| pieces = find_pieces_number(height, width, 3) | |
| return split_pieces_v(image, pieces) | |
| else: | |
| pieces = find_pieces_number(width, height, 3) | |
| return split_pieces_h(image, pieces) | |
| def result_to_content(result): | |
| output_lines = [] | |
| for x in result: | |
| output_lines.append(x[-1][0]) | |
| return '\n'.join(output_lines) | |
| def main(): | |
| args = sys.argv | |
| args.pop(0) | |
| if len(args) == 0: | |
| print('Usage: command <models-dir> <OCR-type>') | |
| print('The type can be one of v2-ch, v3-ch and v3-en') | |
| exit(1) | |
| models_dir = args[0] | |
| ocr_type = args[1] | |
| db_file = OUTPUT_DIR + '/' + ocr_type + '.db' | |
| configs = { | |
| 'v2-ch': { | |
| 'det': models_dir + '/ch_ppocr_server_v2.0_det_infer', | |
| 'cls': models_dir + '/ch_ppocr_mobile_v2.0_cls_infer', | |
| 'rec': models_dir + '/ch_ppocr_server_v2.0_rec_infer', | |
| 'version': 'PP-OCRv2', | |
| 'lang': 'ch' | |
| }, | |
| 'v3-ch': { | |
| 'det': models_dir + '/ch_PP-OCRv3_det_infer', | |
| 'cls': models_dir + '/ch_ppocr_mobile_v2.0_cls_infer', | |
| 'rec': models_dir + '/ch_PP-OCRv3_rec_infer', | |
| 'version': 'PP-OCRv3', | |
| 'lang': 'ch' | |
| }, | |
| 'v3-en': { | |
| 'det': models_dir + '/en_PP-OCRv3_det_infer', | |
| 'cls': models_dir + '/ch_ppocr_mobile_v2.0_cls_infer', | |
| 'rec': models_dir + '/en_PP-OCRv3_rec_infer', | |
| 'version': 'PP-OCRv3', | |
| 'lang': 'en' | |
| } | |
| } | |
| config = configs[ocr_type] | |
| ocr = PaddleOCR( | |
| use_gpu=True, det_model_dir=config['det'], | |
| cls_model_dir=config['cls'], | |
| rec_model_dir=config['rec'], | |
| ocr_version=config['version'], lang=config['lang'] | |
| ) | |
| Path(OUTPUT_DIR).mkdir(exist_ok=True) | |
| Path(INVALID_DIR).mkdir(parents=True, exist_ok=True) | |
| Path(TMP_DIR).mkdir(exist_ok=True) | |
| count = 0 | |
| database = sqlite3.connect(db_file) | |
| database.execute("""CREATE TABLE IF NOT EXISTS ocr | |
| ( | |
| `path` TEXT NOT NULL UNIQUE, | |
| content TEXT NOT NULL, | |
| `json` TEXT NOT NULL | |
| )""") | |
| paths = [] | |
| for line in sys.stdin: | |
| paths.append(line.rstrip()) | |
| for path in paths: | |
| count += 1 | |
| if not Path(path).is_file(): | |
| continue | |
| if not os.path.exists(path): | |
| # append error list file | |
| with open(ERR_LIST_FILE, 'a') as f: | |
| f.write(path + '\n') | |
| continue | |
| if os.path.getsize(path) == 0: | |
| continue | |
| has_record = database.cursor().execute('select * from ocr where `path` = ?', [path]).fetchone() is not None | |
| if has_record: | |
| print(Colors.WARNING + "Skip: " + path + Colors.ENDC) | |
| continue | |
| filename = os.path.basename(path) | |
| print(path) | |
| img = cv2.imread(path) | |
| if img is None: | |
| # invalid image | |
| shutil.copy(path, INVALID_DIR) | |
| continue | |
| height, width, _ = img.shape | |
| aspect_ratio_limit = 4 | |
| if width / height > aspect_ratio_limit or height / width > aspect_ratio_limit: | |
| pieces = split_image(img, TMP_DIR) | |
| print(pieces) | |
| content = '' | |
| for p in pieces: | |
| result = ocr.ocr(p) | |
| assert result is not None | |
| content += result_to_content(result) | |
| content += '\n' | |
| print(Colors.OKCYAN + p + Colors.ENDC) | |
| # TODO: workaround: empty JSON | |
| database.execute("""insert into ocr (`path`, content, `json`) | |
| values (?, ?, ?)""", [path, content, '']) | |
| database.commit() | |
| continue | |
| result = ocr.ocr(path) | |
| if result is None: | |
| # append error list file | |
| with open(ERR_LIST_FILE, 'a') as f: | |
| f.write(path + '\n') | |
| continue | |
| json_output = json.dumps(result) | |
| content = result_to_content(result) | |
| database.execute("""insert into ocr (`path`, content, `json`) | |
| values (?, ?, ?)""", [path, content, json_output]) | |
| database.commit() | |
| print(Colors.OKCYAN + '[{}/{}] {}'.format(count, len(paths), filename) + Colors.ENDC) | |
| if __name__ == "__main__": | |
| main() |