Created
June 17, 2024 22:42
-
-
Save twilligon/f0e89bc98cd2aee803f207796c46ab29 to your computer and use it in GitHub Desktop.
Document OCR via GPT-4o
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| from concurrent.futures import ThreadPoolExecutor | |
| from io import BytesIO | |
| import asyncio | |
| import base64 | |
| import random | |
| import sys | |
| from openai import AsyncOpenAI, APIError, RateLimitError | |
| from openai.types.completion import CompletionUsage | |
| from PIL import Image | |
| MAX_TOKENS = 4095 | |
| MAX_IMAGE_SIZE = 2000 | |
| MAX_IMAGE_BYTES = 20 * 1024 * 1024 | |
| MAX_RETRIES = 10 | |
| page_completions = asyncio.Queue(1) | |
| def encode_image_url(path): | |
| img = Image.open(path) | |
| if max(img.size) > MAX_IMAGE_SIZE: | |
| img.thumbnail((MAX_IMAGE_SIZE, MAX_IMAGE_SIZE)) | |
| output = BytesIO() | |
| img.save(output, format="PNG") | |
| png = output.getvalue() | |
| if len(png) > MAX_IMAGE_BYTES: | |
| raise ValueError(f"Resized image size exceeds 20 MB: {path}") | |
| return "data:image/png;base64," + base64.b64encode(png).decode("utf-8") | |
| async def transcribe_page(client, executor, path, queue): | |
| global page_completions | |
| image_url = await asyncio.get_event_loop().run_in_executor( | |
| executor, encode_image_url, path | |
| ) | |
| retries = 0 | |
| delay = 1 | |
| while True: | |
| try: | |
| response = await client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=( | |
| { | |
| "role": "system", | |
| "content": "Transcribe the entire document verbatim. Output nothing but the transcription.", | |
| }, | |
| { | |
| "role": "user", | |
| "content": ( | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": image_url, | |
| "detail": "high", | |
| }, | |
| }, | |
| ), | |
| }, | |
| ), | |
| max_tokens=MAX_TOKENS, | |
| temperature=0, | |
| stream=True, | |
| stream_options={"include_usage": True}, | |
| ) | |
| await queue.put(()) | |
| async for chunk in response: | |
| if chunk.choices and chunk.choices[0].delta.content: | |
| await queue.put(chunk.choices[0].delta.content) | |
| await queue.put(chunk.usage) | |
| await page_completions.put("transcribe") | |
| return | |
| except (APIError, RateLimitError): | |
| # drain queue | |
| # TODO: what if we were reading from it rn? | |
| try: | |
| while True: | |
| queue.get_nowait() | |
| except asyncio.QueueEmpty: | |
| pass | |
| retries += 1 | |
| if retries > MAX_RETRIES: | |
| raise | |
| delay *= 2 * (1 + random.random()) | |
| await asyncio.sleep(delay) | |
| async def status(): | |
| global page_completions | |
| transcribed = 0 | |
| pages = len(sys.argv[1:]) | |
| while True: | |
| print(f"\r{transcribed}/{pages} pages transcribed", end="", file=sys.stderr) | |
| if transcribed >= pages: | |
| break | |
| await page_completions.get() | |
| transcribed += 1 | |
| print(file=sys.stderr) | |
| async def main(): | |
| with open("api_key") as f: | |
| api_key = f.read().strip() | |
| client = AsyncOpenAI(api_key=api_key) | |
| pages = sys.argv[1:] | |
| with ThreadPoolExecutor() as executor: | |
| async with asyncio.TaskGroup() as tg: | |
| tg.create_task(status()) | |
| queues = [] | |
| for page in pages: | |
| queue = asyncio.Queue() | |
| tg.create_task(transcribe_page(client, executor, page, queue)) | |
| queues.append(queue) | |
| for page, queue in zip(pages, queues): | |
| assert await queue.get() == () | |
| while True: | |
| event = await queue.get() | |
| if isinstance(event, CompletionUsage): | |
| if event.completion_tokens >= MAX_TOKENS: | |
| print( | |
| f"warning: {page} reached the maximum token limit and may be truncated", | |
| file=sys.stderr, | |
| ) | |
| break | |
| elif event is None: | |
| break | |
| elif event == (): | |
| raise Exception("had to retry query while streaming from it") | |
| else: | |
| print(event, end="") | |
| print() | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| pdftoppm -png -progress church.pdf page | |
| ./4ocr.py page*.png > transcription.txt | |
| # remember to check for hallucinations! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
License: CC0-1.0