Skip to content

Instantly share code, notes, and snippets.

@bczhc
Last active April 29, 2023 03:42
Show Gist options
  • Select an option

  • Save bczhc/514d5d3fe4e7875c6fbd8a1f9880dc2d to your computer and use it in GitHub Desktop.

Select an option

Save bczhc/514d5d3fe4e7875c6fbd8a1f9880dc2d to your computer and use it in GitHub Desktop.
PaddleOCR协作脚本
#!/usr/bin/env python3
import os.path
import shutil
import sys
from pathlib import Path
import sqlite3
import json
import cv2
from paddleocr import PaddleOCR
OUTPUT_DIR = './pp-ocr'
INVALID_DIR = OUTPUT_DIR + '/ignored/invalid'
ERR_LIST_FILE = OUTPUT_DIR + '/errs'
TMP_DIR = OUTPUT_DIR + '/tmp'
class Colors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
def split_image(image, out_dir):
def find_pieces_number(long_side, short_side, ratio):
n = 1
while True:
if (long_side // n) / short_side <= ratio:
return n
n += 1
def split_pieces_v(img, num):
out_list = []
# split vertically
piece_h = img.shape[:2][0] // num
w = img.shape[:2][1]
for i in range(num):
piece = img[piece_h * i: piece_h * (i + 1), 0:w]
name = out_dir + '/' + str(i) + '.png'
print(name)
cv2.imwrite(name, piece)
out_list.append(name)
return out_list
def split_pieces_h(img, num):
# split horizontally
out_list = []
piece_w = img.shape[:2][1] // num
h = img.shape[:2][0]
for i in range(num):
piece = img[0:h, piece_w * i: piece_w * (i + 1)]
name = out_dir + '/' + str(i) + '.png'
print(name)
cv2.imwrite(name, piece)
out_list.append(name)
return out_list
height, width = image.shape[:2]
if height > width:
pieces = find_pieces_number(height, width, 3)
return split_pieces_v(image, pieces)
else:
pieces = find_pieces_number(width, height, 3)
return split_pieces_h(image, pieces)
def result_to_content(result):
output_lines = []
for x in result:
output_lines.append(x[-1][0])
return '\n'.join(output_lines)
def main():
args = sys.argv
args.pop(0)
if len(args) == 0:
print('Usage: command <models-dir> <OCR-type>')
print('The type can be one of v2-ch, v3-ch and v3-en')
exit(1)
models_dir = args[0]
ocr_type = args[1]
db_file = OUTPUT_DIR + '/' + ocr_type + '.db'
configs = {
'v2-ch': {
'det': models_dir + '/ch_ppocr_server_v2.0_det_infer',
'cls': models_dir + '/ch_ppocr_mobile_v2.0_cls_infer',
'rec': models_dir + '/ch_ppocr_server_v2.0_rec_infer',
'version': 'PP-OCRv2',
'lang': 'ch'
},
'v3-ch': {
'det': models_dir + '/ch_PP-OCRv3_det_infer',
'cls': models_dir + '/ch_ppocr_mobile_v2.0_cls_infer',
'rec': models_dir + '/ch_PP-OCRv3_rec_infer',
'version': 'PP-OCRv3',
'lang': 'ch'
},
'v3-en': {
'det': models_dir + '/en_PP-OCRv3_det_infer',
'cls': models_dir + '/ch_ppocr_mobile_v2.0_cls_infer',
'rec': models_dir + '/en_PP-OCRv3_rec_infer',
'version': 'PP-OCRv3',
'lang': 'en'
}
}
config = configs[ocr_type]
ocr = PaddleOCR(
use_gpu=True, det_model_dir=config['det'],
cls_model_dir=config['cls'],
rec_model_dir=config['rec'],
ocr_version=config['version'], lang=config['lang']
)
Path(OUTPUT_DIR).mkdir(exist_ok=True)
Path(INVALID_DIR).mkdir(parents=True, exist_ok=True)
Path(TMP_DIR).mkdir(exist_ok=True)
count = 0
database = sqlite3.connect(db_file)
database.execute("""CREATE TABLE IF NOT EXISTS ocr
(
`path` TEXT NOT NULL UNIQUE,
content TEXT NOT NULL,
`json` TEXT NOT NULL
)""")
paths = []
for line in sys.stdin:
paths.append(line.rstrip())
for path in paths:
count += 1
if not Path(path).is_file():
continue
if not os.path.exists(path):
# append error list file
with open(ERR_LIST_FILE, 'a') as f:
f.write(path + '\n')
continue
if os.path.getsize(path) == 0:
continue
has_record = database.cursor().execute('select * from ocr where `path` = ?', [path]).fetchone() is not None
if has_record:
print(Colors.WARNING + "Skip: " + path + Colors.ENDC)
continue
filename = os.path.basename(path)
print(path)
img = cv2.imread(path)
if img is None:
# invalid image
shutil.copy(path, INVALID_DIR)
continue
height, width, _ = img.shape
aspect_ratio_limit = 4
if width / height > aspect_ratio_limit or height / width > aspect_ratio_limit:
pieces = split_image(img, TMP_DIR)
print(pieces)
content = ''
for p in pieces:
result = ocr.ocr(p)
assert result is not None
content += result_to_content(result)
content += '\n'
print(Colors.OKCYAN + p + Colors.ENDC)
# TODO: workaround: empty JSON
database.execute("""insert into ocr (`path`, content, `json`)
values (?, ?, ?)""", [path, content, ''])
database.commit()
continue
result = ocr.ocr(path)
if result is None:
# append error list file
with open(ERR_LIST_FILE, 'a') as f:
f.write(path + '\n')
continue
json_output = json.dumps(result)
content = result_to_content(result)
database.execute("""insert into ocr (`path`, content, `json`)
values (?, ?, ?)""", [path, content, json_output])
database.commit()
print(Colors.OKCYAN + '[{}/{}] {}'.format(count, len(paths), filename) + Colors.ENDC)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment