Skip to content

Instantly share code, notes, and snippets.

@andresriancho
Created March 7, 2025 12:38
Show Gist options
  • Save andresriancho/862abac95aee8309452d8a138c2044c6 to your computer and use it in GitHub Desktop.
Save andresriancho/862abac95aee8309452d8a138c2044c6 to your computer and use it in GitHub Desktop.
Compress structured prompts using LLMLingua2
from typing import Optional, List
from llmlingua import PromptCompressor
MODEL_CONFIG = {}
MODEL = "microsoft/llmlingua-2-xlm-roberta-large-meetingbank"
RATE = 0.33
class LLMLinguaSegment:
"""Class representing a single <llmlingua> segment, encapsulating its content, rate, and compress flag."""
def __init__(self, content: str, rate: Optional[float], compress: Optional[bool], global_rate: float = 1.0):
self.content = content
self.compress = True if compress is None else compress # Default compress to True if not specified
self.rate = rate if rate is not None else (global_rate if self.compress else 1.0) # Default rate logic
# Ensure rate is within the valid range
if self.rate > 1.0:
raise ValueError(f"Invalid 'rate' value: {self.rate}. It must be between 0.0 and 1.0.")
def __repr__(self):
return f"LLMLinguaSegment(content={self.content!r}, rate={self.rate}, compress={self.compress})"
def extract_segments(prompt_text: str, global_rate: float = RATE) -> List[LLMLinguaSegment]:
"""
Receives a prompt containing <llmlingua> tags and parses it to extract the `rate` and `compress`
attributes, plus the text within the <llmlingua> tag.
"""
pattern = r"<llmlingua(?:\s*,?\s*rate\s*=\s*([\d\.]+))?\s*(?:,\s*compress\s*=\s*(True|False))?\s*(?:,\s*rate\s*=\s*([\d\.]+))?\s*(?:,\s*compress\s*=\s*(True|False))?\s*>([^<]+)</llmlingua>"
matches = re.findall(pattern, prompt_text)
segments = []
for match in matches:
# Extract attributes from either possible position
rate = float(match[0]) if match[0] else (float(match[2]) if match[2] else None)
compress = (match[1] == "True") if match[1] else ((match[3] == "True") if match[3] else None)
content = match[4]
# Create an instance of LLMLinguaSegment
segment = LLMLinguaSegment(content=content, rate=rate, compress=compress, global_rate=global_rate)
segments.append(segment)
return segments
def compress_prompt(prompt: str, device: str) -> str:
# Disable parallelism to avoid some warnings in the tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Initialize LLMLingua-2 with optimized parameters based on FAQ
llm_lingua = PromptCompressor(
MODEL,
device_map=device,
use_llmlingua2=True,
model_config=MODEL_CONFIG,
)
logger.debug(f"Compressing prompt with length: {len(prompt)}")
compressed_prompt: List[str] = []
for idx, llmlingua_segment in enumerate(extract_segments(prompt)):
if not llmlingua_segment.compress:
logger.debug(f"Not compressing prompt segment at index {idx}")
compressed_prompt.append(llmlingua_segment.content)
continue
logger.debug(f"Compressing prompt segment {idx} with length: {len(llmlingua_segment.content)}")
# Optimized compression using structured_compress_prompt
result = llm_lingua.compress_prompt(
llmlingua_segment.content,
rate=llmlingua_segment.rate,
force_tokens = ['\n', '?']
)
idx_compressed_prompt = result.pop("compressed_prompt")
result.pop("compressed_prompt_list")
logger.debug(f"Compressed prompt size at {idx}: {len(idx_compressed_prompt)}")
logger.debug(f"Compression metrics for {idx}: {result}")
compressed_prompt.append(idx_compressed_prompt)
return '\n'.join(compressed_prompt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment