Last active
April 3, 2025 09:14
-
-
Save ShivnarenSrinivasan/01dc2987dc8dda6f2ffcfabdb3dc68ea to your computer and use it in GitHub Desktop.
Gemini + PDF
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import base64 | |
import dotenv | |
import pymupdf | |
from langchain_core.messages import HumanMessage, SystemMessage | |
from langchain_core.messages.base import BaseMessage | |
from langchain_google_vertexai import ChatVertexAI | |
dotenv.load_dotenv(override=True) | |
async def run_llm_on_pdf_pages( | |
pdf: pymupdf.Document, | |
llm: ChatVertexAI, | |
system_prompt: str, | |
user_prompt: str, | |
) -> dict[int, BaseMessage]: | |
tasks = [ | |
run_llm_on_pdf_page( | |
pdf=pdf, | |
page_number=page_number, | |
llm=llm, | |
system_prompt=system_prompt, | |
user_prompt=user_prompt, | |
) | |
for page_number in range(len(pdf)) | |
] | |
_results = await asyncio.gather(*tasks) | |
results = {page: r for page, r in zip(range(len(pdf)), _results)} | |
return results | |
async def run_llm_on_pdf_page( | |
pdf: pymupdf.Document, | |
page_number: int, | |
llm: ChatVertexAI, | |
system_prompt: str, | |
user_prompt: str, | |
) -> BaseMessage: | |
pdf_page = await extract_page_as_bytes(pdf, page_number) | |
pdf_base64 = base64.b64encode(pdf_page).decode('utf-8') | |
result = await run_llm_on_pdf_base64(pdf_base64, llm, system_prompt, user_prompt) | |
return result | |
async def run_llm_on_pdf_base64( | |
pdf_base64: str, | |
llm: ChatVertexAI, | |
system_prompt: str, | |
user_prompt: str, | |
) -> BaseMessage: | |
messages = [ | |
SystemMessage(content=system_prompt), | |
HumanMessage( | |
[ | |
{ | |
'type': 'media', | |
'data': pdf_base64, | |
'mime_type': 'application/pdf', | |
}, | |
user_prompt, | |
] | |
), | |
] | |
result = await llm.ainvoke(messages) | |
return result | |
async def extract_page_as_bytes(pdf: pymupdf.Document, page_number: int) -> bytes: | |
new = pymupdf.open() | |
new.insert_pdf(pdf, from_page=page_number, to_page=page_number) | |
_bytes = new.write() | |
new.close() | |
return _bytes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dataclasses | |
from collections.abc import Sequence | |
import jinja2 | |
from langchain_core.language_models.chat_models import BaseChatModel | |
from langchain_core.messages import HumanMessage, SystemMessage | |
_SUMMARIZER_PROMPT = """ | |
The following are {{ n_results }} individual string results from a process: | |
{{ combined_input }} | |
Please provide a comprehensive summary that: | |
1. Identifies the main themes and patterns across all results | |
2. Highlights any significant outliers or contradictions | |
3. Consolidates similar findings | |
4. Presents the most important insights in a clear, structured way | |
Your summary should be concise but thorough, capturing the essence of all results. | |
""" | |
_SUMMARIZER_TEMPLATE = jinja2.Template(_SUMMARIZER_PROMPT) | |
@dataclasses.dataclass(frozen=True) | |
class Result: | |
filename: str | |
page_number: int | |
text: str | |
async def summarize_large_result_set( | |
results: Sequence[Result] | Sequence[str], | |
llm: BaseChatModel, | |
batch_size: int = 25, | |
max_tokens: int = 1500, | |
depth: int = 0, | |
max_summaries_per_level: int = 20, | |
) -> str: | |
""" | |
Recursively summarize very large sets of results by processing in batches and then | |
summarizing the summaries until a single coherent summary is produced. | |
Args: | |
results: List of string results to summarize | |
llm: LangChain BaseChatModel instance | |
batch_size: Number of results to process in each batch at the leaf level | |
depth: Current recursion depth (used internally) | |
max_summaries_per_level: Maximum number of summaries that can be processed at once | |
Returns: | |
A consolidated summary of all input results | |
""" | |
# Handle empty or small result sets | |
if not results: | |
return 'No results to summarize.' | |
# Base case: if we have few enough results to summarize directly | |
if len(results) <= batch_size: | |
return await summarize_results(results, llm) | |
# TODO: run this loop async | |
# Process in batches | |
batch_summaries = [] | |
for i in range(0, len(results), batch_size): | |
batch = results[i : i + batch_size] | |
# Format identifier changes based on depth | |
level_prefix = f'Level {depth} - ' if depth > 0 else '' | |
batch_summary = await summarize_results(batch, llm) | |
batch_summaries.append( | |
f'{level_prefix}Summary of results {i + 1}-{i + len(batch)}: {batch_summary}' | |
) | |
# Recursive case: if we have too many summaries to process at once, recurse | |
if len(batch_summaries) > max_summaries_per_level: | |
return await summarize_large_result_set( | |
batch_summaries, | |
llm, | |
batch_size=max_summaries_per_level, | |
max_tokens=max_tokens, | |
depth=depth + 1, | |
max_summaries_per_level=max_summaries_per_level, | |
) | |
# Final case: we can create one summary from our current batch of summaries | |
final_summary = await summarize_results(batch_summaries, llm) | |
return final_summary | |
async def summarize_results( | |
results: Sequence[Result] | Sequence[str], | |
llm: BaseChatModel, | |
user_template: jinja2.Template = _SUMMARIZER_TEMPLATE, | |
) -> str: | |
"""Summarize and consolidate multiple string results using an LLM. | |
Returns: | |
-------- | |
A consolidated summary of the inputs | |
""" | |
if isinstance(results[0], str): | |
combined_input = '\n\n'.join([result for result in results]) | |
else: | |
combined_input = '\n\n'.join( | |
[ | |
f'Result {i + 1}: {result.filename = }, {result.page_number = }\n text: {result.text}' | |
for i, result in enumerate(results) | |
] | |
) | |
prompt = user_template.render( | |
n_results=len(results), | |
combined_input=combined_input, | |
) | |
return await _summarize(prompt, llm) | |
async def _summarize(prompt: str, llm: BaseChatModel) -> str: | |
messages = [ | |
SystemMessage( | |
content='You are a helpful assistant that summarizes multiple results into a consolidated insight.' | |
), | |
HumanMessage(content=prompt), | |
] | |
response = await llm.ainvoke(messages) | |
return response.content |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment