Skip to content

Instantly share code, notes, and snippets.

@dhruvilp
Last active November 3, 2025 23:26
Show Gist options
  • Select an option

  • Save dhruvilp/ec5c62178c2bf22d6c0c6a552a674e8d to your computer and use it in GitHub Desktop.

Select an option

Save dhruvilp/ec5c62178c2bf22d6c0c6a552a674e8d to your computer and use it in GitHub Desktop.
granite-docling-258m inference
[
  {
    "content": "reasoning language: English\n\nYou are an intelligent assistant that can answer customer service queries",
    "role": "system",
    "thinking": null
  },
  {
    "content": "Can you provide me with a list of the top-rated series currently on Netflix?",
    "role": "user",
    "thinking": null
  },
  {
    "content": "Netflix does not publicly release real-time lists of its top-rated series, but you can find updated information from third-party platforms and review sites....",
    "role": "assistant",
    "thinking": "\nOkay, the user is asking for the top-rated series currently on Netflix. Let me start by recalling that Netflix doesn't publicly release their current top-rated lists in real-time....\n"
  }
]
"""FastAPI app for document conversion using Granite Docling model."""
import base64
import html
import io
import re
from pathlib import Path
from typing import Optional, Literal
import asyncio
import torch
import numpy as np
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from PIL import Image, ImageDraw, ImageOps
from docling_core.types.doc import DoclingDocument
from docling_core.types.doc.document import DocTagsDocument
from transformers import (
AutoProcessor,
Idefics3ForConditionalGeneration,
)
# Initialize FastAPI app
app = FastAPI(
title="Granite Docling Document Converter",
description="Convert documents to markdown, doctags, and JSON using Granite Docling",
version="1.0.0"
)
# Global variables for model and processor
processor = None
model = None
device = None
# Response models
class ConversionResponse(BaseModel):
"""Response model for document conversion."""
success: bool
format: str
content: str
raw_output: Optional[str] = None
message: Optional[str] = None
class BoundingBoxResponse(BaseModel):
"""Response model with bounding boxes."""
success: bool
content: str
annotated_image: Optional[str] = None # Base64 encoded
has_bounding_boxes: bool
class HealthResponse(BaseModel):
"""Health check response."""
status: str
model_loaded: bool
device: str
# Utility functions
def clean_model_response(text: str) -> str:
"""Clean up model response by removing special tokens."""
if not text:
return "No response generated."
special_tokens = [
"<|end_of_text|>",
"<|end|>",
"<|assistant|>",
"<|user|>",
"<|system|>",
"<pad>",
"</s>",
"<s>",
]
cleaned = text
for token in special_tokens:
cleaned = cleaned.replace(token, "")
cleaned = cleaned.strip()
if not cleaned or len(cleaned) == 0:
return "The model generated a response, but it appears to be empty."
return cleaned
def draw_bounding_boxes(image: Image.Image, response_text: str, is_doctag_response: bool = False) -> Image.Image:
"""Draw bounding boxes on the image based on loc tags."""
try:
draw = ImageDraw.Draw(image)
width, height = image.size
# Color mapping for different classes
class_colors = {
"caption": "#FFCC99",
"footnote": "#C8C8FF",
"formula": "#C0C0C0",
"list_item": "#9999FF",
"page_footer": "#CCFFCC",
"page_header": "#CCFFCC",
"picture": "#FFCCA4",
"chart": "#FFCCA4",
"section_header": "#FF9999",
"table": "#FFCCCC",
"text": "#FFFF99",
"title": "#FF9999",
"document_index": "#DCDCDC",
"code": "#7D7D7D",
"paragraph": "#FFFF99",
}
doctag_class_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>[^<]*</[^>]+>"
doctag_matches = re.findall(doctag_class_pattern, response_text)
class_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>"
class_matches = re.findall(class_pattern, response_text)
seen_coords = set()
all_class_matches = []
for match in doctag_matches:
coords = (match[1], match[2], match[3], match[4])
if coords not in seen_coords:
seen_coords.add(coords)
all_class_matches.append(match)
for match in class_matches:
coords = (match[1], match[2], match[3], match[4])
if coords not in seen_coords:
seen_coords.add(coords)
all_class_matches.append(match)
for class_name, xmin, ymin, xmax, ymax in all_class_matches:
if is_doctag_response:
color = class_colors.get(class_name.lower(), "#808080")
else:
color = "#E0115F"
x1 = int((int(xmin) / 500) * width)
y1 = int((int(ymin) / 500) * height)
x2 = int((int(xmax) / 500) * width)
y2 = int((int(ymax) / 500) * height)
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
return image
except Exception:
return image
def image_to_base64(image: Image.Image) -> str:
"""Convert PIL Image to base64 string."""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
# Model initialization
def initialize_model(model_path: str):
"""Initialize the model and processor from local path."""
global processor, model, device
device = torch.device(
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
model = Idefics3ForConditionalGeneration.from_pretrained(
model_path,
device_map=device,
torch_dtype=torch.bfloat16,
local_files_only=True
)
if not torch.cuda.is_available():
model = model.to(device)
print(f"Model loaded successfully on {device}")
def generate_response(question: str, image: Image.Image) -> str:
"""Generate response using the model."""
if model is None or processor is None:
raise HTTPException(status_code=500, detail="Model not initialized")
try:
image = image.convert("RGB")
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": question},
],
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
temperature = 0.0
inputs = processor(text=prompt, images=[image], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=4096,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=processor.tokenizer.eos_token_id,
)
generated_texts = processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1]:],
skip_special_tokens=False,
)[0]
return clean_model_response(generated_texts)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
async def generate_response_async(question: str, image: Image.Image) -> str:
"""Async version of generate_response."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, generate_response, question, image)
# API Endpoints
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint."""
return HealthResponse(
status="healthy" if model is not None else "model_not_loaded",
model_loaded=model is not None,
device=str(device) if device is not None else "not_initialized"
)
@app.post("/initialize")
async def initialize(model_path: str = Form(...)):
"""Initialize the model from local path."""
try:
initialize_model(model_path)
return {"success": True, "message": "Model initialized successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to initialize model: {str(e)}")
@app.post("/convert/markdown", response_model=ConversionResponse)
async def convert_to_markdown_sync(
file: UploadFile = File(...),
prompt: str = Form(default="Convert this page to docling.")
):
"""Synchronous endpoint to convert document to markdown."""
try:
# Read and process image
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
# Generate response
raw_output = generate_response(prompt, image)
# Convert to markdown
try:
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([raw_output], [image])
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
markdown_output = doc.export_to_markdown()
return ConversionResponse(
success=True,
format="markdown",
content=markdown_output,
raw_output=raw_output,
message="Successfully converted to markdown"
)
except Exception as e:
return ConversionResponse(
success=False,
format="markdown",
content="",
raw_output=raw_output,
message=f"Error converting to markdown: {str(e)}"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/convert/markdown/async", response_model=ConversionResponse)
async def convert_to_markdown_async_endpoint(
file: UploadFile = File(...),
prompt: str = Form(default="Convert this page to docling.")
):
"""Asynchronous endpoint to convert document to markdown."""
try:
# Read and process image
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
# Generate response asynchronously
raw_output = await generate_response_async(prompt, image)
# Convert to markdown
try:
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([raw_output], [image])
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
markdown_output = doc.export_to_markdown()
return ConversionResponse(
success=True,
format="markdown",
content=markdown_output,
raw_output=raw_output,
message="Successfully converted to markdown"
)
except Exception as e:
return ConversionResponse(
success=False,
format="markdown",
content="",
raw_output=raw_output,
message=f"Error converting to markdown: {str(e)}"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/convert/doctags", response_model=ConversionResponse)
async def convert_to_doctags(
file: UploadFile = File(...),
prompt: str = Form(default="Convert this page to docling.")
):
"""Synchronous endpoint to convert document to doctags."""
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
raw_output = generate_response(prompt, image)
return ConversionResponse(
success=True,
format="doctags",
content=raw_output,
raw_output=raw_output,
message="Successfully generated doctags"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/convert/json")
async def convert_to_json(
file: UploadFile = File(...),
prompt: str = Form(default="Convert this page to docling.")
):
"""Synchronous endpoint to convert document to JSON dict."""
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
raw_output = generate_response(prompt, image)
# Try to convert to DoclingDocument and export to dict
try:
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([raw_output], [image])
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
json_output = doc.export_to_dict()
return {
"success": True,
"format": "json",
"content": json_output,
"raw_output": raw_output,
"message": "Successfully converted to JSON"
}
except Exception as e:
return {
"success": False,
"format": "json",
"content": {},
"raw_output": raw_output,
"message": f"Error converting to JSON: {str(e)}"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/convert/json/async")
async def convert_to_json_async(
file: UploadFile = File(...),
prompt: str = Form(default="Convert this page to docling.")
):
"""Asynchronous endpoint to convert document to JSON dict."""
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
raw_output = await generate_response_async(prompt, image)
# Try to convert to DoclingDocument and export to dict
try:
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([raw_output], [image])
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
json_output = doc.export_to_dict()
return {
"success": True,
"format": "json",
"content": json_output,
"raw_output": raw_output,
"message": "Successfully converted to JSON"
}
except Exception as e:
return {
"success": False,
"format": "json",
"content": {},
"raw_output": raw_output,
"message": f"Error converting to JSON: {str(e)}"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/convert/with-boxes", response_model=BoundingBoxResponse)
async def convert_with_bounding_boxes(
file: UploadFile = File(...),
prompt: str = Form(default="Convert this page to docling.")
):
"""Convert document and return with bounding boxes visualization."""
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
raw_output = generate_response(prompt, image)
# Check for location tags
has_doctag = "<doctag>" in raw_output
class_loc_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>"
loc_only_pattern = r"<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>"
has_loc_tags = bool(re.findall(class_loc_pattern, raw_output) or re.findall(loc_only_pattern, raw_output))
annotated_image_b64 = None
if has_loc_tags:
annotated_image = draw_bounding_boxes(image.copy(), raw_output, has_doctag)
annotated_image_b64 = image_to_base64(annotated_image)
return BoundingBoxResponse(
success=True,
content=raw_output,
annotated_image=annotated_image_b64,
has_bounding_boxes=has_loc_tags
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query")
async def query_document(
file: UploadFile = File(...),
question: str = Form(...)
):
"""General purpose endpoint to query a document with any question."""
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
response = generate_response(question, image)
return {
"success": True,
"question": question,
"answer": response
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/async")
async def query_document_async(
file: UploadFile = File(...),
question: str = Form(...)
):
"""Async general purpose endpoint to query a document."""
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
response = await generate_response_async(question, image)
return {
"success": True,
"question": question,
"answer": response
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
async def generate_response_streaming(question: str, image: Image.Image):
"""Generate response with streaming."""
if model is None or processor is None:
yield "Error: Model not initialized"
return
try:
from transformers import TextIteratorStreamer
from threading import Thread
image = image.convert("RGB")
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": question},
],
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
temperature = 0.0
inputs = processor(text=prompt, images=[image], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False)
generation_args = dict(
inputs,
streamer=streamer,
max_new_tokens=4096,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=processor.tokenizer.eos_token_id,
)
thread = Thread(target=model.generate, kwargs=generation_args)
thread.start()
for new_text in streamer:
cleaned = clean_model_response(new_text)
if cleaned and cleaned != "No response generated.":
yield f"data: {cleaned}\n\n"
thread.join()
yield "data: [DONE]\n\n"
except Exception as e:
yield f"data: Error: {str(e)}\n\n"
@app.post("/query/stream")
async def query_document_stream(
file: UploadFile = File(...),
question: str = Form(...)
):
"""Stream response for document query (Server-Sent Events)."""
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
return StreamingResponse(
generate_response_streaming(question, image),
media_type="text/event-stream"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
# Initialize model at startup if MODEL_PATH is set
import os
model_path = os.getenv("MODEL_PATH")
if model_path:
print(f"Initializing model from {model_path}")
initialize_model(model_path)
else:
print("MODEL_PATH not set. Use /initialize endpoint to load model.")
uvicorn.run(app, host="0.0.0.0", port=8000)
import html
import os
import random
import re
import time
from pathlib import Path
from threading import Thread
import numpy as np
import torch
from docling_core.types.doc import DoclingDocument
from docling_core.types.doc.document import DocTagsDocument
from PIL import Image, ImageDraw, ImageOps
from transformers import (
AutoProcessor,
Idefics3ForConditionalGeneration,
TextIteratorStreamer,
)
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import JSONResponse
import io
import base64
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
# Replace with your actual local path to the model
local_path = "./granite-docling-258M"
processor = AutoProcessor.from_pretrained(local_path)
model = Idefics3ForConditionalGeneration.from_pretrained(
local_path, device_map=device, torch_dtype=torch.bfloat16
)
if not torch.cuda.is_available():
model = model.to(device)
def add_random_padding(image: Image.Image, min_percent: float = 0.1, max_percent: float = 0.10) -> Image.Image:
"""Add random padding to an image."""
image = image.convert("RGB")
width, height = image.size
pad_w_percent = random.uniform(min_percent, max_percent)
pad_h_percent = random.uniform(min_percent, max_percent)
pad_w = int(width * pad_w_percent)
pad_h = int(height * pad_h_percent)
corner_pixel = image.getpixel((0, 0)) # Top-left corner
padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
return padded_image
def draw_bounding_boxes(image: Image.Image, response_text: str, is_doctag_response: bool = False) -> Image.Image:
"""Draw bounding boxes on the image based on loc tags and return the annotated image."""
try:
image = image.convert("RGB")
draw = ImageDraw.Draw(image)
# Get image dimensions
width, height = image.size
# Color mapping for different classes (RGB values converted to hex)
class_colors = {
"caption": "#FFCC99", # (255, 204, 153)
"footnote": "#C8C8FF", # (200, 200, 255)
"formula": "#C0C0C0", # (192, 192, 192)
"list_item": "#9999FF", # (153, 153, 255)
"page_footer": "#CCFFCC", # (204, 255, 204)
"page_header": "#CCFFCC", # (204, 255, 204)
"picture": "#FFCCA4", # (255, 204, 164)
"chart": "#FFCCA4", # (255, 204, 164)
"section_header": "#FF9999", # (255, 153, 153)
"table": "#FFCCCC", # (255, 204, 204)
"text": "#FFFF99", # (255, 255, 153)
"title": "#FF9999", # (255, 153, 153)
"document_index": "#DCDCDC", # (220, 220, 220)
"code": "#7D7D7D", # (125, 125, 125)
"checkbox_selected": "#FFB6C1", # (255, 182, 193)
"checkbox_unselected": "#FFB6C1", # (255, 182, 193)
"form": "#C8FFFF", # (200, 255, 255)
"key_value_region": "#B7410E", # (183, 65, 14)
"paragraph": "#FFFF99", # (255, 255, 153)
"reference": "#B0E0E6", # (176, 224, 230)
"grading_scale": "#FFCCCC", # (255, 204, 204)
"handwritten_text": "#CCFFCC", # (204, 255, 204)
"empty_value": "#DCDCDC", # (220, 220, 220)
}
doctag_class_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>[^<]*</[^>]+>"
doctag_matches = re.findall(doctag_class_pattern, response_text)
class_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>"
class_matches = re.findall(class_pattern, response_text)
seen_coords = set()
all_class_matches = []
for match in doctag_matches:
coords = (match[1], match[2], match[3], match[4])
if coords not in seen_coords:
seen_coords.add(coords)
all_class_matches.append(match)
for match in class_matches:
coords = (match[1], match[2], match[3], match[4])
if coords not in seen_coords:
seen_coords.add(coords)
all_class_matches.append(match)
loc_only_pattern = r"<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>"
loc_only_matches = re.findall(loc_only_pattern, response_text)
for class_name, xmin, ymin, xmax, ymax in all_class_matches:
if is_doctag_response:
color = class_colors.get(class_name.lower(), None)
if color is None:
for key in class_colors:
if class_name.lower() in key or key in class_name.lower():
color = class_colors[key]
break
if color is None:
color = "#808080"
else:
color = "#E0115F"
x1 = int((int(xmin) / 500) * width)
y1 = int((int(ymin) / 500) * height)
x2 = int((int(xmax) / 500) * width)
y2 = int((int(ymax) / 500) * height)
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
for xmin, ymin, xmax, ymax in loc_only_matches:
if is_doctag_response:
continue
else:
color = "#808080"
x1 = int((int(xmin) / 500) * width)
y1 = int((int(ymin) / 500) * height)
x2 = int((int(xmax) / 500) * width)
y2 = int((int(ymax) / 500) * height)
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
return image
except Exception:
return image
def clean_model_response(text: str) -> str:
"""Clean up model response by removing special tokens and formatting properly."""
if not text:
return "No response generated."
special_tokens = [
"<|end_of_text|>",
"<|end|>",
"<|assistant|>",
"<|user|>",
"<|system|>",
"<pad>",
"</s>",
"<s>",
]
cleaned = text
for token in special_tokens:
cleaned = cleaned.replace(token, "")
cleaned = cleaned.strip()
if not cleaned or len(cleaned) == 0:
return "The model generated a response, but it appears to be empty or contain only special tokens."
return cleaned
def generate_with_model(question: str, image: Image.Image, apply_padding: bool = False) -> str:
"""Generate answer using the Granite Docling model directly on the image."""
if os.environ.get("NO_LLM"):
time.sleep(2)
return "This is a simulated response from the Granite Docling model."
try:
if apply_padding:
image = add_random_padding(image)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": question},
],
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
temperature = 0.0
inputs = processor(text=prompt, images=[image], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=4096,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=processor.tokenizer.eos_token_id,
)
generated_texts = processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
)[0]
cleaned_response = clean_model_response(generated_texts)
return cleaned_response
except Exception as e:
return f"Error processing image: {e!s}"
def pil_to_bytes(image: Image.Image) -> bytes:
buf = io.BytesIO()
image.save(buf, format="PNG")
return buf.getvalue()
app = FastAPI()
@app.post("/generate/sync")
def generate_sync(
prompt: str = Form(...),
image: UploadFile = File(...),
apply_padding: bool = Form(False)
):
try:
image_bytes = image.file.read()
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
answer = generate_with_model(prompt, img, apply_padding)
markdown_output = None
json_output = None
annotated_image_base64 = None
class_loc_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>"
class_loc_matches = re.findall(class_loc_pattern, answer)
loc_only_pattern = r"<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>"
loc_only_matches = re.findall(loc_only_pattern, answer)
has_doctag = "<doctag>" in answer
has_loc_tags = class_loc_matches or loc_only_matches
if "convert this page to docling" in prompt.lower() or ("convert" in prompt.lower() and "otsl" in prompt.lower()):
try:
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([answer], [img])
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
markdown_output = doc.export_to_markdown()
# Assuming DoclingDocument is Pydantic-based or has model_dump_json
try:
json_output = doc.model_dump_json()
except AttributeError:
json_output = None # If not available, skip
except Exception as e:
return JSONResponse({"error": f"Error creating markdown output: {str(e)}"})
elif "convert formula to latex" in prompt.lower():
try:
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([answer], [img])
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
markdown_output = doc.export_to_markdown()
if markdown_output.count("$$") >= 2:
parts = markdown_output.split("$$", 2)
formula = parts[1].strip()
wrapped = f"$$\n\\begin{{aligned}}\n{formula}\n\\end{{aligned}}\n$$"
markdown_output = parts[0] + wrapped + parts[2]
# Assuming model_dump_json
try:
json_output = doc.model_dump_json()
except AttributeError:
json_output = None
except Exception as e:
return JSONResponse({"error": f"Error creating LaTeX output: {str(e)}"})
if has_loc_tags:
try:
annotated_image = draw_bounding_boxes(img, answer, is_doctag_response=has_doctag)
annotated_image_base64 = base64.b64encode(pil_to_bytes(annotated_image)).decode("utf-8")
except Exception as e:
pass # Skip if error
return JSONResponse({
"response": answer,
"markdown": markdown_output,
"json": json_output,
"annotated_image_base64": annotated_image_base64
})
except Exception as e:
return JSONResponse({"error": str(e)})
@app.post("/generate/async")
async def generate_async(
prompt: str = Form(...),
image: UploadFile = File(...),
apply_padding: bool = Form(False)
):
# Same as sync, but async def
try:
image_bytes = await image.read()
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
answer = generate_with_model(prompt, img, apply_padding)
markdown_output = None
json_output = None
annotated_image_base64 = None
class_loc_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>"
class_loc_matches = re.findall(class_loc_pattern, answer)
loc_only_pattern = r"<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>"
loc_only_matches = re.findall(loc_only_pattern, answer)
has_doctag = "<doctag>" in answer
has_loc_tags = class_loc_matches or loc_only_matches
if "convert this page to docling" in prompt.lower() or ("convert" in prompt.lower() and "otsl" in prompt.lower()):
try:
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([answer], [img])
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
markdown_output = doc.export_to_markdown()
try:
json_output = doc.model_dump_json()
except AttributeError:
json_output = None
except Exception as e:
return JSONResponse({"error": f"Error creating markdown output: {str(e)}"})
elif "convert formula to latex" in prompt.lower():
try:
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([answer], [img])
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
markdown_output = doc.export_to_markdown()
if markdown_output.count("$$") >= 2:
parts = markdown_output.split("$$", 2)
formula = parts[1].strip()
wrapped = f"$$\n\\begin{{aligned}}\n{formula}\n\\end{{aligned}}\n$$"
markdown_output = parts[0] + wrapped + parts[2]
try:
json_output = doc.model_dump_json()
except AttributeError:
json_output = None
except Exception as e:
return JSONResponse({"error": f"Error creating LaTeX output: {str(e)}"})
if has_loc_tags:
try:
annotated_image = draw_bounding_boxes(img, answer, is_doctag_response=has_doctag)
annotated_image_base64 = base64.b64encode(pil_to_bytes(annotated_image)).decode("utf-8")
except Exception as e:
pass
return JSONResponse({
"response": answer,
"markdown": markdown_output,
"json": json_output,
"annotated_image_base64": annotated_image_base64
})
except Exception as e:
return JSONResponse({"error": str(e)})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
import torch
import time
from datetime import timedelta
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.image_utils import load_image
from pathlib import Path
# Load model and processor
DEVICE = "cpu"
if torch.cuda.is_available():
DEVICE = "cuda"
elif torch.backends.mps.is_available():
DEVICE = "mps"
print(f"USING DEVICE: {DEVICE}")
start_time = time.time()
model_path = "/Users/ghart/models/ibm-granite/granite-docling-258M-text-only"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path).to(DEVICE)
print("==> Done loading model: {}s".format(timedelta(seconds=time.time() - start_time).total_seconds()))
# Prepare inputs
image = load_image("https://huggingface.co/ibm-granite/granite-docling-258M/resolve/main/assets/new_arxiv.png")
messages = [
{
"role": "user",
"content": [
# {"type": "image"},
{"type": "text", "text": "Convert this page to docling."}
]
},
]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(prompt, return_tensors="pt")
inputs = inputs.to(DEVICE)
print("==> Done preparing inputs: {}s".format(timedelta(seconds=time.time() - start_time).total_seconds()))
# Generate outputs
generated_ids = model.generate(**inputs, max_new_tokens=512, use_cache=True)
print("==> Done generating: {}s".format(timedelta(seconds=time.time() - start_time).total_seconds()))
prompt_length = inputs.input_ids.shape[1]
trimmed_generated_ids = generated_ids[:, prompt_length:]
result = tokenizer.batch_decode(
trimmed_generated_ids,
skip_special_tokens=False,
)[0].lstrip()
print("==> Done: {}s".format(timedelta(seconds=time.time() - start_time).total_seconds()))
"""
# config.json
{
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 100264,
"dtype": "bfloat16",
"eos_token_id": 100257,
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 576,
"initializer_range": 0.02,
"intermediate_size": 1536,
"max_position_embeddings": 8192,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 9,
"num_hidden_layers": 30,
"num_key_value_heads": 3,
"pad_token_id": 100257,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 100000.0,
"tie_word_embeddings": true,
"use_cache": false,
"vocab_size": 100352
}
"""
import torch
from docling_core.types.doc import DoclingDocument
from docling_core.types.doc.document import DocTagsDocument
from transformers import AutoProcessor, AutoModelForImageTextToText
from transformers.image_utils import load_image
from pathlib import Path
# Load model and processor
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained("ibm-granite/granite-docling-258M")
model = AutoModelForImageTextToText.from_pretrained(
pretrained_model_name_or_path="ibm-granite/granite-docling-258M",
torch_dtype=torch.bfloat16,
).to(DEVICE)
# Prepare inputs
image = load_image("https://huggingface.co/ibm-granite/granite-docling-258M/resolve/main/assets/new_arxiv.png")
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Convert this page to docling."}
]
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image], return_tensors="pt")
inputs = inputs.to(DEVICE)
# Generate outputs
generated_ids = model.generate(**inputs, max_new_tokens=8192)
prompt_length = inputs.input_ids.shape[1]
trimmed_generated_ids = generated_ids[:, prompt_length:]
doctags = processor.batch_decode(
trimmed_generated_ids,
skip_special_tokens=False,
)[0].lstrip()
print(f"DocTags: \n{doctags}\n")
# Populate document
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image])
# create a docling document
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
print(f"Markdown:\n{doc.export_to_markdown()}\n")
## export as any format.
# Path("out/").mkdir(parents=True, exist_ok=True)
# HTML:
# output_path_html = Path("out/") / "example.html"
# doc.save_as_html(output_path_html)
# Markdown:
# output_path_md = Path("out/") / "example.md"
# doc.save_as_markdown(output_path_md)
from IPython.display import Markdown, display
from docling_core.types.doc.document import DoclingDocument
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doc_tags], [image])
document = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
extracted_text_markdown = document.export_to_markdown()
display(Markdown(extracted_text_markdown))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment