Last active
September 8, 2024 09:59
-
-
Save vadimkantorov/2296f3f54b490fb39037cc4603e02d57 to your computer and use it in GitHub Desktop.
Example of running the captcha breaker from https://huggingface.co/spaces/docparser/Text_Captcha_breaker without PyTorch as a dependency
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# before running download captcha.onnx from https://huggingface.co/spaces/docparser/Text_Captcha_breaker | |
# python -m pip install numpy pillow onnxruntime --user --break-system-packages | |
import argparse | |
import PIL.Image | |
import numpy | |
import onnxruntime | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model-path', default = 'captcha.onnx') | |
parser.add_argument('--input-image-path', '-i', default = 'example_2A5Z.png') | |
parser.add_argument('--charset', default = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~") | |
parser.add_argument('--image-height-width', type = int, nargs = 2, default = [32, 128]) | |
parser.add_argument('--raw', action = 'store_true') | |
args = parser.parse_args() | |
ort_session = onnxruntime.InferenceSession(args.model_path) | |
img = numpy.asarray(PIL.Image.open(args.input_image_path).convert('RGB').resize(args.image_height_width[::-1], PIL.Image.BICUBIC)) | |
assert img.dtype == numpy.uint8 and img.ndim == 3 and img.shape[-1] == 3 | |
x = numpy.moveaxis((numpy.divide(img, 255.0, dtype = numpy.float32) - 0.5) / 0.5, -1, 0)[None, ...] | |
token_dists = ort_session.run(None, {ort_session.get_inputs()[0].name: x})[0] | |
dist = token_dists[0] | |
specials_first = (EOS,) = ('[E]',) | |
specials_last = (BOS, PAD) = ('[B]', '[P]') | |
_itos = specials_first + tuple(args.charset+'[UNK]') + specials_last | |
_stoi = {s: i for i, s in enumerate(_itos)} | |
eos_id, bos_id, pad_id = [_stoi[s] for s in specials_first + specials_last] | |
ids = numpy.argmax(dist, -1) # greedy selection | |
if not args.raw: | |
"""Internal method which performs the necessary filtering prior to decoding.""" | |
ids = ids.tolist() | |
try: | |
eos_idx = ids.index(eos_id) | |
except ValueError: | |
eos_idx = len(ids) # Nothing to truncate. | |
# Truncate after EOS | |
ids = ids[:eos_idx] | |
tokens = [_itos[i] for i in ids] | |
tokens = ''.join(tokens) if (not args.raw) else tokens | |
print(tokens) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment