Created
April 12, 2024 19:29
-
-
Save imiraoui/8fa88654de7ed7e9ef6805a7cf814b73 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from modal import Image, Mount, Stub, asgi_app, gpu, method | |
from PIL import Image as Image2 | |
from typing import List | |
from io import BytesIO | |
from fastapi import FastAPI, Request | |
from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor | |
from surya.model.recognition.model import load_model as load_rec_model | |
from surya.model.recognition.processor import load_processor as load_rec_processor | |
from surya.ocr import run_ocr | |
import base64 | |
def download_models(): | |
det_processor, det_model = load_det_processor(), load_det_model() | |
rec_model, rec_processor = load_rec_model(), load_rec_processor() | |
ocr_image = ( | |
Image.debian_slim() | |
.apt_install( | |
"libglib2.0-0", "libsm6", "libxrender1", "libxext6", "ffmpeg", "libgl1","git" | |
) | |
.pip_install("fastapi==0.110.1") | |
.pip_install("transformers~=4.36.2", | |
"accelerate~=0.23", | |
"safetensors~=0.3") | |
.pip_install( | |
"surya-ocr" | |
) | |
.run_commands("git clone https://github.com/VikParuchuri/surya.git") | |
.run_commands("cd surya") | |
.run_function(download_models) | |
) | |
stub = Stub("run-ocr") | |
web_app = FastAPI() | |
@stub.cls(gpu=gpu.T4(), container_idle_timeout=180, image=ocr_image) | |
class Model: | |
def __enter__(self): | |
self.det_processor, self.det_model = load_det_processor(), load_det_model() | |
self.rec_model, self.rec_processor = load_rec_model(), load_rec_processor() | |
@method() | |
def inference(self,langs, images): | |
predictions = run_ocr( | |
images, [langs], self.det_model, self.det_processor, self.rec_model, self.rec_processor | |
) | |
return predictions | |
@stub.local_entrypoint() | |
def main(image_base64: str,langs: str): | |
images = [] | |
image_data = base64.b64decode(image_base64) | |
image = Image2.open(BytesIO(image_data)) | |
predictions = Model().inference.remote(langs, [image]) | |
@stub.function(image=ocr_image, container_idle_timeout=45) | |
@asgi_app() | |
def fastapi_app(): | |
return web_app | |
@web_app.post("/runOcr") | |
async def runOcr(request: Request): | |
body = await request.json() | |
langs = body["langs"] | |
image_base64 = body["image_base64"] | |
image_data = base64.b64decode(image_base64) | |
image = Image2.open(BytesIO(image_data)) | |
predictions = Model().inference.remote(langs, [image]) | |
return predictions |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment