Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save pythonlessons/3802b45d254c2d917c93c4e301473f09 to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/3802b45d254c2d917c93c4e301473f09 to your computer and use it in GitHub Desktop.
handwriting_recognition_pytorch
import cv2
import typing
import numpy as np
from mltu.inferenceModel import OnnxInferenceModel
from mltu.utils.text_utils import ctc_decoder, get_cer
class ImageToWordModel(OnnxInferenceModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def predict(self, image: np.ndarray):
image = cv2.resize(image, self.input_shape[:2][::-1])
image_pred = np.expand_dims(image, axis=0).astype(np.float32)
preds = self.model.run(None, {self.input_name: image_pred})[0]
text = ctc_decoder(preds, self.vocab)[0]
return text
if __name__ == "__main__":
import pandas as pd
from tqdm import tqdm
model = ImageToWordModel(model_path="Models/08_handwriting_recognition_torch/202303142139/model.onnx")
df = pd.read_csv("Models/08_handwriting_recognition_torch/202303142139/val.csv").values.tolist()
accum_cer = []
for image_path, label in tqdm(df):
image = cv2.imread(image_path)
prediction_text = model.predict(image)
cer = get_cer(prediction_text, label)
print(f"Image: {image_path}, Label: {label}, Prediction: {prediction_text}, CER: {cer}")
accum_cer.append(cer)
print(f"Average CER: {np.average(accum_cer)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment