Skip to content

Instantly share code, notes, and snippets.

@ShivnarenSrinivasan
Last active April 3, 2025 09:14
Show Gist options
  • Save ShivnarenSrinivasan/01dc2987dc8dda6f2ffcfabdb3dc68ea to your computer and use it in GitHub Desktop.
Save ShivnarenSrinivasan/01dc2987dc8dda6f2ffcfabdb3dc68ea to your computer and use it in GitHub Desktop.
Gemini + PDF
import asyncio
import base64
import dotenv
import pymupdf
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages.base import BaseMessage
from langchain_google_vertexai import ChatVertexAI
dotenv.load_dotenv(override=True)
async def run_llm_on_pdf_pages(
pdf: pymupdf.Document,
llm: ChatVertexAI,
system_prompt: str,
user_prompt: str,
) -> dict[int, BaseMessage]:
tasks = [
run_llm_on_pdf_page(
pdf=pdf,
page_number=page_number,
llm=llm,
system_prompt=system_prompt,
user_prompt=user_prompt,
)
for page_number in range(len(pdf))
]
_results = await asyncio.gather(*tasks)
results = {page: r for page, r in zip(range(len(pdf)), _results)}
return results
async def run_llm_on_pdf_page(
pdf: pymupdf.Document,
page_number: int,
llm: ChatVertexAI,
system_prompt: str,
user_prompt: str,
) -> BaseMessage:
pdf_page = await extract_page_as_bytes(pdf, page_number)
pdf_base64 = base64.b64encode(pdf_page).decode('utf-8')
result = await run_llm_on_pdf_base64(pdf_base64, llm, system_prompt, user_prompt)
return result
async def run_llm_on_pdf_base64(
pdf_base64: str,
llm: ChatVertexAI,
system_prompt: str,
user_prompt: str,
) -> BaseMessage:
messages = [
SystemMessage(content=system_prompt),
HumanMessage(
[
{
'type': 'media',
'data': pdf_base64,
'mime_type': 'application/pdf',
},
user_prompt,
]
),
]
result = await llm.ainvoke(messages)
return result
async def extract_page_as_bytes(pdf: pymupdf.Document, page_number: int) -> bytes:
new = pymupdf.open()
new.insert_pdf(pdf, from_page=page_number, to_page=page_number)
_bytes = new.write()
new.close()
return _bytes
import dataclasses
from collections.abc import Sequence
import jinja2
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
_SUMMARIZER_PROMPT = """
The following are {{ n_results }} individual string results from a process:
{{ combined_input }}
Please provide a comprehensive summary that:
1. Identifies the main themes and patterns across all results
2. Highlights any significant outliers or contradictions
3. Consolidates similar findings
4. Presents the most important insights in a clear, structured way
Your summary should be concise but thorough, capturing the essence of all results.
"""
_SUMMARIZER_TEMPLATE = jinja2.Template(_SUMMARIZER_PROMPT)
@dataclasses.dataclass(frozen=True)
class Result:
filename: str
page_number: int
text: str
async def summarize_large_result_set(
results: Sequence[Result] | Sequence[str],
llm: BaseChatModel,
batch_size: int = 25,
max_tokens: int = 1500,
depth: int = 0,
max_summaries_per_level: int = 20,
) -> str:
"""
Recursively summarize very large sets of results by processing in batches and then
summarizing the summaries until a single coherent summary is produced.
Args:
results: List of string results to summarize
llm: LangChain BaseChatModel instance
batch_size: Number of results to process in each batch at the leaf level
depth: Current recursion depth (used internally)
max_summaries_per_level: Maximum number of summaries that can be processed at once
Returns:
A consolidated summary of all input results
"""
# Handle empty or small result sets
if not results:
return 'No results to summarize.'
# Base case: if we have few enough results to summarize directly
if len(results) <= batch_size:
return await summarize_results(results, llm)
# TODO: run this loop async
# Process in batches
batch_summaries = []
for i in range(0, len(results), batch_size):
batch = results[i : i + batch_size]
# Format identifier changes based on depth
level_prefix = f'Level {depth} - ' if depth > 0 else ''
batch_summary = await summarize_results(batch, llm)
batch_summaries.append(
f'{level_prefix}Summary of results {i + 1}-{i + len(batch)}: {batch_summary}'
)
# Recursive case: if we have too many summaries to process at once, recurse
if len(batch_summaries) > max_summaries_per_level:
return await summarize_large_result_set(
batch_summaries,
llm,
batch_size=max_summaries_per_level,
max_tokens=max_tokens,
depth=depth + 1,
max_summaries_per_level=max_summaries_per_level,
)
# Final case: we can create one summary from our current batch of summaries
final_summary = await summarize_results(batch_summaries, llm)
return final_summary
async def summarize_results(
results: Sequence[Result] | Sequence[str],
llm: BaseChatModel,
user_template: jinja2.Template = _SUMMARIZER_TEMPLATE,
) -> str:
"""Summarize and consolidate multiple string results using an LLM.
Returns:
--------
A consolidated summary of the inputs
"""
if isinstance(results[0], str):
combined_input = '\n\n'.join([result for result in results])
else:
combined_input = '\n\n'.join(
[
f'Result {i + 1}: {result.filename = }, {result.page_number = }\n text: {result.text}'
for i, result in enumerate(results)
]
)
prompt = user_template.render(
n_results=len(results),
combined_input=combined_input,
)
return await _summarize(prompt, llm)
async def _summarize(prompt: str, llm: BaseChatModel) -> str:
messages = [
SystemMessage(
content='You are a helpful assistant that summarizes multiple results into a consolidated insight.'
),
HumanMessage(content=prompt),
]
response = await llm.ainvoke(messages)
return response.content
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment