Last active
June 26, 2025 23:27
-
-
Save ericflo/f8d3978a86f163f54586ec74e496cf8a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# /// script | |
# requires-python = ">=3.11" | |
# dependencies = [ | |
# "datasets", | |
# "numpy", | |
# "requests", | |
# "tqdm", | |
# "pyarrow" | |
# ] | |
# /// | |
""" | |
label_finetome.py | |
Internal Coherence Maximization (ICM) with Self-Play Response Generation | |
Combines: | |
- N-shot prompting to generate M diverse responses per prompt | |
- Saves all generated responses to a Hugging Face Dataset before labeling. | |
- Full ICM algorithm (MCMC acceptance/rejection, mutual predictability, inconsistency repair) | |
- Adaptive selection of which pairs to label based on potential information gain | |
- Aggressive duplicate prevention at multiple levels (exact, prefix, fuzzy matching) | |
- Robust "Fail, Reduce, Retry" logic for handling API context length limits. | |
VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve Qwen/Qwen3-8B-Base --port 8093 --disable-log-requests --gpu-memory-utilization 0.8 | |
NCCL_P2P_DISABLE=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve Qwen/Qwen3-8B-Base --port 8093 --disable-log-requests --gpu-memory-utilization 0.8 --tensor-parallel-size 8 | |
""" | |
import json | |
import random | |
import re | |
import math | |
import numpy as np | |
from typing import List, Dict, Tuple, Optional, Set, Iterator, Any | |
from dataclasses import dataclass, field | |
from collections import defaultdict | |
import requests | |
from datasets import load_dataset, Dataset, Features, Value, Sequence | |
import logging | |
from tqdm import tqdm | |
import time | |
import hashlib | |
import signal | |
import sys | |
import csv | |
from functools import wraps | |
from datetime import datetime | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import os | |
import atexit | |
import uuid | |
import threading | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
# Multi-dimensional label ontology for comprehensive preference modeling | |
LABEL_ONTOLOGY = [ | |
# AXIS 0: Safety Gates | |
["policy_compliance", "harmlessness"], | |
# AXIS 1: Reality Fidelity | |
["factual_accuracy", "evidence_grounding"], | |
["logical_coherence", "causal_soundness", "internal_consistency"], | |
# AXIS 2: Task Intelligence | |
["task_completion", "constraint_adherence"], | |
["latent_task_identification", "pattern_generalization"], | |
# AXIS 3: Information Quality | |
["information_density", "relevance_focus", "structural_clarity"], | |
# AXIS 4: Epistemic Honesty | |
["uncertainty_calibration", "assumption_transparency", "perspective_awareness"], | |
# AXIS 5: Cognitive Depth | |
["insight_synthesis", "problem_decomposition", "conceptual_synthesis"], | |
] | |
# Build label validation set | |
VALID_LABELS = set() | |
LABEL_TO_AXIS = {} | |
for axis_id, labels in enumerate(LABEL_ONTOLOGY): | |
for label in labels: | |
VALID_LABELS.add(label) | |
LABEL_TO_AXIS[label] = axis_id | |
# Dimension rubrics | |
DIMENSION_RUBRICS = { | |
"policy_compliance": "adherence to content policies without violations", | |
"harmlessness": "absence of potential for real-world harm", | |
"factual_accuracy": "correctness of factual claims", | |
"evidence_grounding": "support by evidence or sources", | |
"logical_coherence": "reasoning without fallacies or contradictions", | |
"causal_soundness": "correct cause-effect relationships", | |
"internal_consistency": "absence of self-contradictions", | |
"task_completion": "fully addressing all parts of a request", | |
"constraint_adherence": "following specified constraints and formats", | |
"latent_task_identification": "identifying underlying goals or patterns", | |
"pattern_generalization": "applying patterns to new cases", | |
"information_density": "high value per word", | |
"relevance_focus": "staying on topic without tangents", | |
"structural_clarity": "clear organization and presentation", | |
"uncertainty_calibration": "appropriate expression of confidence", | |
"assumption_transparency": "clarity about underlying assumptions", | |
"perspective_awareness": "recognition of different viewpoints", | |
"insight_synthesis": "deep, integrative understanding", | |
"problem_decomposition": "breaking down complex problems clearly", | |
"conceptual_synthesis": "combining ideas into coherent wholes", | |
} | |
UNIFIED_ONTOLOGY_PROMPT = """You are an expert evaluator assessing AI assistant responses across multiple dimensions of quality. Your evaluation should consider the following comprehensive framework: | |
## FOUNDATIONAL REQUIREMENTS (Non-Negotiable) | |
**Safety & Compliance** | |
- The response must not violate content policies (no hate speech, PII exposure, or illicit content) | |
- The response must not provide advice that could cause real-world harm | |
## CORE QUALITY DIMENSIONS | |
### 1. REALITY FIDELITY - Grounding in Truth | |
Evaluate how well the response reflects accurate knowledge and sound reasoning: | |
**Factual Accuracy**: Are all factual claims correct and free from hallucination? | |
**Evidence Grounding**: Are claims properly supported by the context or cited sources? | |
**Logical Coherence**: Is the reasoning free from fallacies and contradictions? | |
**Causal Soundness**: Are cause-effect relationships correctly represented? | |
**Internal Consistency**: Is the response free from self-contradictions? | |
### 2. TASK INTELLIGENCE - Understanding & Execution | |
Assess how well the response understands and fulfills the user's true intent: | |
**Task Completion**: Does it fully address all explicit and implicit parts of the request? | |
**Constraint Adherence**: Does it strictly follow formatting, length, and style requirements? | |
**Latent Task Identification**: Does it correctly infer the underlying goal beyond literal instructions? | |
**Pattern Generalization**: Can it apply identified patterns to new cases? | |
### 3. INFORMATION QUALITY - Clarity & Efficiency | |
Consider how effectively information is communicated: | |
**Information Density**: Is the response concise while preserving necessary detail? | |
**Relevance Focus**: Does it stay on topic without unnecessary tangents? | |
**Structural Clarity**: Is it well-organized with clear formatting and flow? | |
### 4. EPISTEMIC AWARENESS - Intellectual Honesty | |
Evaluate the response's self-awareness and calibration: | |
**Uncertainty Calibration**: Does expressed confidence match actual accuracy? | |
**Assumption Transparency**: Are underlying assumptions clearly stated? | |
**Perspective Acknowledgment**: Does it recognize when topics have multiple valid viewpoints? | |
### 5. COGNITIVE DEPTH - Sophistication of Thought | |
Assess the depth and quality of reasoning: | |
**Insight Synthesis**: Does it provide non-obvious connections or deeper understanding? | |
**Problem Decomposition**: Are complex problems broken down into clear, logical steps? | |
**Conceptual Synthesis**: Does it skillfully combine disparate ideas into coherent insights? | |
## EVALUATION PRINCIPLES | |
When comparing responses, consider: | |
- A response that is factually accurate but incomplete is generally better than one that is complete but contains errors | |
- Clear, direct communication is preferred over verbose or overly complex explanations | |
- Responses that acknowledge their limitations are preferred over those that confidently assert uncertain information | |
- The ability to identify and address the user's underlying need is more valuable than literal instruction following when they conflict | |
## HOLISTIC ASSESSMENT | |
While each dimension is important, the best response is one that: | |
1. First and foremost, avoids harm and maintains accuracy | |
2. Genuinely understands and addresses the user's needs | |
3. Communicates clearly and efficiently | |
4. Demonstrates appropriate confidence and self-awareness | |
5. Provides insights that elevate the user's understanding | |
## IMPORTANT: You must choose ONE response as better overall | |
Even if the responses seem equal in quality, you must select the one that performs marginally better across all dimensions. If you genuinely cannot distinguish between them, pick the one that is even slightly more concise or clear. You MUST output ONLY "A" or "B" - no other text, explanations, or qualifiers.""" | |
@dataclass | |
class ICMConfig: | |
"""Configuration for ICM preference learning""" | |
# Data source | |
dataset_name: str | |
api_url: str | |
# Dataset sampling | |
num_prompts: int | |
n_shot_examples: int | |
responses_per_prompt: int | |
shuffle_buffer_size: int | |
# Response generation parameters | |
response_temperature: float | |
response_top_p: float | |
response_top_k: int | |
response_frequency_penalty: float | |
response_presence_penalty: float | |
response_repeat_penalty: float # For llama.cpp | |
response_max_tokens: int | |
# ICM core algorithm | |
initial_k: int | |
alpha: float | |
initial_temp: float | |
final_temp: float | |
beta: float | |
max_iterations: int | |
max_labels_per_dimension: int | |
fix_inconsistencies_max_iterations: int | |
inconsistency_weight_multiplier: float | |
# Dimension configuration | |
dimensions: Optional[List[str]] | |
dimension_subset: str # "all", "minimal", "safety", "task", "epistemic" | |
unified_mode: bool | |
unified_dimension_name: str | |
# Model context management | |
max_context_tokens: int | |
response_reserve_tokens: int | |
token_estimation_ratio: float | |
max_context_examples: int | |
# Pair and data management | |
max_pairs_per_dimension: int | |
max_generation_failures: int | |
# Output configuration | |
output_dir: str | |
train_split: float | |
cache_file: str | |
save_interval: int | |
generated_dataset_path: str | |
# Execution configuration | |
max_workers: int | |
random_seed: Optional[int] | |
# Network and retry configuration | |
retry_max_attempts: int | |
retry_base_delay: float | |
timeout_short: float | |
timeout_medium: float | |
timeout_long: float | |
timeout_generation: float | |
# Deduplication parameters | |
duplicate_prefix_length: int | |
duplicate_fuzzy_length: int | |
duplicate_fuzzy_prefix_min_length: int | |
# Generation diversity parameters | |
diversity_attempt_threshold: int | |
diversity_strong_threshold: int | |
max_generation_attempts: int | |
generation_oversampling_factor: float | |
# Resilience parameters | |
max_consecutive_failures: int | |
skip_indecisive_pairs: bool | |
min_labels_per_dimension: int | |
save_partial_on_error: bool | |
# Misc parameters | |
log_prob_fallback: float | |
cache_report_interval: int | |
num_logprobs: int | |
streaming: bool | |
def __post_init__(self): | |
if self.random_seed is not None: | |
random.seed(self.random_seed) | |
np.random.seed(self.random_seed) | |
logger.info(f"Random seed set to: {self.random_seed}") | |
@dataclass | |
class GeneratedResponse: | |
"""A response generated for a prompt""" | |
prompt_id: str | |
response_text: str | |
response_id: str | |
generation_params: Dict[str, Any] | |
timestamp: float | |
def __post_init__(self): | |
# Generate UUID if not provided | |
if not self.response_id: | |
self.response_id = str(uuid.uuid4())[:8] | |
@dataclass | |
class PromptContext: | |
"""Context for generating responses""" | |
prompt_id: str | |
few_shot_examples: List[Dict] | |
target_conversation: List[Dict] | |
original_response: str | |
source_id: int | |
timestamp: float | |
@property | |
def prompt_text(self) -> str: | |
"""Get the formatted prompt text""" | |
return create_n_shot_prompt(self.few_shot_examples, self.target_conversation) | |
@dataclass | |
class LabeledExample: | |
"""Pairwise comparison example""" | |
prompt_id: str | |
response_a_id: str | |
response_b_id: str | |
response_a_text: str | |
response_b_text: str | |
label_dimension: str | |
label: str # "A" or "B" | |
position_swapped: bool | |
timestamp: float | |
def get_pair_key(self) -> Tuple[str, str]: | |
"""Get unique key for this response pair""" | |
return tuple(sorted([self.response_a_id, self.response_b_id])) | |
class ResponseCache: | |
"""Cache for generated responses with deduplication""" | |
def __init__(self, config: ICMConfig): | |
self.cache_file = config.cache_file | |
self.responses_by_prompt: Dict[str, List[GeneratedResponse]] = defaultdict(list) | |
self.responses_by_id: Dict[str, GeneratedResponse] = {} | |
self.prompt_contexts: Dict[str, PromptContext] = {} | |
self.response_hashes: Set[str] = set() # For deduplication | |
self._load_cache() | |
def _hash_response(self, prompt_id: str, response_text: str) -> str: | |
"""Create hash for response deduplication""" | |
normalized_text = response_text.strip().lower() | |
normalized_text = re.sub(r"\s+", " ", normalized_text) | |
content = f"{prompt_id}:{normalized_text}" | |
return hashlib.sha256(content.encode()).hexdigest()[:16] | |
def _load_cache(self): | |
"""Load existing cache from disk""" | |
try: | |
with open(self.cache_file, "r") as f: | |
for line in f: | |
entry = json.loads(line) | |
if entry["type"] == "response": | |
response = GeneratedResponse( | |
prompt_id=entry["prompt_id"], | |
response_text=entry["response_text"], | |
response_id=entry["response_id"], | |
generation_params=entry.get("generation_params", {}), | |
timestamp=entry.get("timestamp", time.time()), | |
) | |
self.responses_by_prompt[entry["prompt_id"]].append(response) | |
self.responses_by_id[response.response_id] = response | |
resp_hash = self._hash_response( | |
response.prompt_id, response.response_text | |
) | |
self.response_hashes.add(resp_hash) | |
elif entry["type"] == "context": | |
context = PromptContext( | |
prompt_id=entry["prompt_id"], | |
few_shot_examples=entry["few_shot_examples"], | |
target_conversation=entry["target_conversation"], | |
original_response=entry["original_response"], | |
source_id=entry["source_id"], | |
timestamp=entry.get("timestamp", time.time()), | |
) | |
self.prompt_contexts[entry["prompt_id"]] = context | |
logger.info( | |
f"Loaded {len(self.responses_by_prompt)} prompts with {len(self.responses_by_id)} responses from cache" | |
) | |
except FileNotFoundError: | |
logger.info("No existing response cache found, starting fresh") | |
def add_response( | |
self, prompt_id: str, response_text: str, generation_params: Dict[str, Any] | |
) -> Optional[GeneratedResponse]: | |
"""Add response with deduplication check""" | |
resp_hash = self._hash_response(prompt_id, response_text) | |
if resp_hash in self.response_hashes: | |
logger.debug(f"Skipping duplicate response for prompt {prompt_id[:8]}...") | |
return None | |
response = GeneratedResponse( | |
prompt_id=prompt_id, | |
response_text=response_text, | |
response_id=str(uuid.uuid4())[:8], | |
generation_params=generation_params, | |
timestamp=time.time(), | |
) | |
self.responses_by_prompt[prompt_id].append(response) | |
self.responses_by_id[response.response_id] = response | |
self.response_hashes.add(resp_hash) | |
entry = { | |
"type": "response", | |
"prompt_id": response.prompt_id, | |
"response_text": response.response_text, | |
"response_id": response.response_id, | |
"generation_params": response.generation_params, | |
"timestamp": response.timestamp, | |
} | |
with open(self.cache_file, "a", buffering=1) as f: | |
f.write(json.dumps(entry) + "\n") | |
return response | |
def add_context(self, context: PromptContext): | |
"""Add prompt context to cache""" | |
self.prompt_contexts[context.prompt_id] = context | |
entry = { | |
"type": "context", | |
"prompt_id": context.prompt_id, | |
"few_shot_examples": context.few_shot_examples, | |
"target_conversation": context.target_conversation, | |
"original_response": context.original_response, | |
"source_id": context.source_id, | |
"timestamp": context.timestamp, | |
} | |
with open(self.cache_file, "a", buffering=1) as f: | |
f.write(json.dumps(entry) + "\n") | |
def get_response_by_id(self, response_id: str) -> Optional[GeneratedResponse]: | |
"""Get response by ID""" | |
return self.responses_by_id.get(response_id) | |
def get_responses_for_prompt(self, prompt_id: str) -> List[GeneratedResponse]: | |
"""Get all responses for a prompt""" | |
return self.responses_by_prompt.get(prompt_id, []) | |
def get_context(self, prompt_id: str) -> Optional[PromptContext]: | |
"""Get context for a prompt""" | |
return self.prompt_contexts.get(prompt_id) | |
def save_final(self): | |
"""Save complete cache (already incrementally saved)""" | |
logger.info(f"Cache already saved incrementally to {self.cache_file}") | |
def retry_on_network_error(config: ICMConfig): | |
"""Retry decorator for network errors only""" | |
def decorator(func): | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
last_exception = None | |
for attempt in range(config.retry_max_attempts): | |
try: | |
return func(*args, **kwargs) | |
except requests.exceptions.RequestException as e: | |
# Do not retry on ContextLengthExceededError, it must be handled by caller | |
if isinstance(e, LLMInterface.ContextLengthExceededError): | |
raise | |
last_exception = e | |
if attempt < config.retry_max_attempts - 1: | |
delay = config.retry_base_delay * (2**attempt) | |
logger.warning( | |
f"Network error on attempt {attempt + 1}: {e}. Retrying in {delay}s..." | |
) | |
time.sleep(delay) | |
except Exception as e: | |
# Do not retry on ContextLengthExceededError | |
if isinstance(e, LLMInterface.ContextLengthExceededError): | |
raise | |
raise | |
raise last_exception | |
return wrapper | |
return decorator | |
class LLMInterface: | |
"""Interface for LLM communication""" | |
class ContextLengthExceededError(ValueError): | |
"""Custom exception for when prompt + max_tokens exceeds model's context.""" | |
pass | |
def __init__(self, config: ICMConfig): | |
self.base_url = config.api_url | |
self.config = config | |
self.request_count = 0 | |
self._detect_api_type() | |
self._detect_context_length() | |
def _detect_api_type(self): | |
"""Detect LLM server type""" | |
base_url = self.base_url.rstrip("/") | |
retry_decorator = retry_on_network_error(self.config) | |
@retry_decorator | |
def check_llamacpp(): | |
response = requests.get( | |
f"{base_url}/props", timeout=self.config.timeout_short | |
) | |
if response.status_code == 200: | |
props = response.json() | |
if "model_path" in props: | |
return "llamacpp" | |
return None | |
@retry_decorator | |
def check_vllm(): | |
response = requests.get( | |
f"{base_url}/version", timeout=self.config.timeout_short | |
) | |
if response.status_code == 200: | |
return "vllm" | |
return None | |
@retry_decorator | |
def check_openai(): | |
response = requests.get( | |
f"{base_url}/v1/models", timeout=self.config.timeout_medium | |
) | |
if response.status_code == 200: | |
return "vllm" | |
return None | |
# Try to detect API type | |
try: | |
api_type = check_llamacpp() | |
if api_type: | |
self.api_type = api_type | |
logger.info(f"Detected llama.cpp server") | |
return | |
except requests.RequestException: | |
pass | |
try: | |
api_type = check_vllm() | |
if api_type: | |
self.api_type = api_type | |
logger.info("Detected vLLM server") | |
return | |
except requests.RequestException: | |
pass | |
try: | |
api_type = check_openai() | |
if api_type: | |
self.api_type = api_type | |
logger.info("Detected OpenAI-compatible API") | |
return | |
except requests.RequestException: | |
pass | |
raise ConnectionError(f"Cannot connect to LLM server at {base_url}") | |
def _detect_context_length(self): | |
"""Try to detect model's context length from API""" | |
original_context = self.config.max_context_tokens | |
try: | |
if self.api_type == "vllm": | |
response = requests.get( | |
f"{self.base_url}/v1/models", timeout=self.config.timeout_short | |
) | |
if response.status_code == 200: | |
models = response.json().get("data", []) | |
for model_info in models: | |
for field in ["max_model_len", "context_length", "max_seq_len"]: | |
if field in model_info: | |
detected = model_info[field] | |
if detected != original_context: | |
logger.info( | |
f"Detected model context length: {detected} tokens " | |
f"(was {original_context})" | |
) | |
self.config.max_context_tokens = detected | |
return | |
if "config" in model_info: | |
config = model_info["config"] | |
for field in [ | |
"max_seq_len", | |
"max_position_embeddings", | |
"n_ctx", | |
]: | |
if field in config: | |
detected = config[field] | |
if detected != original_context: | |
logger.info( | |
f"Detected model context length: {detected} tokens " | |
f"(was {original_context})" | |
) | |
self.config.max_context_tokens = detected | |
return | |
elif self.api_type == "llamacpp": | |
response = requests.get( | |
f"{self.base_url}/props", timeout=self.config.timeout_short | |
) | |
if response.status_code == 200: | |
props = response.json() | |
if "n_ctx" in props: | |
detected = props["n_ctx"] | |
if detected != original_context: | |
logger.info( | |
f"Detected model context length: {detected} tokens " | |
f"(was {original_context})" | |
) | |
self.config.max_context_tokens = detected | |
except Exception as e: | |
logger.debug(f"Could not detect context length: {e}") | |
logger.info( | |
f"Using configured context length: {self.config.max_context_tokens} tokens" | |
) | |
def _check_for_context_error(self, response: requests.Response): | |
"""Checks a response for signs of a context length error and raises.""" | |
if response.status_code == 400: | |
try: | |
error_data = response.json() | |
error_msg = str(error_data.get("error", "")).lower() | |
# Common phrases indicating context length issues from vLLM/OpenAI APIs | |
if "context length" in error_msg or "too long" in error_msg or "maximum sequence length" in error_msg: | |
raise self.ContextLengthExceededError(f"API error indicates context length exceeded: {error_msg}") | |
except (json.JSONDecodeError, AttributeError): | |
# If it's a 400 but not a clear context error, we can still infer it | |
pass | |
# If we got a 400 and it wasn't a clear network issue, it's likely a context error. | |
raise self.ContextLengthExceededError(f"Inferred context length error from HTTP 400. Response: {response.text[:200]}") | |
def generate_response( | |
self, prompt: str, max_tokens: int, temperature: float, top_p: float | |
) -> str: | |
"""Generate a response to a prompt""" | |
self.request_count += 1 | |
retry_decorator = retry_on_network_error(self.config) | |
@retry_decorator | |
def _perform_request() -> requests.Response: | |
if self.api_type == "vllm": | |
return requests.post( | |
f"{self.base_url}/v1/completions", | |
json={ | |
"prompt": prompt, | |
"max_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": self.config.response_top_k, | |
"frequency_penalty": self.config.response_frequency_penalty, | |
"presence_penalty": self.config.response_presence_penalty, | |
"stop": ["</assistant>", "\n<user>", "\n<assistant>"], | |
"seed": random.randint(0, 1000000) if temperature > 0 else 42, | |
}, | |
timeout=self.config.timeout_generation, | |
) | |
else: # llamacpp | |
return requests.post( | |
f"{self.base_url}/completion", | |
json={ | |
"prompt": prompt, | |
"n_predict": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": self.config.response_top_k, | |
"repeat_penalty": self.config.response_repeat_penalty, | |
"frequency_penalty": self.config.response_frequency_penalty, | |
"presence_penalty": self.config.response_presence_penalty, | |
"cache_prompt": True, | |
"stop": ["</assistant>", "\n<user>", "\n<assistant>"], | |
"seed": random.randint(0, 1000000) if temperature > 0 else 42, | |
}, | |
timeout=self.config.timeout_generation, | |
) | |
try: | |
response = _perform_request() | |
# Check for context error first, as it's a special failure case | |
self._check_for_context_error(response) | |
# If no context error, check for other HTTP errors | |
response.raise_for_status() | |
if self.api_type == "vllm": | |
return response.json()["choices"][0]["text"].strip() | |
else: # llamacpp | |
return response.json()["content"].strip() | |
except self.ContextLengthExceededError: | |
# Re-raise our custom exception so the caller can handle it specifically | |
raise | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Request failed after all retries: {e}") | |
raise | |
def generate_label( | |
self, prompt: str, dimension: str, max_tokens: int, temperature: float | |
) -> Optional[str]: | |
"""Generate A or B label, raising ContextLengthExceededError on failure.""" | |
self.request_count += 1 | |
retry_decorator = retry_on_network_error(self.config) | |
@retry_decorator | |
def _generate(): | |
try: | |
if self.api_type == "vllm": | |
response = requests.post( | |
f"{self.base_url}/v1/completions", | |
json={ | |
"prompt": prompt, | |
"max_tokens": max_tokens, | |
"temperature": temperature, | |
"logprobs": self.config.num_logprobs, | |
}, | |
timeout=self.config.timeout_long, | |
) | |
else: # llamacpp | |
response = requests.post( | |
f"{self.base_url}/completion", | |
json={ | |
"prompt": prompt, | |
"n_predict": max_tokens, | |
"temperature": temperature, | |
"cache_prompt": True, | |
"n_probs": self.config.num_logprobs, | |
}, | |
timeout=self.config.timeout_long, | |
) | |
# IMPORTANT: Check for context error first and raise if found | |
self._check_for_context_error(response) | |
# If no context error, check for other HTTP errors | |
response.raise_for_status() | |
response_json = response.json() | |
if self.api_type == "vllm": | |
result = response_json["choices"][0]["text"].strip() | |
logprobs_data = response_json["choices"][0].get("logprobs", {}) | |
if logprobs_data and "top_logprobs" in logprobs_data and logprobs_data["top_logprobs"]: | |
first_token_probs = logprobs_data["top_logprobs"][0] | |
else: | |
first_token_probs = {} | |
else: # llamacpp | |
result = response_json["content"].strip() | |
completion_probs = response_json.get("completion_probabilities", []) | |
if completion_probs: | |
top_logprobs_list = completion_probs[0].get("top_logprobs", []) | |
first_token_probs = {item['token']: item['logprob'] for item in top_logprobs_list} | |
else: | |
first_token_probs = {} | |
# --- Unified Logprob and Text Parsing Logic --- | |
result_upper = result.upper() | |
# Check logprobs first | |
if first_token_probs: | |
best_token, best_logprob = None, float("-inf") | |
a_total_prob, b_total_prob = 0.0, 0.0 | |
for token, logprob in first_token_probs.items(): | |
cleaned = re.sub(r"[^a-zA-Z0-9]", "", str(token)).upper() | |
if cleaned == "A": | |
a_total_prob += math.exp(logprob) | |
if logprob > best_logprob: best_logprob, best_token = logprob, "A" | |
elif cleaned == "B": | |
b_total_prob += math.exp(logprob) | |
if logprob > best_logprob: best_logprob, best_token = logprob, "B" | |
elif "A" in cleaned and "B" not in cleaned and len(cleaned) < 10: a_total_prob += math.exp(logprob) * 0.5 | |
elif "B" in cleaned and "A" not in cleaned and len(cleaned) < 10: b_total_prob += math.exp(logprob) * 0.5 | |
if best_token: return best_token | |
if a_total_prob > b_total_prob * 1.2: return "A" | |
if b_total_prob > a_total_prob * 1.2: return "B" | |
# Fallback to generated text | |
if "A" in result_upper and "B" not in result_upper: return "A" | |
if "B" in result_upper and "A" not in result_upper: return "B" | |
# Check for indecisiveness in logprobs or text | |
indecisive_terms = ["BOTH", "TIE", "EQUAL", "NEITHER", "SAME", "DRAW", "NONE", "EQUIVALENT"] | |
top_tokens_str = [str(t).upper() for t in list(first_token_probs.keys())[:5]] | |
if any(any(term in token for term in indecisive_terms) for token in top_tokens_str): | |
logger.debug(f"Model indecisive via logprobs, skipping pair.") | |
return None # Signal to skip | |
# If truly ambiguous, return None to signal a skip | |
logger.debug(f"Could not determine A/B preference from response '{result[:20]}...' or logprobs. Skipping.") | |
return None | |
except self.ContextLengthExceededError: | |
# Propagate the specific error for the retry loop | |
raise | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Network error during label generation: {e}") | |
raise # Let the retry handler deal with it | |
except Exception as e: | |
logger.error(f"Unexpected error in generate_label, cannot determine label: {e}") | |
return None # Return None for other unexpected errors | |
return _generate() | |
def get_log_prob(self, prompt: str, completion: str) -> float: | |
"""Get log probability of completion given prompt""" | |
self.request_count += 1 | |
if self.api_type == "vllm": | |
return self._get_log_prob_vllm(prompt, completion) | |
else: | |
return self._get_log_prob_llamacpp(prompt, completion) | |
def _get_log_prob_vllm(self, prompt: str, completion: str) -> float: | |
"""Get log probability for vLLM""" | |
if len(completion.strip()) > 1: | |
raise NotImplementedError( | |
"This optimized vLLM logprob function only supports single-token completions." | |
) | |
retry_decorator = retry_on_network_error(self.config) | |
@retry_decorator | |
def _get(): | |
try: | |
response = requests.post( | |
f"{self.base_url}/v1/completions", | |
json={ | |
"prompt": prompt, | |
"max_tokens": 1, | |
"temperature": 0.0, | |
"logprobs": self.config.num_logprobs, | |
}, | |
timeout=self.config.timeout_long, | |
) | |
self._check_for_context_error(response) | |
response.raise_for_status() | |
result = response.json() | |
if "choices" not in result or not result["choices"]: raise ValueError("No choices in vLLM response") | |
choice = result["choices"][0] | |
logprobs_data = choice.get("logprobs") | |
if not logprobs_data or not logprobs_data.get("top_logprobs"): raise ValueError("No logprobs in vLLM response") | |
first_token_probs = logprobs_data["top_logprobs"][0] | |
cleaned_tokens = [] | |
for token, logprob in first_token_probs.items(): | |
cleaned_token = re.sub(r"[^a-zA-Z0-9]", "", token) | |
cleaned_tokens.append(cleaned_token) | |
if cleaned_token == completion: | |
return logprob | |
logger.warning( f"Completion '{completion}' not in top {len(first_token_probs)} logprobs. " f"Returning low probability. (logprobs: {cleaned_tokens})") | |
return self.config.log_prob_fallback | |
except self.ContextLengthExceededError: | |
raise # Let caller handle this | |
except Exception as e: | |
logger.error(f"Error in get_log_prob: {e}") | |
return self.config.log_prob_fallback | |
return _get() | |
def _get_log_prob_llamacpp(self, prompt: str, completion: str) -> float: | |
"""Get log probability for llama.cpp""" | |
if len(completion) == 1: | |
retry_decorator = retry_on_network_error(self.config) | |
@retry_decorator | |
def _get(): | |
try: | |
response = requests.post( | |
f"{self.base_url}/completion", | |
json={ | |
"prompt": prompt, | |
"n_predict": 2, | |
"n_probs": self.config.num_logprobs, | |
"temperature": 0.0, | |
"cache_prompt": True, | |
}, | |
timeout=self.config.timeout_long, | |
) | |
self._check_for_context_error(response) | |
response.raise_for_status() | |
result = response.json() | |
completion_probs = result.get("completion_probabilities", []) | |
if not completion_probs: raise ValueError("No completion_probabilities in response") | |
first_pos = completion_probs[0] | |
top_candidates = first_pos.get("top_logprobs", []) | |
for candidate in top_candidates: | |
token = candidate.get("token", "") | |
cleaned_token = re.sub(r"[^a-zA-Z0-9]", "", token) | |
if cleaned_token == completion: | |
return candidate["logprob"] | |
raise ValueError(f"Completion '{completion}' not found in top candidates") | |
except self.ContextLengthExceededError: | |
raise # Let caller handle this | |
except Exception as e: | |
logger.error(f"Error in get_log_prob: {e}") | |
return self.config.log_prob_fallback | |
return _get() | |
raise NotImplementedError("Multi-token completion logprobs not implemented") | |
def normalize_conversation_format(conversation: List[Dict]) -> List[Dict]: | |
"""Normalize conversation to role/content format""" | |
normalized = [] | |
for turn in conversation: | |
if "from" in turn and "value" in turn: | |
role_map = {"human": "user", "gpt": "assistant", "system": "system"} | |
role = role_map.get(turn["from"], turn["from"]) | |
content = turn["value"] | |
elif "role" in turn and "content" in turn: | |
role = turn["role"] | |
content = turn["content"] | |
else: | |
logger.warning(f"Unknown conversation format: {turn}") | |
continue | |
normalized.append({"role": role, "content": content}) | |
return normalized | |
def format_conversation(conversation: List[Dict]) -> str: | |
"""Format a conversation into a readable string""" | |
formatted = [] | |
normalized = normalize_conversation_format(conversation) | |
for turn in normalized: | |
role = turn["role"] | |
content = turn["content"] | |
formatted.append(f"<{role}>\n{content}\n</{role}>") | |
return "\n".join(formatted) | |
def create_n_shot_prompt( | |
few_shot_examples: List[Dict], target_conversation: List[Dict] | |
) -> str: | |
"""Create an N-shot prompt from examples and target conversation""" | |
prompt_parts = [] | |
for i, example in enumerate(few_shot_examples): | |
prompt_parts.append(f"Example {i + 1}:") | |
prompt_parts.append(format_conversation(example)) | |
prompt_parts.append("") | |
prompt_parts.append(f"Example {len(few_shot_examples) + 1}:") | |
prompt_parts.append(format_conversation(target_conversation)) | |
prompt_parts.append("<assistant>") | |
return "\n".join(prompt_parts) | |
class ResponseGenerator: | |
"""Generate multiple responses for prompts""" | |
def __init__( | |
self, | |
llm: LLMInterface, | |
config: ICMConfig, | |
cache: ResponseCache, | |
executor: ThreadPoolExecutor, | |
): | |
self.llm = llm | |
self.config = config | |
self.cache = cache | |
self.executor = executor | |
def shutdown(self): | |
"""Shutdown is handled by main executor""" | |
pass | |
def prepare_prompt_contexts( | |
self, dataset_iter: Iterator, num_prompts: int | |
) -> List[PromptContext]: | |
"""Prepare prompt contexts from dataset""" | |
contexts = [] | |
examples_buffer = [] | |
# Use a larger buffer to ensure we can find n-shot examples for each target | |
pbar = tqdm(total=num_prompts, desc="Preparing prompt contexts") | |
for idx, item in enumerate(dataset_iter): | |
if len(contexts) >= num_prompts: | |
break | |
conversation = self._get_conversation(item) | |
if not conversation: continue | |
normalized = normalize_conversation_format(conversation) | |
if not any(t["role"] == "user" for t in normalized) or not any(t["role"] == "assistant" for t in normalized): continue | |
# Keep a sliding window of examples | |
examples_buffer.append((normalized, idx)) | |
if len(examples_buffer) > self.config.n_shot_examples: | |
target_conv_with_response, source_id = examples_buffer.pop(0) | |
few_shot_examples = [ex[0] for ex in examples_buffer] | |
last_assistant_idx = -1 | |
for i, turn in enumerate(target_conv_with_response): | |
if turn["role"] == "assistant": | |
last_assistant_idx = i | |
if last_assistant_idx != -1: | |
target_conversation = target_conv_with_response[:last_assistant_idx] | |
original_response = target_conv_with_response[last_assistant_idx]["content"] | |
prompt_text = format_conversation(target_conversation) | |
prompt_id = hashlib.sha256(prompt_text.encode()).hexdigest()[:16] | |
context = PromptContext( | |
prompt_id=prompt_id, | |
few_shot_examples=few_shot_examples, | |
target_conversation=target_conversation, | |
original_response=original_response, | |
source_id=source_id, | |
timestamp=time.time(), | |
) | |
contexts.append(context) | |
pbar.update(1) | |
pbar.close() | |
logger.info(f"Prepared {len(contexts)} prompt contexts") | |
return contexts | |
def _get_conversation(self, item: Dict) -> Optional[List[Dict]]: | |
"""Get conversation from dataset item""" | |
if "messages" in item: | |
return item["messages"] | |
elif "conversations" in item: | |
return item["conversations"] | |
else: | |
for key in ["conversation", "dialog", "dialogue", "chat"]: | |
if key in item: | |
return item[key] | |
return None | |
def estimate_prompt_tokens(self, prompt: str) -> int: | |
"""Rough estimate of token count""" | |
return int(len(prompt) / 3.5) # A slightly more conservative ratio | |
def generate_responses_for_prompt(self, context: PromptContext) -> List[GeneratedResponse]: | |
"""Generate multiple responses for a single prompt with robust context handling.""" | |
cached_responses = self.cache.get_responses_for_prompt(context.prompt_id) | |
if len(cached_responses) >= self.config.responses_per_prompt: | |
return cached_responses[:self.config.responses_per_prompt] | |
generated_texts_exact = {r.response_text.strip().lower() for r in cached_responses} | |
generated_texts_prefix = {r.response_text.strip().lower()[:self.config.duplicate_prefix_length] for r in cached_responses} | |
generated_texts_fuzzy = {re.sub(r"[^a-z0-9]", "", r.response_text.strip().lower())[:self.config.duplicate_fuzzy_length] for r in cached_responses} | |
successful_responses = [] | |
generation_lock = threading.Lock() | |
def generate_single(): | |
# Outermost loop: Reduce n-shot examples if context is too long | |
num_examples_to_try = self.config.n_shot_examples | |
while num_examples_to_try >= 0: | |
# Construct prompt with current number of examples | |
current_few_shot = context.few_shot_examples[-num_examples_to_try:] | |
prompt_text = create_n_shot_prompt(current_few_shot, context.target_conversation) | |
# Proactive check: If prompt itself is too big, don't even try. | |
estimated_prompt_tokens = self.estimate_prompt_tokens(prompt_text) | |
if estimated_prompt_tokens + self.config.response_max_tokens > self.config.max_context_tokens: | |
logger.debug(f"Proactively skipping {num_examples_to_try} examples, prompt too long for {context.prompt_id[:8]}.") | |
num_examples_to_try -= 1 | |
continue | |
# Middle loop: Try different generation parameters for diversity | |
for attempt in range(self.config.max_generation_attempts): | |
try: | |
temperature = self.config.response_temperature * (1.1**attempt) | |
response_text = self.llm.generate_response( | |
prompt_text, | |
max_tokens=self.config.response_max_tokens, | |
temperature=temperature, | |
top_p=min(self.config.response_top_p + (attempt * 0.01), 1.0), | |
) | |
if not response_text or not response_text.strip(): continue | |
# Deduplication logic | |
normalized_exact = response_text.strip().lower() | |
normalized_prefix = normalized_exact[:self.config.duplicate_prefix_length] | |
normalized_fuzzy = re.sub(r"[^a-z0-9]", "", normalized_exact)[:self.config.duplicate_fuzzy_length] | |
with generation_lock: | |
is_duplicate = False | |
if normalized_exact in generated_texts_exact: is_duplicate = True | |
elif len(normalized_prefix) > self.config.duplicate_fuzzy_prefix_min_length and normalized_prefix in generated_texts_prefix: is_duplicate = True | |
elif len(normalized_fuzzy) > 100 and normalized_fuzzy in generated_texts_fuzzy: is_duplicate = True | |
if is_duplicate: continue | |
response_obj = self.cache.add_response( | |
context.prompt_id, response_text, | |
{"temperature": temperature, "top_p": min(self.config.response_top_p + (attempt * 0.01), 1.0), "attempts": attempt + 1, "n_shot": num_examples_to_try} | |
) | |
if not response_obj: continue | |
# Add to sets if successfully cached | |
generated_texts_exact.add(normalized_exact) | |
generated_texts_prefix.add(normalized_prefix) | |
generated_texts_fuzzy.add(normalized_fuzzy) | |
return response_obj # Success! | |
except self.llm.ContextLengthExceededError: | |
logger.warning(f"Response generation failed for {context.prompt_id[:8]} with {num_examples_to_try} examples due to context length. Reducing.") | |
# Break the diversification loop to trigger reduction of n-shot examples | |
break | |
except Exception as e: | |
logger.error(f"Failed to generate response on attempt {attempt + 1}: {e}") | |
if attempt == self.config.max_generation_attempts - 1: | |
break # Break inner loop, but outer loop will continue | |
# This part is reached if the diversification loop breaks (e.g., from context error) | |
num_examples_to_try -= 1 | |
logger.error(f"Failed to generate a unique response for {context.prompt_id[:8]} even with 0 examples.") | |
return None | |
# Threaded execution to generate responses | |
with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor: | |
needed = self.config.responses_per_prompt - len(cached_responses) | |
if needed <= 0: return cached_responses | |
# Oversample to account for failures and duplicates | |
num_to_generate = math.ceil(needed * self.config.generation_oversampling_factor) | |
futures = [executor.submit(generate_single) for _ in range(num_to_generate)] | |
pbar = tqdm(as_completed(futures), total=len(futures), desc=f"Gen for {context.prompt_id[:8]}", leave=False) | |
for future in pbar: | |
response = future.result() | |
if response: | |
successful_responses.append(response) | |
if len(successful_responses) >= needed: | |
# Cancel remaining futures once we have enough | |
for f in futures: | |
if not f.done(): | |
f.cancel() | |
break | |
all_responses = cached_responses + successful_responses | |
if not self.cache.get_context(context.prompt_id): | |
self.cache.add_context(context) | |
if len(all_responses) < self.config.responses_per_prompt: | |
logger.warning(f"Only generated {len(all_responses)}/{self.config.responses_per_prompt} for prompt {context.prompt_id[:8]}") | |
return all_responses[:self.config.responses_per_prompt] | |
def generate_all_responses( | |
self, contexts: List[PromptContext] | |
) -> Dict[str, List[GeneratedResponse]]: | |
"""Generate responses for all prompt contexts""" | |
all_responses = {} | |
pbar = tqdm(contexts, desc="Generating responses", position=0) | |
for context in pbar: | |
# Use cached responses if they exist, otherwise generate new ones | |
cached_responses = self.cache.get_responses_for_prompt(context.prompt_id) | |
if len(cached_responses) >= self.config.responses_per_prompt: | |
pbar.set_postfix_str(f"Prompt {context.prompt_id[:8]}: {len(cached_responses)} cached") | |
all_responses[context.prompt_id] = cached_responses[:self.config.responses_per_prompt] | |
else: | |
responses = self.generate_responses_for_prompt(context) | |
all_responses[context.prompt_id] = responses | |
pbar.set_postfix_str(f"Prompt {context.prompt_id[:8]}: {len(responses)} generated") | |
pbar.close() | |
total_responses = sum(len(r) for r in all_responses.values()) | |
logger.info(f"Finished response generation. Total responses: {total_responses} for {len(contexts)} prompts") | |
return all_responses | |
def save_responses_to_hf_dataset( | |
all_responses: Dict[str, List[GeneratedResponse]], | |
contexts: List[PromptContext], | |
output_dir: str | |
): | |
"""Saves the generated responses to a Hugging Face Dataset on disk.""" | |
total_responses = sum(len(r) for r in all_responses.values()) | |
if total_responses == 0: | |
logger.warning("No responses were generated, skipping dataset creation.") | |
return | |
logger.info(f"Preparing to save {total_responses} responses to Hugging Face dataset at '{output_dir}'") | |
# Create a lookup for contexts by prompt_id for efficiency | |
contexts_by_id = {c.prompt_id: c for c in contexts} | |
data_to_save = defaultdict(list) | |
for prompt_id, responses in tqdm(all_responses.items(), desc="Formatting for dataset"): | |
prompt_context = contexts_by_id.get(prompt_id) | |
if not prompt_context: | |
logger.warning(f"Could not find context for prompt_id {prompt_id}. Skipping its responses.") | |
continue | |
for response in responses: | |
data_to_save["prompt"].append(prompt_context.target_conversation) | |
data_to_save["completion"].append(response.response_text) | |
data_to_save["prompt_id"].append(response.prompt_id) | |
data_to_save["response_id"].append(response.response_id) | |
data_to_save["source_id"].append(prompt_context.source_id) | |
data_to_save["generation_params"].append(json.dumps(response.generation_params)) | |
if not data_to_save["prompt"]: | |
logger.error("No data to save after formatting. Aborting dataset creation.") | |
return | |
# Define the features for type safety and clarity | |
features = Features({ | |
'prompt': Sequence({'role': Value('string'), 'content': Value('string')}), | |
'completion': Value('string'), | |
'prompt_id': Value('string'), | |
'response_id': Value('string'), | |
'source_id': Value('int64'), | |
'generation_params': Value('string'), # Storing as a JSON string | |
}) | |
try: | |
hf_dataset = Dataset.from_dict(dict(data_to_save), features=features) | |
logger.info(f"Created dataset with {len(hf_dataset)} rows. Saving to disk...") | |
hf_dataset.save_to_disk(output_dir) | |
logger.info(f"Successfully saved generated responses dataset to '{output_dir}'") | |
except Exception as e: | |
logger.exception(f"Failed to create or save Hugging Face dataset: {e}") | |
def analyze_response_diversity( | |
all_responses: Dict[str, List[GeneratedResponse]], config: ICMConfig | |
): | |
"""Analyze diversity of generated responses""" | |
print("\n📊 Response Diversity Analysis:") | |
print("=" * 60) | |
total_duplicates = 0 | |
total_near_duplicates = 0 | |
total_fuzzy_duplicates = 0 | |
for prompt_id, responses in all_responses.items(): | |
if len(responses) < 2: | |
continue | |
texts = [r.response_text.strip() for r in responses] | |
unique_texts = set(texts) | |
duplicates = len(texts) - len(unique_texts) | |
normalized_texts = [t.lower() for t in texts] | |
unique_normalized = set(normalized_texts) | |
near_duplicates = len(normalized_texts) - len(unique_normalized) | |
fuzzy_texts = [ | |
re.sub(r"[^a-z0-9]", "", t.lower())[: config.duplicate_fuzzy_length] | |
for t in texts | |
] | |
unique_fuzzy = set(fuzzy_texts) | |
fuzzy_duplicates = len(fuzzy_texts) - len(unique_fuzzy) | |
if duplicates > 0 or near_duplicates > 0 or fuzzy_duplicates > 0: | |
print(f"\nPrompt {prompt_id[:8]}...") | |
print(f" Total responses: {len(responses)}") | |
print(f" Exact duplicates: {duplicates}") | |
print(f" Near duplicates (case): {near_duplicates}") | |
print(f" Fuzzy duplicates: {fuzzy_duplicates}") | |
if duplicates > 0: | |
print(" Example duplicate:") | |
for i, text in enumerate(texts): | |
if texts.count(text) > 1: | |
print(f" Response {i+1}: {text[:50]}...") | |
break | |
total_duplicates += duplicates | |
total_near_duplicates += near_duplicates | |
total_fuzzy_duplicates += fuzzy_duplicates | |
print(f"\n📈 Summary:") | |
print(f" Total prompts: {len(all_responses)}") | |
print(f" Total responses: {sum(len(r) for r in all_responses.values())}") | |
print(f" Total exact duplicates: {total_duplicates}") | |
print(f" Total near duplicates (case): {total_near_duplicates}") | |
print(f" Total fuzzy duplicates: {total_fuzzy_duplicates}") | |
if ( | |
total_duplicates == 0 | |
and total_near_duplicates == 0 | |
and total_fuzzy_duplicates == 0 | |
): | |
print("\n✅ Excellent! No duplicates detected at any level.") | |
else: | |
print("\n⚠️ Some duplicates were found despite prevention measures.") | |
print("=" * 60) | |
class ICMLabeler: | |
"""Internal Coherence Maximization for labeling response pairs""" | |
def __init__( | |
self, | |
llm: LLMInterface, | |
config: ICMConfig, | |
cache: ResponseCache, | |
executor: ThreadPoolExecutor, | |
): | |
self.llm = llm | |
self.config = config | |
self.cache = cache | |
self.executor = executor | |
self.labeled_data: Dict[str, List[LabeledExample]] = defaultdict(list) | |
self.generation_failures: Dict[str, int] = defaultdict(int) | |
self.pair_tracker: Dict[str, Set[Tuple[str, str]]] = defaultdict(set) | |
self.inconsistency_degree: Dict[str, Dict[Tuple[str, str], int]] = defaultdict(dict) | |
self.log_prob_cache: Dict[str, Dict[int, float]] = defaultdict(dict) | |
self.dirty_indices: Dict[str, Set[int]] = defaultdict(set) | |
self.total_log_probs: Dict[str, float] = defaultdict(float) | |
self.skipped_pairs: Dict[str, int] = defaultdict(int) | |
self.consecutive_failures: Dict[str, int] = defaultdict(int) | |
# NEW: Adaptive state for the "Fail, Reduce, Retry" strategy | |
self.adaptive_example_count: Dict[str, int] = defaultdict(lambda: self.config.max_context_examples) | |
def shutdown(self): | |
"""Shutdown is handled by main executor""" | |
pass | |
def _mark_all_dirty(self, dimension: str): | |
"""Mark all examples as needing recomputation""" | |
n_examples = len(self.labeled_data[dimension]) | |
self.dirty_indices[dimension] = set(range(n_examples)) | |
def _mark_dirty_except(self, dimension: str, skip_idx: int): | |
"""Mark all examples except one as needing recomputation""" | |
n_examples = len(self.labeled_data[dimension]) | |
self.dirty_indices[dimension] = set(range(n_examples)) - {skip_idx} | |
def create_comparison_prompt( | |
self, | |
prompt_context: PromptContext, | |
response_a: str, | |
response_b: str, | |
dimension: str, | |
position_swapped: bool, | |
) -> str: | |
"""Create prompt for comparing two responses""" | |
dimension_text = dimension.replace("_", " ") | |
rubric = DIMENSION_RUBRICS.get(dimension, "") | |
if position_swapped: | |
first_response, second_response = response_b, response_a | |
first_label, second_label = "B", "A" | |
else: | |
first_response, second_response = response_a, response_b | |
first_label, second_label = "A", "B" | |
conversation_text = format_conversation(prompt_context.target_conversation) | |
prompt = f"""You are an impartial judge evaluating two AI assistant responses. | |
Here is the conversation that led to these responses: | |
{conversation_text} | |
Response {first_label}: | |
{first_response} | |
Response {second_label}: | |
{second_response} | |
Which response demonstrates better {dimension_text} ({rubric})? | |
IMPORTANT: You MUST choose either "{first_label}" or "{second_label}". If they seem very similar, pick the one that is even slightly better. Do NOT say "tie", "both", "equal", or anything else. | |
Reply with ONLY "{first_label}" or "{second_label}" - nothing else. | |
Answer: """ | |
return prompt | |
def create_context_prompt( | |
self, | |
examples: List[LabeledExample], | |
new_example: LabeledExample, | |
prompt_context: PromptContext, | |
position_swapped: bool, | |
) -> str: | |
"""Create prompt with examples in context (no token calculation).""" | |
dimension_text = new_example.label_dimension.replace("_", " ") | |
rubric = DIMENSION_RUBRICS.get(new_example.label_dimension, "") | |
prompt = "<golden_dataset>\n" | |
prompt += f"<task_description>\nYou are an impartial judge.\n" | |
prompt += f"Compare AI assistant responses on {dimension_text} ({rubric}).\n" | |
prompt += f"IMPORTANT: Always choose A or B. Never say tie/both/equal.\n" | |
prompt += f"</task_description>\n\n" | |
prompt += f"<best_comparisons>\n" | |
prompt_parts = [] | |
for i, ex in enumerate(examples): | |
ex_context = self.cache.get_context(ex.prompt_id) | |
if not ex_context: continue | |
if ex.position_swapped: | |
first_resp, second_resp = ex.response_b_text, ex.response_a_text | |
first_label, second_label = "B", "A" | |
display_label = "B" if ex.label == "A" else "A" | |
else: | |
first_resp, second_resp = ex.response_a_text, ex.response_b_text | |
first_label, second_label = "A", "B" | |
display_label = ex.label | |
conversation_text = format_conversation(ex_context.target_conversation) | |
example_text = f"<comparison_{i + 1}>\n" | |
example_text += f"<conversation>\n{conversation_text}\n</conversation>\n" | |
example_text += f"<response_{first_label.lower()}>\n{first_resp}\n</response_{first_label.lower()}>\n" | |
example_text += f"<response_{second_label.lower()}>\n{second_resp}\n</response_{second_label.lower()}>\n" | |
example_text += f"<dimension>\n{dimension_text}\n</dimension>\n" | |
example_text += f"<rubric>\n{rubric}\n</rubric>\n" | |
example_text += f"<verdict>\n <better>{display_label}</better>\n</verdict>\n" | |
example_text += f"</comparison_{i + 1}>\n" | |
prompt_parts.append(example_text) | |
prompt += "".join(prompt_parts) | |
if position_swapped: | |
first_resp, second_resp = new_example.response_b_text, new_example.response_a_text | |
first_label, second_label = "B", "A" | |
else: | |
first_resp, second_resp = new_example.response_a_text, new_example.response_b_text | |
first_label, second_label = "A", "B" | |
conversation_text = format_conversation(prompt_context.target_conversation) | |
prompt += f"<comparison_{len(examples) + 1}>\n" | |
prompt += f"<conversation>\n{conversation_text}\n</conversation>\n" | |
prompt += f"<response_{first_label.lower()}>\n{first_resp}\n</response_{first_label.lower()}>\n" | |
prompt += f"<response_{second_label.lower()}>\n{second_resp}\n</response_{second_label.lower()}>\n" | |
prompt += f"<dimension>\n{dimension_text}\n</dimension>\n" | |
prompt += f"<rubric>\n{rubric}\n</rubric>\n" | |
prompt += f"<verdict>\n <better>" | |
return prompt | |
def create_unified_comparison_prompt( | |
self, | |
prompt_context: PromptContext, | |
response_a: str, | |
response_b: str, | |
position_swapped: bool, | |
) -> str: | |
"""Create a unified prompt that evaluates all dimensions holistically""" | |
if position_swapped: | |
first_response, second_response = response_b, response_a | |
first_label, second_label = "B", "A" | |
else: | |
first_response, second_response = response_a, response_b | |
first_label, second_label = "A", "B" | |
conversation_text = format_conversation(prompt_context.target_conversation) | |
prompt = f"""{UNIFIED_ONTOLOGY_PROMPT} | |
Now, evaluate these two responses to the following conversation: | |
Conversation: | |
{conversation_text} | |
Response {first_label}: | |
{first_response} | |
Response {second_label}: | |
{second_response} | |
Considering all dimensions of the evaluation framework above, which response is better overall for the given conversation? | |
CRITICAL: You MUST output ONLY "{first_label}" or "{second_label}" - no other text, explanations, or qualifiers. Even if they seem similar, pick the one that's marginally better. | |
Answer: """ | |
return prompt | |
def create_unified_context_prompt( | |
self, | |
examples: List[LabeledExample], | |
new_example: LabeledExample, | |
prompt_context: PromptContext, | |
position_swapped: bool, | |
) -> str: | |
"""Create unified prompt with examples in context (no token calculation).""" | |
prompt = "<golden_dataset>\n" | |
prompt += "<task_description>\n" | |
prompt += "You are an expert evaluator using a comprehensive quality framework.\n" | |
prompt += "Evaluate responses holistically across all dimensions: safety, accuracy, task fulfillment, clarity, self-awareness, and depth.\n" | |
prompt += "ALWAYS choose A or B - never say tie/both/equal.\n" | |
prompt += "</task_description>\n\n" | |
prompt += "<evaluation_framework>\n" | |
prompt += UNIFIED_ONTOLOGY_PROMPT | |
prompt += "\n</evaluation_framework>\n\n" | |
prompt += "<best_comparisons>\n" | |
prompt_parts = [] | |
for i, ex in enumerate(examples): | |
ex_context = self.cache.get_context(ex.prompt_id) | |
if not ex_context: continue | |
if ex.position_swapped: | |
ex_first_resp, ex_second_resp = ex.response_b_text, ex.response_a_text | |
ex_first_label, ex_second_label = "B", "A" | |
display_label = "B" if ex.label == "A" else "A" | |
else: | |
ex_first_resp, ex_second_resp = ex.response_a_text, ex.response_b_text | |
ex_first_label, ex_second_label = "A", "B" | |
display_label = ex.label | |
conversation_text = format_conversation(ex_context.target_conversation) | |
example_text = f"<comparison_{i + 1}>\n" | |
example_text += f"<conversation>\n{conversation_text}\n</conversation>\n" | |
example_text += f"<response_{ex_first_label.lower()}>\n{ex_first_resp}\n</response_{ex_first_label.lower()}>\n" | |
example_text += f"<response_{ex_second_label.lower()}>\n{ex_second_resp}\n</response_{ex_second_label.lower()}>\n" | |
example_text += f"<verdict>\n <better>{display_label}</better>\n</verdict>\n" | |
example_text += f"</comparison_{i + 1}>\n" | |
prompt_parts.append(example_text) | |
prompt += "".join(prompt_parts) | |
if position_swapped: | |
first_resp, second_resp = new_example.response_b_text, new_example.response_a_text | |
first_label, second_label = "B", "A" | |
else: | |
first_resp, second_resp = new_example.response_a_text, new_example.response_b_text | |
first_label, second_label = "A", "B" | |
conversation_text = format_conversation(prompt_context.target_conversation) | |
prompt += f"<comparison_{len(examples) + 1}>\n" | |
prompt += f"<conversation>\n{conversation_text}\n</conversation>\n" | |
prompt += f"<response_{first_label.lower()}>\n{first_resp}\n</response_{first_label.lower()}>\n" | |
prompt += f"<response_{second_label.lower()}>\n{second_resp}\n</response_{second_label.lower()}>\n" | |
prompt += f"<verdict>\n <better>" | |
return prompt | |
def compute_mutual_predictability(self, dimension: str, pbar) -> float: | |
"""Compute mutual predictability score P_θ(D) on full dataset""" | |
examples = self.labeled_data[dimension] | |
if len(examples) < 2: return 0.0 | |
if dimension not in self.total_log_probs or len(self.log_prob_cache[dimension]) != len(examples): | |
self.dirty_indices[dimension] = set(range(len(examples))) | |
self.log_prob_cache[dimension].clear() | |
self.total_log_probs[dimension] = 0.0 | |
indices_to_compute = list(self.dirty_indices[dimension]) | |
if not indices_to_compute: return self.total_log_probs[dimension] | |
logger.debug(f"Recomputing {len(indices_to_compute)}/{len(examples)} examples for {dimension}") | |
if pbar: pbar.set_postfix_str(f"{len(self.labeled_data[dimension]):3d} | Skip: {self.skipped_pairs[dimension]:3d} | Computing P_θ(D)... (0/{len(indices_to_compute)})") | |
prompts_to_process = [] | |
for idx in indices_to_compute: | |
if idx >= len(examples): continue | |
num_examples_to_try = self.adaptive_example_count[dimension] | |
log_prob = self.config.log_prob_fallback | |
while num_examples_to_try >= 0: | |
try: | |
context_examples = [ex for i, ex in enumerate(examples) if i != idx] | |
context_examples = context_examples[-num_examples_to_try:] | |
example = examples[idx] | |
prompt_context = self.cache.get_context(example.prompt_id) | |
if not prompt_context: break | |
if self.config.unified_mode: | |
prompt = self.create_unified_context_prompt(context_examples, example, prompt_context, position_swapped=example.position_swapped) | |
else: | |
prompt = self.create_context_prompt(context_examples, example, prompt_context, position_swapped=example.position_swapped) | |
if example.position_swapped: | |
expected_label = "B" if example.label == "A" else "A" | |
else: | |
expected_label = example.label | |
log_prob = self.llm.get_log_prob(prompt, expected_label) | |
# Do not adapt example count during log_prob computation, only during labeling | |
break # Success | |
except self.llm.ContextLengthExceededError: | |
num_examples_to_try -= 1 | |
except Exception as e: | |
logger.error(f"Error during P_theta computation: {e}") | |
break # Don't retry on other errors | |
prompts_to_process.append({"idx": idx, "log_prob": log_prob}) | |
for item in prompts_to_process: | |
idx, new_log_prob = item['idx'], item['log_prob'] | |
old_log_prob = self.log_prob_cache[dimension].get(idx, 0.0) | |
self.log_prob_cache[dimension][idx] = new_log_prob | |
self.total_log_probs[dimension] += new_log_prob - old_log_prob | |
self.dirty_indices[dimension].clear() | |
return self.total_log_probs[dimension] | |
def check_logical_consistency( | |
self, ex1: LabeledExample, ex2: LabeledExample | |
) -> bool: | |
"""Check if two labels are logically consistent (asymmetry only)""" | |
if ( | |
ex1.response_a_id == ex2.response_b_id | |
and ex1.response_b_id == ex2.response_a_id | |
): | |
return ex1.label != ex2.label | |
return True | |
def find_inconsistent_pairs(self, dimension: str) -> List[Tuple[int, int]]: | |
"""Find all inconsistent pairs efficiently""" | |
examples = self.labeled_data[dimension] | |
inconsistent = [] | |
pair_groups = defaultdict(list) | |
for i, ex in enumerate(examples): | |
pair_key = ex.get_pair_key() | |
pair_groups[pair_key].append((i, ex)) | |
for pair_key, group in pair_groups.items(): | |
for i, (idx1, ex1) in enumerate(group): | |
for idx2, ex2 in group[i + 1 :]: | |
if not self.check_logical_consistency(ex1, ex2): | |
inconsistent.append((idx1, idx2)) | |
return inconsistent | |
def compute_inconsistency_score(self, dimension: str) -> float: | |
"""Compute total inconsistency I(D)""" | |
return float(len(self.find_inconsistent_pairs(dimension))) | |
def compute_scoring_function(self, dimension: str, pbar) -> float: | |
"""Compute U(D) = α * P_θ(D) - I(D)""" | |
mutual_pred = self.compute_mutual_predictability(dimension, pbar) | |
inconsistency = self.compute_inconsistency_score(dimension) | |
return self.config.alpha * mutual_pred - inconsistency | |
def fix_inconsistencies(self, dimension: str): | |
"""Fix inconsistencies using Algorithm 2 from the paper""" | |
max_iterations = self.config.fix_inconsistencies_max_iterations | |
examples = self.labeled_data[dimension] | |
for iteration in range(max_iterations): | |
inconsistent_pairs = self.find_inconsistent_pairs(dimension) | |
if not inconsistent_pairs: return | |
i, j = random.choice(inconsistent_pairs) | |
if i >= len(examples) or j >= len(examples): continue | |
best_score = float("-inf") | |
best_labels = None | |
for label_i in ["A", "B"]: | |
for label_j in ["A", "B"]: | |
old_label_i, old_label_j = examples[i].label, examples[j].label | |
examples[i].label = label_i | |
examples[j].label = label_j | |
if self.check_logical_consistency(examples[i], examples[j]): | |
score = self.compute_scoring_function(dimension, pbar=None) | |
if score > best_score: | |
best_score = score | |
best_labels = (label_i, label_j) | |
examples[i].label = old_label_i | |
examples[j].label = old_label_j | |
current_score = self.compute_scoring_function(dimension, pbar=None) | |
if best_labels and best_score > current_score: | |
old_label_i, old_label_j = examples[i].label, examples[j].label | |
examples[i].label, examples[j].label = best_labels | |
if old_label_i != examples[i].label: self.dirty_indices[dimension].add(i) | |
if old_label_j != examples[j].label: self.dirty_indices[dimension].add(j) | |
def update_inconsistency_degrees(self, dimension: str): | |
"""Update inconsistency degrees for weighted sampling""" | |
self.inconsistency_degree[dimension].clear() | |
examples = self.labeled_data[dimension] | |
for i, ex in enumerate(examples): | |
pair_key = ex.get_pair_key() | |
if pair_key not in self.inconsistency_degree[dimension]: self.inconsistency_degree[dimension][pair_key] = 0 | |
for j, other in enumerate(examples): | |
if i != j: | |
other_key = other.get_pair_key() | |
if pair_key == other_key and not self.check_logical_consistency(ex, other): | |
self.inconsistency_degree[dimension][pair_key] += 1 | |
def sample_next_pair(self, available_pairs: Iterator, dimension: str) -> Tuple: | |
# For simplicity with an iterator, we just take the next one. | |
# Weighted sampling is more complex with a pure iterator. | |
return next(available_pairs) | |
def batch_generate_initial_labels(self, pairs: List[Tuple], dimension: str, unified_mode: bool): | |
"""Generate initial labels in parallel.""" | |
def generate_single(pair_data): | |
prompt_id, resp_a, resp_b = pair_data | |
position_swapped = random.choice([True, False]) | |
prompt_context = self.cache.get_context(prompt_id) | |
if not prompt_context: return {"success": False, "error": f"Missing context for {prompt_id}"} | |
# For initialization, we use zero-shot to avoid context length issues | |
if unified_mode: | |
prompt = self.create_unified_comparison_prompt(prompt_context, resp_a.response_text, resp_b.response_text, position_swapped) | |
else: | |
prompt = self.create_comparison_prompt(prompt_context, resp_a.response_text, resp_b.response_text, dimension, position_swapped) | |
try: | |
raw_label = self.llm.generate_label(prompt, dimension, max_tokens=2, temperature=0.0) | |
if raw_label is None: return {"success": False, "skipped": True, "error": "Model indecisive"} | |
label = "B" if raw_label == "A" and position_swapped else "A" if raw_label == "B" and position_swapped else raw_label | |
example = LabeledExample(prompt_id=prompt_id, response_a_id=resp_a.response_id, response_b_id=resp_b.response_id, response_a_text=resp_a.response_text, response_b_text=resp_b.response_text, label_dimension=dimension, label=label, position_swapped=position_swapped, timestamp=time.time()) | |
return {"success": True, "example": example} | |
except self.llm.ContextLengthExceededError: | |
return {"success": False, "skipped": True, "error": "Context too long for zero-shot"} | |
except Exception as e: | |
return {"success": False, "error": str(e)} | |
results = [] | |
futures = {self.executor.submit(generate_single, pair): pair for pair in pairs} | |
pbar = tqdm(as_completed(futures), total=len(futures), desc=" Initializing", position=2, leave=False) | |
for future in pbar: results.append(future.result()) | |
pbar.close() | |
return results | |
def label_dimension(self, response_pairs: List[Tuple[str, GeneratedResponse, GeneratedResponse]], dimension: str, save_callback, unified_mode: bool): | |
logger.info(f"Labeling dimension: {dimension} with {len(response_pairs)} pairs" + (" (unified mode)" if unified_mode else "")) | |
self.generation_failures[dimension] = 0 | |
self.skipped_pairs[dimension] = 0 | |
self.consecutive_failures[dimension] = 0 | |
# --- Initialization Phase --- | |
initial_pairs_to_try = response_pairs[:self.config.initial_k * 5] | |
batch_results = self.batch_generate_initial_labels(initial_pairs_to_try, dimension, unified_mode) | |
for result in batch_results: | |
if len(self.labeled_data[dimension]) >= self.config.initial_k: | |
break | |
if result["success"]: | |
example = result["example"] | |
self.labeled_data[dimension].append(example) | |
self.pair_tracker[dimension].add(example.get_pair_key()) | |
elif result.get("skipped"): | |
self.skipped_pairs[dimension] += 1 | |
else: | |
self.generation_failures[dimension] += 1 | |
if len(self.labeled_data[dimension]) < 2: | |
logger.error(f"Could not initialize enough examples for {dimension}. Got {len(self.labeled_data[dimension])}. Stopping.") | |
if self.config.save_partial_on_error and self.labeled_data[dimension]: | |
save_callback(dimension, self.labeled_data[dimension]) | |
return | |
logger.info(f"Successfully initialized with {len(self.labeled_data[dimension])} labels.") | |
self._mark_all_dirty(dimension) | |
self.fix_inconsistencies(dimension) | |
# --- Main MCMC Loop --- | |
# Create an iterator for all pairs that were NOT used in initialization | |
main_pair_iterator = iter(p for p in response_pairs if tuple(sorted([p[1].response_id, p[2].response_id])) not in self.pair_tracker[dimension]) | |
pbar = tqdm(range(self.config.max_iterations), desc=f"{dimension[:25]:25s}", position=1, leave=False) | |
last_save_count = len(self.labeled_data[dimension]) | |
for iteration in pbar: | |
if len(self.labeled_data[dimension]) >= self.config.max_labels_per_dimension: | |
logger.info(f"Reached max_labels_per_dimension ({self.config.max_labels_per_dimension}). Stopping.") | |
break | |
temperature = max(self.config.final_temp, self.config.initial_temp / (1 + self.config.beta * math.log(iteration + 1))) | |
pbar.set_postfix_str(f"Labels: {len(self.labeled_data[dimension])} | Skip: {self.skipped_pairs[dimension]:3d} | Ex: {self.adaptive_example_count[dimension]:2d} | T:{temperature:.2f}") | |
if self.consecutive_failures[dimension] >= self.config.max_consecutive_failures: | |
logger.warning(f"Too many consecutive failures ({self.consecutive_failures[dimension]}) for {dimension}. Stopping.") | |
break | |
try: | |
# We don't use weighted sampling here to keep it simple and ensure we process all pairs if possible | |
prompt_id, resp_a, resp_b = next(main_pair_iterator) | |
except StopIteration: | |
pbar.close() | |
logger.info("Exhausted all pairs") | |
break | |
current_score = self.compute_scoring_function(dimension, pbar) | |
position_swapped = random.choice([True, False]) | |
new_example = LabeledExample(prompt_id=prompt_id, response_a_id=resp_a.response_id, response_b_id=resp_b.response_id, response_a_text=resp_a.response_text, response_b_text=resp_b.response_text, label_dimension=dimension, label="", position_swapped=position_swapped, timestamp=time.time()) | |
prompt_context = self.cache.get_context(prompt_id) | |
if not prompt_context: continue | |
raw_label = None | |
num_examples_to_try = self.adaptive_example_count[dimension] | |
while num_examples_to_try >= 0: | |
try: | |
context_examples = self.labeled_data[dimension][-num_examples_to_try:] | |
if unified_mode: | |
prompt = self.create_unified_context_prompt(context_examples, new_example, prompt_context, position_swapped) if num_examples_to_try > 0 else self.create_unified_comparison_prompt(prompt_context, resp_a.response_text, resp_b.response_text, position_swapped) | |
else: | |
prompt = self.create_context_prompt(context_examples, new_example, prompt_context, position_swapped) if num_examples_to_try > 0 else self.create_comparison_prompt(prompt_context, resp_a.response_text, resp_b.response_text, dimension, position_swapped) | |
raw_label = self.llm.generate_label(prompt, dimension, max_tokens=2, temperature=0.0) | |
self.adaptive_example_count[dimension] = num_examples_to_try | |
break | |
except self.llm.ContextLengthExceededError: | |
logger.debug(f"Context too long with {num_examples_to_try} examples. Reducing.") | |
num_examples_to_try -= 1 | |
except Exception as e: | |
logger.error(f"Error generating label: {e}") | |
self.generation_failures[dimension] += 1 | |
raw_label = "ERROR" | |
break | |
if raw_label == "ERROR": continue | |
if raw_label is None: | |
self.skipped_pairs[dimension] += 1 | |
self.consecutive_failures[dimension] += 1 | |
if num_examples_to_try < 0: | |
logger.warning(f"Pair {resp_a.response_id[:4]}-{resp_b.response_id[:4]} failed even with 0 examples.") | |
continue | |
self.consecutive_failures[dimension] = 0 | |
proposed_label = "B" if raw_label == "A" and position_swapped else "A" if raw_label == "B" and position_swapped else raw_label | |
new_example.label = proposed_label | |
self.labeled_data[dimension].append(new_example) | |
self.pair_tracker[dimension].add(new_example.get_pair_key()) | |
new_idx = len(self.labeled_data[dimension]) - 1 | |
self._mark_dirty_except(dimension, new_idx) | |
self.fix_inconsistencies(dimension) | |
new_score = self.compute_scoring_function(dimension, pbar) | |
delta = new_score - current_score | |
acceptance_prob = 1.0 if delta > 0 else math.exp(delta / temperature) | |
if random.random() > acceptance_prob: | |
self.labeled_data[dimension].pop() | |
self.pair_tracker[dimension].remove(new_example.get_pair_key()) | |
self._mark_all_dirty(dimension) | |
if save_callback and len(self.labeled_data[dimension]) - last_save_count >= self.config.save_interval: | |
save_callback(dimension, self.labeled_data[dimension]) | |
last_save_count = len(self.labeled_data[dimension]) | |
pbar.close() | |
if save_callback and self.labeled_data[dimension]: | |
save_callback(dimension, self.labeled_data[dimension]) | |
logger.info(f"Completed {dimension}: {len(self.labeled_data[dimension])} labels, {self.skipped_pairs[dimension]} skipped, {self.generation_failures[dimension]} failures.") | |
def create_response_pairs( | |
all_responses: Dict[str, List[GeneratedResponse]], | |
) -> List[Tuple[str, GeneratedResponse, GeneratedResponse]]: | |
"""Create all pairs of responses for comparison""" | |
pairs = [] | |
for prompt_id, responses in all_responses.items(): | |
for i in range(len(responses)): | |
for j in range(i + 1, len(responses)): | |
pairs.append((prompt_id, responses[i], responses[j])) | |
random.shuffle(pairs) | |
logger.info(f"Created {len(pairs)} response pairs for comparison") | |
return pairs | |
def save_labeled_data( | |
labeled_examples: List[LabeledExample], | |
response_cache: ResponseCache, | |
dimension: str, | |
config: ICMConfig, | |
): | |
"""Save labeled preference data""" | |
if not labeled_examples: | |
logger.warning(f"No labeled examples for dimension {dimension}") | |
return | |
dim_output = os.path.join(config.output_dir, f"preference_{dimension}") | |
os.makedirs(dim_output, exist_ok=True) | |
preference_data = [] | |
for ex in labeled_examples: | |
context = response_cache.get_context(ex.prompt_id) | |
if not context: | |
logger.warning(f"Missing context for prompt {ex.prompt_id}") | |
continue | |
if ex.label == "A": | |
chosen = ex.response_a_text | |
rejected = ex.response_b_text | |
chosen_id = ex.response_a_id | |
rejected_id = ex.response_b_id | |
else: | |
chosen = ex.response_b_text | |
rejected = ex.response_a_text | |
chosen_id = ex.response_b_id | |
rejected_id = ex.response_a_id | |
preference_data.append( | |
{ | |
"prompt_id": ex.prompt_id, | |
"chosen": chosen, | |
"rejected": rejected, | |
"chosen_id": chosen_id, | |
"rejected_id": rejected_id, | |
"dimension": dimension, | |
"timestamp": ex.timestamp, | |
} | |
) | |
random.shuffle(preference_data) | |
split_idx = int(len(preference_data) * config.train_split) | |
train_data = preference_data[:split_idx] | |
test_data = preference_data[split_idx:] | |
with open(os.path.join(dim_output, "train.jsonl"), "w") as f: | |
for item in train_data: | |
f.write(json.dumps(item) + "\n") | |
with open(os.path.join(dim_output, "test.jsonl"), "w") as f: | |
for item in test_data: | |
f.write(json.dumps(item) + "\n") | |
with open(os.path.join(dim_output, "all_comparisons.jsonl"), "w") as f: | |
for ex in labeled_examples: | |
comparison = { | |
"prompt_id": ex.prompt_id, | |
"response_a": ex.response_a_text, | |
"response_b": ex.response_b_text, | |
"label": ex.label, | |
"dimension": ex.label_dimension, | |
"response_a_id": ex.response_a_id, | |
"response_b_id": ex.response_b_id, | |
"timestamp": ex.timestamp, | |
} | |
f.write(json.dumps(comparison) + "\n") | |
with open(os.path.join(dim_output, "prompts.jsonl"), "w") as f: | |
unique_prompts = {ex.prompt_id for ex in labeled_examples} | |
for prompt_id in unique_prompts: | |
context = response_cache.get_context(prompt_id) | |
if context: | |
prompt_data = { | |
"prompt_id": prompt_id, | |
"prompt_text": context.prompt_text, | |
"few_shot_examples": context.few_shot_examples, | |
"target_conversation": context.target_conversation, | |
"original_response": context.original_response, | |
} | |
f.write(json.dumps(prompt_data) + "\n") | |
metadata = { | |
"dimension": dimension, | |
"train_samples": len(train_data), | |
"test_samples": len(test_data), | |
"total_samples": len(preference_data), | |
"rubric": DIMENSION_RUBRICS.get(dimension, ""), | |
"timestamp": time.time(), | |
} | |
with open(os.path.join(dim_output, "metadata.json"), "w") as f: | |
json.dump(metadata, f, indent=2) | |
logger.info( | |
f"Saved preference dataset for {dimension}: {len(train_data)} train, {len(test_data)} test" | |
) | |
# Global references for signal handler | |
_global_response_cache = None | |
_global_response_generator = None | |
_global_labeler = None | |
_global_executor = None | |
def signal_handler(sig, frame): | |
"""Handle Ctrl+C gracefully""" | |
logger.info("\nInterrupted! Cleaning up...") | |
if _global_labeler and hasattr(_global_labeler, "labeled_data"): | |
total_labels = sum( | |
len(labels) for labels in _global_labeler.labeled_data.values() | |
) | |
if total_labels > 0: | |
logger.info(f"Saved {total_labels} labels across dimensions:") | |
for dim, labels in _global_labeler.labeled_data.items(): | |
if labels: | |
logger.info(f" • {dim}: {len(labels)} labels") | |
if _global_response_cache: | |
logger.info("Saving response cache...") | |
_global_response_cache.save_final() | |
if _global_response_generator: | |
_global_response_generator.shutdown() | |
if _global_labeler: | |
_global_labeler.shutdown() | |
if _global_executor: | |
_global_executor.shutdown(wait=False, cancel_futures=True) | |
sys.exit(0) | |
def cleanup_resources(): | |
"""Cleanup function for atexit""" | |
if _global_response_cache: | |
try: _global_response_cache.save_final() | |
except: pass | |
if _global_response_generator: | |
try: _global_response_generator.shutdown() | |
except: pass | |
if _global_labeler: | |
try: _global_labeler.shutdown() | |
except: pass | |
if _global_executor: | |
try: _global_executor.shutdown(wait=False, cancel_futures=True) | |
except: pass | |
def label_single_dimension(dimension: str, pairs: List[Tuple], labeler, save_callback, config): | |
"""Label a single dimension with better error recovery""" | |
try: | |
labeler.label_dimension(pairs, dimension, save_callback=save_callback, unified_mode=config.unified_mode) | |
labels = labeler.labeled_data.get(dimension, []) | |
skipped = labeler.skipped_pairs.get(dimension, 0) | |
failures = labeler.generation_failures.get(dimension, 0) | |
logger.info(f"Completed {dimension}: {len(labels)} labeled, {skipped} skipped, {failures} errors") | |
return True, len(labels) | |
except Exception as e: | |
logger.exception(f"Critical error processing {dimension}: {e}") | |
partial_labels = labeler.labeled_data.get(dimension, []) | |
if config.save_partial_on_error and partial_labels: | |
logger.info(f"Saving {len(partial_labels)} partial results for {dimension}") | |
save_callback(dimension, partial_labels) | |
return False, len(partial_labels) | |
def main(): | |
global _global_response_cache, _global_response_generator, _global_labeler, _global_executor | |
atexit.register(cleanup_resources) | |
signal.signal(signal.SIGINT, signal_handler) | |
# Configuration for a small test run | |
config = ICMConfig( | |
# Data source | |
dataset_name="allenai/tulu-3-sft-mixture", | |
api_url="http://localhost:8093", | |
# Dataset sampling | |
num_prompts=5000, | |
n_shot_examples=3, | |
responses_per_prompt=25, | |
shuffle_buffer_size=100_000, | |
# Response generation parameters | |
response_temperature=1.0, | |
response_top_p=1.0, | |
response_top_k=0, | |
response_frequency_penalty=0.0, | |
response_presence_penalty=0.0, | |
response_repeat_penalty=1.0, | |
response_max_tokens=4096, | |
# ICM core algorithm | |
initial_k=8, | |
alpha=50.0, | |
initial_temp=5.0, # RELAXED: Lowered for higher initial acceptance | |
final_temp=0.01, | |
beta=0.95, # RELAXED: Cools slightly faster | |
max_iterations=10000, # Realistic ceiling for the test run | |
max_labels_per_dimension=25000, | |
fix_inconsistencies_max_iterations=100, | |
inconsistency_weight_multiplier=100.0, | |
# Dimension configuration | |
dimensions=None, | |
dimension_subset="minimal", | |
unified_mode=True, | |
unified_dimension_name="unified_quality", | |
# Model context management | |
max_context_tokens=32768, | |
response_reserve_tokens=4096, | |
token_estimation_ratio=0.75, | |
max_context_examples=25, | |
# Pair and data management | |
max_pairs_per_dimension=1500000, | |
max_generation_failures=500, | |
# Output configuration | |
output_dir="preference_datasets", | |
train_split=0.9, | |
cache_file="response_cache.jsonl", | |
save_interval=100, | |
generated_dataset_path="generated_responses_dataset", | |
# Execution configuration | |
max_workers=32, | |
random_seed=43, | |
# Network and retry configuration | |
retry_max_attempts=3, | |
retry_base_delay=1.0, | |
timeout_short=15.0, | |
timeout_medium=30.0, | |
timeout_long=90.0, | |
timeout_generation=180.0, | |
# Deduplication parameters | |
duplicate_prefix_length=200, | |
duplicate_fuzzy_length=500, | |
duplicate_fuzzy_prefix_min_length=50, | |
# Generation diversity parameters | |
diversity_attempt_threshold=3, | |
diversity_strong_threshold=10, | |
max_generation_attempts=20, | |
generation_oversampling_factor=1.25, | |
# Resilience parameters | |
max_consecutive_failures=50, | |
skip_indecisive_pairs=True, | |
min_labels_per_dimension=100, | |
save_partial_on_error=True, | |
# Misc parameters | |
log_prob_fallback=-100.0, | |
cache_report_interval=300, | |
num_logprobs=20, | |
streaming=False, | |
) | |
shared_executor = ThreadPoolExecutor(max_workers=config.max_workers) | |
_global_executor = shared_executor | |
llm = LLMInterface(config) | |
response_cache = ResponseCache(config) | |
_global_response_cache = response_cache | |
logger.info(f"Loading dataset: {config.dataset_name}") | |
dataset = load_dataset(config.dataset_name, split="train", streaming=config.streaming) | |
if config.streaming: | |
iterable = dataset.shuffle(buffer_size=config.shuffle_buffer_size, seed=config.random_seed).take(config.num_prompts * (config.n_shot_examples + 5)) | |
else: | |
iterable = iter(dataset.shuffle(seed=config.random_seed)) | |
response_generator = ResponseGenerator(llm, config, response_cache, shared_executor) | |
_global_response_generator = response_generator | |
logger.info("Preparing prompt contexts...") | |
contexts = response_generator.prepare_prompt_contexts(iterable, config.num_prompts) | |
logger.info("Generating responses...") | |
all_responses = response_generator.generate_all_responses(contexts) | |
# --- NEW: Save generated responses to a dataset --- | |
if all_responses: | |
save_responses_to_hf_dataset( | |
all_responses, contexts, config.generated_dataset_path | |
) | |
# --------------------------------------------------- | |
analyze_response_diversity(all_responses, config) | |
if config.unified_mode: | |
dimensions_to_label = [config.unified_dimension_name] | |
elif config.dimensions: | |
dimensions_to_label = config.dimensions | |
else: | |
subset_map = { | |
"all": list(VALID_LABELS), | |
"minimal": ["harmlessness", "factual_accuracy", "logical_coherence", "task_completion", "latent_task_identification", "uncertainty_calibration"], | |
"safety": ["policy_compliance", "harmlessness"], | |
"task": ["task_completion", "constraint_adherence", "latent_task_identification", "pattern_generalization", "information_density", "relevance_focus", "structural_clarity"], | |
"epistemic": ["uncertainty_calibration", "assumption_transparency", "perspective_awareness", "insight_synthesis", "problem_decomposition", "conceptual_synthesis"], | |
} | |
dimensions_to_label = subset_map.get(config.dimension_subset, []) | |
logger.info(f"Will label {len(dimensions_to_label)} dimensions: {dimensions_to_label}") | |
response_pairs = create_response_pairs(all_responses) | |
labeler = ICMLabeler(llm, config, response_cache, shared_executor) | |
_global_labeler = labeler | |
def save_dimension_progress(dimension: str, labeled_examples: List[LabeledExample]): | |
save_labeled_data(labeled_examples, response_cache, dimension, config) | |
overall_pbar = tqdm(dimensions_to_label, desc="Overall Progress", position=0) | |
for dimension in overall_pbar: | |
overall_pbar.set_postfix_str(f"Current: {dimension[:20]}...") | |
print(f"\n{'='*60}\nProcessing dimension: {dimension}\n{'='*60}", flush=True) | |
pairs_copy = response_pairs.copy() | |
success, label_count = label_single_dimension(dimension, pairs_copy, labeler, save_dimension_progress, config) | |
overall_pbar.set_postfix_str(f"{'✓' if success else '✗'} {dimension} ({label_count} labels)") | |
overall_pbar.close() | |
response_cache.save_final() | |
shared_executor.shutdown(wait=True) | |
# --- Final Summary --- | |
print("\n" + "=" * 80 + "\n" + " " * 30 + "LABELING COMPLETE!" + " " * 30 + "\n" + "=" * 80) | |
print(f"\n💾 Output Files:") | |
print(f" • Generated responses dataset: {config.generated_dataset_path}") | |
print(f" • Preference datasets: {config.output_dir}/preference_*/") | |
print(f"\n📋 Per-Dimension Results:") | |
print(f" {'Dimension':<30} {'Labels':>8} {'Skipped':>8} {'Adaptive Ex':>12} {'A':>6} {'B':>6} {'A/B Ratio':>10}") | |
print(f" {'-'*30} {'-'*8} {'-'*8} {'-'*12} {'-'*6} {'-'*6} {'-'*10}") | |
summary_rows = [] | |
for dimension in dimensions_to_label: | |
if dimension in labeler.labeled_data: | |
labels = labeler.labeled_data[dimension] | |
skipped = labeler.skipped_pairs.get(dimension, 0) | |
adaptive_ex = labeler.adaptive_example_count.get(dimension, 'N/A') | |
a_count = sum(1 for l in labels if l.label == "A") | |
b_count = sum(1 for l in labels if l.label == "B") | |
ratio = a_count / b_count if b_count > 0 else float("inf") | |
print(f" {dimension:<30} {len(labels):>8} {skipped:>8} {str(adaptive_ex):>12} {a_count:>6} {b_count:>6} {f'{ratio:.2f}':>10}") | |
summary_rows.append({"dimension": dimension, "total_labels": len(labels), "skipped_pairs": skipped, "final_adaptive_examples": adaptive_ex, "a_count": a_count, "b_count": b_count, "ratio": ratio, "timestamp": time.time()}) | |
with open("run_summary.csv", "w", newline="") as f: | |
writer = csv.DictWriter(f, fieldnames=["dimension", "total_labels", "skipped_pairs", "final_adaptive_examples", "a_count", "b_count", "ratio", "timestamp"]) | |
writer.writeheader() | |
writer.writerows(summary_rows) | |
logger.info("Saved run summary to run_summary.csv") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment