Last active
May 22, 2019 00:14
-
-
Save wosiu/9fa50de9e47615b5fa08b23637e1f947 to your computer and use it in GitHub Desktop.
Calamari OCR wrapper
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Calamari OCR Wrapper for in-code usage. | |
It assumes that images are already loaded into memory as np.array. | |
Extended with predictions postprocessing API. | |
For now character whitelisting is implemented. | |
Calamari OCR: https://github.com/Calamari-OCR/calamari/ | |
License: Apache 2.0 | |
''' | |
import abc | |
import os | |
import logging | |
import numpy as np | |
from typing import Dict, Tuple | |
from calamari_ocr.ocr.datasets import DataSetType, create_dataset, DataSetMode, InputDataset, RawDataSet | |
from calamari_ocr.ocr import Predictor | |
from calamari_ocr.proto import Predictions | |
logger = logging.getLogger(__name__) | |
class PredictionPostprocessor: | |
def __init__(self): | |
pass | |
''' | |
prediction is protocol buffer with fields: | |
.sentence | |
.positions | |
.chars | |
.char | |
.probability | |
''' | |
@abc.abstractmethod | |
def process_prediction(self, fid: str, prediction, **kwargs): | |
pass | |
def __call__(self, fid: str, prediction, **kwargs): | |
self.process_prediction(fid, prediction, **kwargs) | |
class CalamariWrapper: | |
def __init__(self, model_ckpt_path: str, lazy_load=False): | |
self.model_path = model_ckpt_path | |
self.predictor = None | |
if not lazy_load: | |
self.load_model() | |
def load_model(self): | |
if not self.predictor: | |
self.predictor = Predictor(checkpoint=self.model_path, | |
batch_size=1, | |
auto_update_checkpoints=False, | |
processes=os.cpu_count()) | |
logger.info("Model %s loaded.", self.model_path) | |
def _predict_dataset(self, dataset): | |
""" Predict a complete dataset. Based on calamari_ocr.predictor.Predicator.predict_dataset(...) | |
Parameters | |
---------- | |
dataset : Dataset | |
Dataset to predict | |
Yields | |
------- | |
PredictionResult | |
Single PredictionResult | |
""" | |
self.load_model() | |
input_dataset = InputDataset(dataset, self.predictor.data_preproc, None) | |
return self.predictor.predict_input_dataset(input_dataset, progress_bar=False) | |
def ocr_batch(self, batch: Dict[str, np.array], prediction_postprocessor: PredictionPostprocessor = None, | |
**kwargs) -> Tuple[Dict[str, str], Dict[str, float]]: | |
""" | |
Predicts batch of images already loaded into memory. | |
Based on calamari_ocr.scripts.predict.run(...) | |
:param batch: dictionary with mapping: text line id -> image with text line already loaded as np.array | |
:param prediction_postprocessor: optional implementation of PredictionPostprocessor | |
:param kwargs: arguments for prediction postprocessor | |
:return: two dictionaries: 1) mapping: text line id -> recognized text, 2) mapping: text line id -> confidence | |
""" | |
images = batch.values() | |
dataset = RawDataSet(DataSetMode.PREDICT, images, None) | |
logger.info("Found %s images in the dataset", len(dataset)) | |
if len(dataset) == 0: | |
raise Exception("Empty dataset provided.") | |
predictions = self._predict_dataset(dataset) | |
txt_results = {} | |
conf = {} | |
for fid, result in zip(batch.keys(), predictions): | |
prediction = result.prediction | |
if prediction_postprocessor: | |
prediction_postprocessor(fid, prediction, **kwargs) | |
txt_results[fid] = prediction.sentence | |
conf[fid] = prediction.avg_char_probability * 100 | |
return txt_results, conf | |
class WhitelistPostProc(PredictionPostprocessor): | |
""" | |
Postprocess calamari prediction by choosing characters presented in the whitelist and with the highest confidence. | |
""" | |
def __init__(self, whitelist: str): | |
super().__init__() | |
self.whitelist = set(whitelist) | |
self.whitelist.add('') | |
def choose_char(self, chrs): | |
chosen_prob = 0.0 | |
chosen = '' | |
for ch in chrs: | |
if ch.char not in self.whitelist: | |
continue | |
if ch.probability > chosen_prob: | |
chosen_prob = ch.probability | |
chosen = ch.char | |
return chosen, chosen_prob | |
def process_prediction(self, fid: str, prediction, **kwargs): | |
if not prediction.positions: | |
return | |
chars, prob = zip(*[self.choose_char(pos.chars) for pos in prediction.positions]) | |
prediction.sentence = "".join(chars) | |
prediction.avg_char_probability = float(np.mean(prob)) | |
class BatchPredictionPostProcMapper(PredictionPostprocessor): | |
""" | |
Aggregates other prediction postprocessors, so that each text lines in one batch might have different postprocessors. | |
""" | |
def __init__(self, postprocessors: Dict[str, PredictionPostprocessor] = {}): | |
super().__init__() | |
self.postprocessors = postprocessors | |
def set_postprocessor(self, fid: str, postprocessor: PredictionPostprocessor): | |
self.postprocessors[fid] = postprocessor | |
def process_prediction(self, fid: str, prediction, **kwargs): | |
p = self.postprocessors.get(fid, None) | |
if p: | |
p.process_prediction(fid, prediction, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment