Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active September 8, 2024 09:59
Show Gist options
  • Save vadimkantorov/2296f3f54b490fb39037cc4603e02d57 to your computer and use it in GitHub Desktop.
Save vadimkantorov/2296f3f54b490fb39037cc4603e02d57 to your computer and use it in GitHub Desktop.
Example of running the captcha breaker from https://huggingface.co/spaces/docparser/Text_Captcha_breaker without PyTorch as a dependency
# before running download captcha.onnx from https://huggingface.co/spaces/docparser/Text_Captcha_breaker
# python -m pip install numpy pillow onnxruntime --user --break-system-packages
import argparse
import PIL.Image
import numpy
import onnxruntime
parser = argparse.ArgumentParser()
parser.add_argument('--model-path', default = 'captcha.onnx')
parser.add_argument('--input-image-path', '-i', default = 'example_2A5Z.png')
parser.add_argument('--charset', default = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~")
parser.add_argument('--image-height-width', type = int, nargs = 2, default = [32, 128])
parser.add_argument('--raw', action = 'store_true')
args = parser.parse_args()
ort_session = onnxruntime.InferenceSession(args.model_path)
img = numpy.asarray(PIL.Image.open(args.input_image_path).convert('RGB').resize(args.image_height_width[::-1], PIL.Image.BICUBIC))
assert img.dtype == numpy.uint8 and img.ndim == 3 and img.shape[-1] == 3
x = numpy.moveaxis((numpy.divide(img, 255.0, dtype = numpy.float32) - 0.5) / 0.5, -1, 0)[None, ...]
token_dists = ort_session.run(None, {ort_session.get_inputs()[0].name: x})[0]
dist = token_dists[0]
specials_first = (EOS,) = ('[E]',)
specials_last = (BOS, PAD) = ('[B]', '[P]')
_itos = specials_first + tuple(args.charset+'[UNK]') + specials_last
_stoi = {s: i for i, s in enumerate(_itos)}
eos_id, bos_id, pad_id = [_stoi[s] for s in specials_first + specials_last]
ids = numpy.argmax(dist, -1) # greedy selection
if not args.raw:
"""Internal method which performs the necessary filtering prior to decoding."""
ids = ids.tolist()
try:
eos_idx = ids.index(eos_id)
except ValueError:
eos_idx = len(ids) # Nothing to truncate.
# Truncate after EOS
ids = ids[:eos_idx]
tokens = [_itos[i] for i in ids]
tokens = ''.join(tokens) if (not args.raw) else tokens
print(tokens)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment