[
{
"content": "reasoning language: English\n\nYou are an intelligent assistant that can answer customer service queries",
"role": "system",
"thinking": null
},
{
"content": "Can you provide me with a list of the top-rated series currently on Netflix?",
"role": "user",
"thinking": null
},
{
"content": "Netflix does not publicly release real-time lists of its top-rated series, but you can find updated information from third-party platforms and review sites....",
"role": "assistant",
"thinking": "\nOkay, the user is asking for the top-rated series currently on Netflix. Let me start by recalling that Netflix doesn't publicly release their current top-rated lists in real-time....\n"
}
]
Last active
November 3, 2025 23:26
-
-
Save dhruvilp/ec5c62178c2bf22d6c0c6a552a674e8d to your computer and use it in GitHub Desktop.
granite-docling-258m inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """FastAPI app for document conversion using Granite Docling model.""" | |
| import base64 | |
| import html | |
| import io | |
| import re | |
| from pathlib import Path | |
| from typing import Optional, Literal | |
| import asyncio | |
| import torch | |
| import numpy as np | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Form | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from pydantic import BaseModel, Field | |
| from PIL import Image, ImageDraw, ImageOps | |
| from docling_core.types.doc import DoclingDocument | |
| from docling_core.types.doc.document import DocTagsDocument | |
| from transformers import ( | |
| AutoProcessor, | |
| Idefics3ForConditionalGeneration, | |
| ) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Granite Docling Document Converter", | |
| description="Convert documents to markdown, doctags, and JSON using Granite Docling", | |
| version="1.0.0" | |
| ) | |
| # Global variables for model and processor | |
| processor = None | |
| model = None | |
| device = None | |
| # Response models | |
| class ConversionResponse(BaseModel): | |
| """Response model for document conversion.""" | |
| success: bool | |
| format: str | |
| content: str | |
| raw_output: Optional[str] = None | |
| message: Optional[str] = None | |
| class BoundingBoxResponse(BaseModel): | |
| """Response model with bounding boxes.""" | |
| success: bool | |
| content: str | |
| annotated_image: Optional[str] = None # Base64 encoded | |
| has_bounding_boxes: bool | |
| class HealthResponse(BaseModel): | |
| """Health check response.""" | |
| status: str | |
| model_loaded: bool | |
| device: str | |
| # Utility functions | |
| def clean_model_response(text: str) -> str: | |
| """Clean up model response by removing special tokens.""" | |
| if not text: | |
| return "No response generated." | |
| special_tokens = [ | |
| "<|end_of_text|>", | |
| "<|end|>", | |
| "<|assistant|>", | |
| "<|user|>", | |
| "<|system|>", | |
| "<pad>", | |
| "</s>", | |
| "<s>", | |
| ] | |
| cleaned = text | |
| for token in special_tokens: | |
| cleaned = cleaned.replace(token, "") | |
| cleaned = cleaned.strip() | |
| if not cleaned or len(cleaned) == 0: | |
| return "The model generated a response, but it appears to be empty." | |
| return cleaned | |
| def draw_bounding_boxes(image: Image.Image, response_text: str, is_doctag_response: bool = False) -> Image.Image: | |
| """Draw bounding boxes on the image based on loc tags.""" | |
| try: | |
| draw = ImageDraw.Draw(image) | |
| width, height = image.size | |
| # Color mapping for different classes | |
| class_colors = { | |
| "caption": "#FFCC99", | |
| "footnote": "#C8C8FF", | |
| "formula": "#C0C0C0", | |
| "list_item": "#9999FF", | |
| "page_footer": "#CCFFCC", | |
| "page_header": "#CCFFCC", | |
| "picture": "#FFCCA4", | |
| "chart": "#FFCCA4", | |
| "section_header": "#FF9999", | |
| "table": "#FFCCCC", | |
| "text": "#FFFF99", | |
| "title": "#FF9999", | |
| "document_index": "#DCDCDC", | |
| "code": "#7D7D7D", | |
| "paragraph": "#FFFF99", | |
| } | |
| doctag_class_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>[^<]*</[^>]+>" | |
| doctag_matches = re.findall(doctag_class_pattern, response_text) | |
| class_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>" | |
| class_matches = re.findall(class_pattern, response_text) | |
| seen_coords = set() | |
| all_class_matches = [] | |
| for match in doctag_matches: | |
| coords = (match[1], match[2], match[3], match[4]) | |
| if coords not in seen_coords: | |
| seen_coords.add(coords) | |
| all_class_matches.append(match) | |
| for match in class_matches: | |
| coords = (match[1], match[2], match[3], match[4]) | |
| if coords not in seen_coords: | |
| seen_coords.add(coords) | |
| all_class_matches.append(match) | |
| for class_name, xmin, ymin, xmax, ymax in all_class_matches: | |
| if is_doctag_response: | |
| color = class_colors.get(class_name.lower(), "#808080") | |
| else: | |
| color = "#E0115F" | |
| x1 = int((int(xmin) / 500) * width) | |
| y1 = int((int(ymin) / 500) * height) | |
| x2 = int((int(xmax) / 500) * width) | |
| y2 = int((int(ymax) / 500) * height) | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=3) | |
| return image | |
| except Exception: | |
| return image | |
| def image_to_base64(image: Image.Image) -> str: | |
| """Convert PIL Image to base64 string.""" | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode() | |
| # Model initialization | |
| def initialize_model(model_path: str): | |
| """Initialize the model and processor from local path.""" | |
| global processor, model, device | |
| device = torch.device( | |
| "cuda" if torch.cuda.is_available() | |
| else "mps" if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| processor = AutoProcessor.from_pretrained(model_path, local_files_only=True) | |
| model = Idefics3ForConditionalGeneration.from_pretrained( | |
| model_path, | |
| device_map=device, | |
| torch_dtype=torch.bfloat16, | |
| local_files_only=True | |
| ) | |
| if not torch.cuda.is_available(): | |
| model = model.to(device) | |
| print(f"Model loaded successfully on {device}") | |
| def generate_response(question: str, image: Image.Image) -> str: | |
| """Generate response using the model.""" | |
| if model is None or processor is None: | |
| raise HTTPException(status_code=500, detail="Model not initialized") | |
| try: | |
| image = image.convert("RGB") | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": question}, | |
| ], | |
| } | |
| ] | |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| temperature = 0.0 | |
| inputs = processor(text=prompt, images=[image], return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=4096, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| ) | |
| generated_texts = processor.batch_decode( | |
| generated_ids[:, inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=False, | |
| )[0] | |
| return clean_model_response(generated_texts) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
| async def generate_response_async(question: str, image: Image.Image) -> str: | |
| """Async version of generate_response.""" | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(None, generate_response, question, image) | |
| # API Endpoints | |
| @app.get("/health", response_model=HealthResponse) | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return HealthResponse( | |
| status="healthy" if model is not None else "model_not_loaded", | |
| model_loaded=model is not None, | |
| device=str(device) if device is not None else "not_initialized" | |
| ) | |
| @app.post("/initialize") | |
| async def initialize(model_path: str = Form(...)): | |
| """Initialize the model from local path.""" | |
| try: | |
| initialize_model(model_path) | |
| return {"success": True, "message": "Model initialized successfully"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to initialize model: {str(e)}") | |
| @app.post("/convert/markdown", response_model=ConversionResponse) | |
| async def convert_to_markdown_sync( | |
| file: UploadFile = File(...), | |
| prompt: str = Form(default="Convert this page to docling.") | |
| ): | |
| """Synchronous endpoint to convert document to markdown.""" | |
| try: | |
| # Read and process image | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| # Generate response | |
| raw_output = generate_response(prompt, image) | |
| # Convert to markdown | |
| try: | |
| doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([raw_output], [image]) | |
| doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
| markdown_output = doc.export_to_markdown() | |
| return ConversionResponse( | |
| success=True, | |
| format="markdown", | |
| content=markdown_output, | |
| raw_output=raw_output, | |
| message="Successfully converted to markdown" | |
| ) | |
| except Exception as e: | |
| return ConversionResponse( | |
| success=False, | |
| format="markdown", | |
| content="", | |
| raw_output=raw_output, | |
| message=f"Error converting to markdown: {str(e)}" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| @app.post("/convert/markdown/async", response_model=ConversionResponse) | |
| async def convert_to_markdown_async_endpoint( | |
| file: UploadFile = File(...), | |
| prompt: str = Form(default="Convert this page to docling.") | |
| ): | |
| """Asynchronous endpoint to convert document to markdown.""" | |
| try: | |
| # Read and process image | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| # Generate response asynchronously | |
| raw_output = await generate_response_async(prompt, image) | |
| # Convert to markdown | |
| try: | |
| doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([raw_output], [image]) | |
| doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
| markdown_output = doc.export_to_markdown() | |
| return ConversionResponse( | |
| success=True, | |
| format="markdown", | |
| content=markdown_output, | |
| raw_output=raw_output, | |
| message="Successfully converted to markdown" | |
| ) | |
| except Exception as e: | |
| return ConversionResponse( | |
| success=False, | |
| format="markdown", | |
| content="", | |
| raw_output=raw_output, | |
| message=f"Error converting to markdown: {str(e)}" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| @app.post("/convert/doctags", response_model=ConversionResponse) | |
| async def convert_to_doctags( | |
| file: UploadFile = File(...), | |
| prompt: str = Form(default="Convert this page to docling.") | |
| ): | |
| """Synchronous endpoint to convert document to doctags.""" | |
| try: | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| raw_output = generate_response(prompt, image) | |
| return ConversionResponse( | |
| success=True, | |
| format="doctags", | |
| content=raw_output, | |
| raw_output=raw_output, | |
| message="Successfully generated doctags" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| @app.post("/convert/json") | |
| async def convert_to_json( | |
| file: UploadFile = File(...), | |
| prompt: str = Form(default="Convert this page to docling.") | |
| ): | |
| """Synchronous endpoint to convert document to JSON dict.""" | |
| try: | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| raw_output = generate_response(prompt, image) | |
| # Try to convert to DoclingDocument and export to dict | |
| try: | |
| doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([raw_output], [image]) | |
| doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
| json_output = doc.export_to_dict() | |
| return { | |
| "success": True, | |
| "format": "json", | |
| "content": json_output, | |
| "raw_output": raw_output, | |
| "message": "Successfully converted to JSON" | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "format": "json", | |
| "content": {}, | |
| "raw_output": raw_output, | |
| "message": f"Error converting to JSON: {str(e)}" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| @app.post("/convert/json/async") | |
| async def convert_to_json_async( | |
| file: UploadFile = File(...), | |
| prompt: str = Form(default="Convert this page to docling.") | |
| ): | |
| """Asynchronous endpoint to convert document to JSON dict.""" | |
| try: | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| raw_output = await generate_response_async(prompt, image) | |
| # Try to convert to DoclingDocument and export to dict | |
| try: | |
| doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([raw_output], [image]) | |
| doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
| json_output = doc.export_to_dict() | |
| return { | |
| "success": True, | |
| "format": "json", | |
| "content": json_output, | |
| "raw_output": raw_output, | |
| "message": "Successfully converted to JSON" | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "format": "json", | |
| "content": {}, | |
| "raw_output": raw_output, | |
| "message": f"Error converting to JSON: {str(e)}" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| @app.post("/convert/with-boxes", response_model=BoundingBoxResponse) | |
| async def convert_with_bounding_boxes( | |
| file: UploadFile = File(...), | |
| prompt: str = Form(default="Convert this page to docling.") | |
| ): | |
| """Convert document and return with bounding boxes visualization.""" | |
| try: | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| raw_output = generate_response(prompt, image) | |
| # Check for location tags | |
| has_doctag = "<doctag>" in raw_output | |
| class_loc_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>" | |
| loc_only_pattern = r"<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>" | |
| has_loc_tags = bool(re.findall(class_loc_pattern, raw_output) or re.findall(loc_only_pattern, raw_output)) | |
| annotated_image_b64 = None | |
| if has_loc_tags: | |
| annotated_image = draw_bounding_boxes(image.copy(), raw_output, has_doctag) | |
| annotated_image_b64 = image_to_base64(annotated_image) | |
| return BoundingBoxResponse( | |
| success=True, | |
| content=raw_output, | |
| annotated_image=annotated_image_b64, | |
| has_bounding_boxes=has_loc_tags | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| @app.post("/query") | |
| async def query_document( | |
| file: UploadFile = File(...), | |
| question: str = Form(...) | |
| ): | |
| """General purpose endpoint to query a document with any question.""" | |
| try: | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| response = generate_response(question, image) | |
| return { | |
| "success": True, | |
| "question": question, | |
| "answer": response | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| @app.post("/query/async") | |
| async def query_document_async( | |
| file: UploadFile = File(...), | |
| question: str = Form(...) | |
| ): | |
| """Async general purpose endpoint to query a document.""" | |
| try: | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| response = await generate_response_async(question, image) | |
| return { | |
| "success": True, | |
| "question": question, | |
| "answer": response | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_response_streaming(question: str, image: Image.Image): | |
| """Generate response with streaming.""" | |
| if model is None or processor is None: | |
| yield "Error: Model not initialized" | |
| return | |
| try: | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| image = image.convert("RGB") | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": question}, | |
| ], | |
| } | |
| ] | |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| temperature = 0.0 | |
| inputs = processor(text=prompt, images=[image], return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False) | |
| generation_args = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=4096, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_args) | |
| thread.start() | |
| for new_text in streamer: | |
| cleaned = clean_model_response(new_text) | |
| if cleaned and cleaned != "No response generated.": | |
| yield f"data: {cleaned}\n\n" | |
| thread.join() | |
| yield "data: [DONE]\n\n" | |
| except Exception as e: | |
| yield f"data: Error: {str(e)}\n\n" | |
| @app.post("/query/stream") | |
| async def query_document_stream( | |
| file: UploadFile = File(...), | |
| question: str = Form(...) | |
| ): | |
| """Stream response for document query (Server-Sent Events).""" | |
| try: | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| return StreamingResponse( | |
| generate_response_streaming(question, image), | |
| media_type="text/event-stream" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Initialize model at startup if MODEL_PATH is set | |
| import os | |
| model_path = os.getenv("MODEL_PATH") | |
| if model_path: | |
| print(f"Initializing model from {model_path}") | |
| initialize_model(model_path) | |
| else: | |
| print("MODEL_PATH not set. Use /initialize endpoint to load model.") | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import html | |
| import os | |
| import random | |
| import re | |
| import time | |
| from pathlib import Path | |
| from threading import Thread | |
| import numpy as np | |
| import torch | |
| from docling_core.types.doc import DoclingDocument | |
| from docling_core.types.doc.document import DocTagsDocument | |
| from PIL import Image, ImageDraw, ImageOps | |
| from transformers import ( | |
| AutoProcessor, | |
| Idefics3ForConditionalGeneration, | |
| TextIteratorStreamer, | |
| ) | |
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.responses import JSONResponse | |
| import io | |
| import base64 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") | |
| # Replace with your actual local path to the model | |
| local_path = "./granite-docling-258M" | |
| processor = AutoProcessor.from_pretrained(local_path) | |
| model = Idefics3ForConditionalGeneration.from_pretrained( | |
| local_path, device_map=device, torch_dtype=torch.bfloat16 | |
| ) | |
| if not torch.cuda.is_available(): | |
| model = model.to(device) | |
| def add_random_padding(image: Image.Image, min_percent: float = 0.1, max_percent: float = 0.10) -> Image.Image: | |
| """Add random padding to an image.""" | |
| image = image.convert("RGB") | |
| width, height = image.size | |
| pad_w_percent = random.uniform(min_percent, max_percent) | |
| pad_h_percent = random.uniform(min_percent, max_percent) | |
| pad_w = int(width * pad_w_percent) | |
| pad_h = int(height * pad_h_percent) | |
| corner_pixel = image.getpixel((0, 0)) # Top-left corner | |
| padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel) | |
| return padded_image | |
| def draw_bounding_boxes(image: Image.Image, response_text: str, is_doctag_response: bool = False) -> Image.Image: | |
| """Draw bounding boxes on the image based on loc tags and return the annotated image.""" | |
| try: | |
| image = image.convert("RGB") | |
| draw = ImageDraw.Draw(image) | |
| # Get image dimensions | |
| width, height = image.size | |
| # Color mapping for different classes (RGB values converted to hex) | |
| class_colors = { | |
| "caption": "#FFCC99", # (255, 204, 153) | |
| "footnote": "#C8C8FF", # (200, 200, 255) | |
| "formula": "#C0C0C0", # (192, 192, 192) | |
| "list_item": "#9999FF", # (153, 153, 255) | |
| "page_footer": "#CCFFCC", # (204, 255, 204) | |
| "page_header": "#CCFFCC", # (204, 255, 204) | |
| "picture": "#FFCCA4", # (255, 204, 164) | |
| "chart": "#FFCCA4", # (255, 204, 164) | |
| "section_header": "#FF9999", # (255, 153, 153) | |
| "table": "#FFCCCC", # (255, 204, 204) | |
| "text": "#FFFF99", # (255, 255, 153) | |
| "title": "#FF9999", # (255, 153, 153) | |
| "document_index": "#DCDCDC", # (220, 220, 220) | |
| "code": "#7D7D7D", # (125, 125, 125) | |
| "checkbox_selected": "#FFB6C1", # (255, 182, 193) | |
| "checkbox_unselected": "#FFB6C1", # (255, 182, 193) | |
| "form": "#C8FFFF", # (200, 255, 255) | |
| "key_value_region": "#B7410E", # (183, 65, 14) | |
| "paragraph": "#FFFF99", # (255, 255, 153) | |
| "reference": "#B0E0E6", # (176, 224, 230) | |
| "grading_scale": "#FFCCCC", # (255, 204, 204) | |
| "handwritten_text": "#CCFFCC", # (204, 255, 204) | |
| "empty_value": "#DCDCDC", # (220, 220, 220) | |
| } | |
| doctag_class_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>[^<]*</[^>]+>" | |
| doctag_matches = re.findall(doctag_class_pattern, response_text) | |
| class_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>" | |
| class_matches = re.findall(class_pattern, response_text) | |
| seen_coords = set() | |
| all_class_matches = [] | |
| for match in doctag_matches: | |
| coords = (match[1], match[2], match[3], match[4]) | |
| if coords not in seen_coords: | |
| seen_coords.add(coords) | |
| all_class_matches.append(match) | |
| for match in class_matches: | |
| coords = (match[1], match[2], match[3], match[4]) | |
| if coords not in seen_coords: | |
| seen_coords.add(coords) | |
| all_class_matches.append(match) | |
| loc_only_pattern = r"<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>" | |
| loc_only_matches = re.findall(loc_only_pattern, response_text) | |
| for class_name, xmin, ymin, xmax, ymax in all_class_matches: | |
| if is_doctag_response: | |
| color = class_colors.get(class_name.lower(), None) | |
| if color is None: | |
| for key in class_colors: | |
| if class_name.lower() in key or key in class_name.lower(): | |
| color = class_colors[key] | |
| break | |
| if color is None: | |
| color = "#808080" | |
| else: | |
| color = "#E0115F" | |
| x1 = int((int(xmin) / 500) * width) | |
| y1 = int((int(ymin) / 500) * height) | |
| x2 = int((int(xmax) / 500) * width) | |
| y2 = int((int(ymax) / 500) * height) | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=3) | |
| for xmin, ymin, xmax, ymax in loc_only_matches: | |
| if is_doctag_response: | |
| continue | |
| else: | |
| color = "#808080" | |
| x1 = int((int(xmin) / 500) * width) | |
| y1 = int((int(ymin) / 500) * height) | |
| x2 = int((int(xmax) / 500) * width) | |
| y2 = int((int(ymax) / 500) * height) | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=3) | |
| return image | |
| except Exception: | |
| return image | |
| def clean_model_response(text: str) -> str: | |
| """Clean up model response by removing special tokens and formatting properly.""" | |
| if not text: | |
| return "No response generated." | |
| special_tokens = [ | |
| "<|end_of_text|>", | |
| "<|end|>", | |
| "<|assistant|>", | |
| "<|user|>", | |
| "<|system|>", | |
| "<pad>", | |
| "</s>", | |
| "<s>", | |
| ] | |
| cleaned = text | |
| for token in special_tokens: | |
| cleaned = cleaned.replace(token, "") | |
| cleaned = cleaned.strip() | |
| if not cleaned or len(cleaned) == 0: | |
| return "The model generated a response, but it appears to be empty or contain only special tokens." | |
| return cleaned | |
| def generate_with_model(question: str, image: Image.Image, apply_padding: bool = False) -> str: | |
| """Generate answer using the Granite Docling model directly on the image.""" | |
| if os.environ.get("NO_LLM"): | |
| time.sleep(2) | |
| return "This is a simulated response from the Granite Docling model." | |
| try: | |
| if apply_padding: | |
| image = add_random_padding(image) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": question}, | |
| ], | |
| } | |
| ] | |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| temperature = 0.0 | |
| inputs = processor(text=prompt, images=[image], return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=4096, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| ) | |
| generated_texts = processor.batch_decode( | |
| generated_ids[:, inputs["input_ids"].shape[1] :], | |
| skip_special_tokens=False, | |
| )[0] | |
| cleaned_response = clean_model_response(generated_texts) | |
| return cleaned_response | |
| except Exception as e: | |
| return f"Error processing image: {e!s}" | |
| def pil_to_bytes(image: Image.Image) -> bytes: | |
| buf = io.BytesIO() | |
| image.save(buf, format="PNG") | |
| return buf.getvalue() | |
| app = FastAPI() | |
| @app.post("/generate/sync") | |
| def generate_sync( | |
| prompt: str = Form(...), | |
| image: UploadFile = File(...), | |
| apply_padding: bool = Form(False) | |
| ): | |
| try: | |
| image_bytes = image.file.read() | |
| img = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| answer = generate_with_model(prompt, img, apply_padding) | |
| markdown_output = None | |
| json_output = None | |
| annotated_image_base64 = None | |
| class_loc_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>" | |
| class_loc_matches = re.findall(class_loc_pattern, answer) | |
| loc_only_pattern = r"<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>" | |
| loc_only_matches = re.findall(loc_only_pattern, answer) | |
| has_doctag = "<doctag>" in answer | |
| has_loc_tags = class_loc_matches or loc_only_matches | |
| if "convert this page to docling" in prompt.lower() or ("convert" in prompt.lower() and "otsl" in prompt.lower()): | |
| try: | |
| doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([answer], [img]) | |
| doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
| markdown_output = doc.export_to_markdown() | |
| # Assuming DoclingDocument is Pydantic-based or has model_dump_json | |
| try: | |
| json_output = doc.model_dump_json() | |
| except AttributeError: | |
| json_output = None # If not available, skip | |
| except Exception as e: | |
| return JSONResponse({"error": f"Error creating markdown output: {str(e)}"}) | |
| elif "convert formula to latex" in prompt.lower(): | |
| try: | |
| doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([answer], [img]) | |
| doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
| markdown_output = doc.export_to_markdown() | |
| if markdown_output.count("$$") >= 2: | |
| parts = markdown_output.split("$$", 2) | |
| formula = parts[1].strip() | |
| wrapped = f"$$\n\\begin{{aligned}}\n{formula}\n\\end{{aligned}}\n$$" | |
| markdown_output = parts[0] + wrapped + parts[2] | |
| # Assuming model_dump_json | |
| try: | |
| json_output = doc.model_dump_json() | |
| except AttributeError: | |
| json_output = None | |
| except Exception as e: | |
| return JSONResponse({"error": f"Error creating LaTeX output: {str(e)}"}) | |
| if has_loc_tags: | |
| try: | |
| annotated_image = draw_bounding_boxes(img, answer, is_doctag_response=has_doctag) | |
| annotated_image_base64 = base64.b64encode(pil_to_bytes(annotated_image)).decode("utf-8") | |
| except Exception as e: | |
| pass # Skip if error | |
| return JSONResponse({ | |
| "response": answer, | |
| "markdown": markdown_output, | |
| "json": json_output, | |
| "annotated_image_base64": annotated_image_base64 | |
| }) | |
| except Exception as e: | |
| return JSONResponse({"error": str(e)}) | |
| @app.post("/generate/async") | |
| async def generate_async( | |
| prompt: str = Form(...), | |
| image: UploadFile = File(...), | |
| apply_padding: bool = Form(False) | |
| ): | |
| # Same as sync, but async def | |
| try: | |
| image_bytes = await image.read() | |
| img = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| answer = generate_with_model(prompt, img, apply_padding) | |
| markdown_output = None | |
| json_output = None | |
| annotated_image_base64 = None | |
| class_loc_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>" | |
| class_loc_matches = re.findall(class_loc_pattern, answer) | |
| loc_only_pattern = r"<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>" | |
| loc_only_matches = re.findall(loc_only_pattern, answer) | |
| has_doctag = "<doctag>" in answer | |
| has_loc_tags = class_loc_matches or loc_only_matches | |
| if "convert this page to docling" in prompt.lower() or ("convert" in prompt.lower() and "otsl" in prompt.lower()): | |
| try: | |
| doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([answer], [img]) | |
| doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
| markdown_output = doc.export_to_markdown() | |
| try: | |
| json_output = doc.model_dump_json() | |
| except AttributeError: | |
| json_output = None | |
| except Exception as e: | |
| return JSONResponse({"error": f"Error creating markdown output: {str(e)}"}) | |
| elif "convert formula to latex" in prompt.lower(): | |
| try: | |
| doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([answer], [img]) | |
| doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
| markdown_output = doc.export_to_markdown() | |
| if markdown_output.count("$$") >= 2: | |
| parts = markdown_output.split("$$", 2) | |
| formula = parts[1].strip() | |
| wrapped = f"$$\n\\begin{{aligned}}\n{formula}\n\\end{{aligned}}\n$$" | |
| markdown_output = parts[0] + wrapped + parts[2] | |
| try: | |
| json_output = doc.model_dump_json() | |
| except AttributeError: | |
| json_output = None | |
| except Exception as e: | |
| return JSONResponse({"error": f"Error creating LaTeX output: {str(e)}"}) | |
| if has_loc_tags: | |
| try: | |
| annotated_image = draw_bounding_boxes(img, answer, is_doctag_response=has_doctag) | |
| annotated_image_base64 = base64.b64encode(pil_to_bytes(annotated_image)).decode("utf-8") | |
| except Exception as e: | |
| pass | |
| return JSONResponse({ | |
| "response": answer, | |
| "markdown": markdown_output, | |
| "json": json_output, | |
| "annotated_image_base64": annotated_image_base64 | |
| }) | |
| except Exception as e: | |
| return JSONResponse({"error": str(e)}) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import time | |
| from datetime import timedelta | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from transformers.image_utils import load_image | |
| from pathlib import Path | |
| # Load model and processor | |
| DEVICE = "cpu" | |
| if torch.cuda.is_available(): | |
| DEVICE = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| DEVICE = "mps" | |
| print(f"USING DEVICE: {DEVICE}") | |
| start_time = time.time() | |
| model_path = "/Users/ghart/models/ibm-granite/granite-docling-258M-text-only" | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForCausalLM.from_pretrained(model_path).to(DEVICE) | |
| print("==> Done loading model: {}s".format(timedelta(seconds=time.time() - start_time).total_seconds())) | |
| # Prepare inputs | |
| image = load_image("https://huggingface.co/ibm-granite/granite-docling-258M/resolve/main/assets/new_arxiv.png") | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| # {"type": "image"}, | |
| {"type": "text", "text": "Convert this page to docling."} | |
| ] | |
| }, | |
| ] | |
| prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| inputs = inputs.to(DEVICE) | |
| print("==> Done preparing inputs: {}s".format(timedelta(seconds=time.time() - start_time).total_seconds())) | |
| # Generate outputs | |
| generated_ids = model.generate(**inputs, max_new_tokens=512, use_cache=True) | |
| print("==> Done generating: {}s".format(timedelta(seconds=time.time() - start_time).total_seconds())) | |
| prompt_length = inputs.input_ids.shape[1] | |
| trimmed_generated_ids = generated_ids[:, prompt_length:] | |
| result = tokenizer.batch_decode( | |
| trimmed_generated_ids, | |
| skip_special_tokens=False, | |
| )[0].lstrip() | |
| print("==> Done: {}s".format(timedelta(seconds=time.time() - start_time).total_seconds())) | |
| """ | |
| # config.json | |
| { | |
| "architectures": [ | |
| "LlamaForCausalLM" | |
| ], | |
| "attention_bias": false, | |
| "attention_dropout": 0.0, | |
| "bos_token_id": 100264, | |
| "dtype": "bfloat16", | |
| "eos_token_id": 100257, | |
| "head_dim": 64, | |
| "hidden_act": "silu", | |
| "hidden_size": 576, | |
| "initializer_range": 0.02, | |
| "intermediate_size": 1536, | |
| "max_position_embeddings": 8192, | |
| "mlp_bias": false, | |
| "model_type": "llama", | |
| "num_attention_heads": 9, | |
| "num_hidden_layers": 30, | |
| "num_key_value_heads": 3, | |
| "pad_token_id": 100257, | |
| "pretraining_tp": 1, | |
| "rms_norm_eps": 1e-05, | |
| "rope_scaling": null, | |
| "rope_theta": 100000.0, | |
| "tie_word_embeddings": true, | |
| "use_cache": false, | |
| "vocab_size": 100352 | |
| } | |
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from docling_core.types.doc import DoclingDocument | |
| from docling_core.types.doc.document import DocTagsDocument | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| from transformers.image_utils import load_image | |
| from pathlib import Path | |
| # Load model and processor | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = AutoProcessor.from_pretrained("ibm-granite/granite-docling-258M") | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| pretrained_model_name_or_path="ibm-granite/granite-docling-258M", | |
| torch_dtype=torch.bfloat16, | |
| ).to(DEVICE) | |
| # Prepare inputs | |
| image = load_image("https://huggingface.co/ibm-granite/granite-docling-258M/resolve/main/assets/new_arxiv.png") | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": "Convert this page to docling."} | |
| ] | |
| }, | |
| ] | |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = processor(text=prompt, images=[image], return_tensors="pt") | |
| inputs = inputs.to(DEVICE) | |
| # Generate outputs | |
| generated_ids = model.generate(**inputs, max_new_tokens=8192) | |
| prompt_length = inputs.input_ids.shape[1] | |
| trimmed_generated_ids = generated_ids[:, prompt_length:] | |
| doctags = processor.batch_decode( | |
| trimmed_generated_ids, | |
| skip_special_tokens=False, | |
| )[0].lstrip() | |
| print(f"DocTags: \n{doctags}\n") | |
| # Populate document | |
| doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image]) | |
| # create a docling document | |
| doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
| print(f"Markdown:\n{doc.export_to_markdown()}\n") | |
| ## export as any format. | |
| # Path("out/").mkdir(parents=True, exist_ok=True) | |
| # HTML: | |
| # output_path_html = Path("out/") / "example.html" | |
| # doc.save_as_html(output_path_html) | |
| # Markdown: | |
| # output_path_md = Path("out/") / "example.md" | |
| # doc.save_as_markdown(output_path_md) | |
| from IPython.display import Markdown, display | |
| from docling_core.types.doc.document import DoclingDocument | |
| doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doc_tags], [image]) | |
| document = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
| extracted_text_markdown = document.export_to_markdown() | |
| display(Markdown(extracted_text_markdown)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment