Skip to content

Instantly share code, notes, and snippets.

@ericflo
Last active June 26, 2025 23:27
Show Gist options
  • Save ericflo/f8d3978a86f163f54586ec74e496cf8a to your computer and use it in GitHub Desktop.
Save ericflo/f8d3978a86f163f54586ec74e496cf8a to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "datasets",
# "numpy",
# "requests",
# "tqdm",
# "pyarrow"
# ]
# ///
"""
label_finetome.py
Internal Coherence Maximization (ICM) with Self-Play Response Generation
Combines:
- N-shot prompting to generate M diverse responses per prompt
- Saves all generated responses to a Hugging Face Dataset before labeling.
- Full ICM algorithm (MCMC acceptance/rejection, mutual predictability, inconsistency repair)
- Adaptive selection of which pairs to label based on potential information gain
- Aggressive duplicate prevention at multiple levels (exact, prefix, fuzzy matching)
- Robust "Fail, Reduce, Retry" logic for handling API context length limits.
VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve Qwen/Qwen3-8B-Base --port 8093 --disable-log-requests --gpu-memory-utilization 0.8
NCCL_P2P_DISABLE=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve Qwen/Qwen3-8B-Base --port 8093 --disable-log-requests --gpu-memory-utilization 0.8 --tensor-parallel-size 8
"""
import json
import random
import re
import math
import numpy as np
from typing import List, Dict, Tuple, Optional, Set, Iterator, Any
from dataclasses import dataclass, field
from collections import defaultdict
import requests
from datasets import load_dataset, Dataset, Features, Value, Sequence
import logging
from tqdm import tqdm
import time
import hashlib
import signal
import sys
import csv
from functools import wraps
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
import os
import atexit
import uuid
import threading
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Multi-dimensional label ontology for comprehensive preference modeling
LABEL_ONTOLOGY = [
# AXIS 0: Safety Gates
["policy_compliance", "harmlessness"],
# AXIS 1: Reality Fidelity
["factual_accuracy", "evidence_grounding"],
["logical_coherence", "causal_soundness", "internal_consistency"],
# AXIS 2: Task Intelligence
["task_completion", "constraint_adherence"],
["latent_task_identification", "pattern_generalization"],
# AXIS 3: Information Quality
["information_density", "relevance_focus", "structural_clarity"],
# AXIS 4: Epistemic Honesty
["uncertainty_calibration", "assumption_transparency", "perspective_awareness"],
# AXIS 5: Cognitive Depth
["insight_synthesis", "problem_decomposition", "conceptual_synthesis"],
]
# Build label validation set
VALID_LABELS = set()
LABEL_TO_AXIS = {}
for axis_id, labels in enumerate(LABEL_ONTOLOGY):
for label in labels:
VALID_LABELS.add(label)
LABEL_TO_AXIS[label] = axis_id
# Dimension rubrics
DIMENSION_RUBRICS = {
"policy_compliance": "adherence to content policies without violations",
"harmlessness": "absence of potential for real-world harm",
"factual_accuracy": "correctness of factual claims",
"evidence_grounding": "support by evidence or sources",
"logical_coherence": "reasoning without fallacies or contradictions",
"causal_soundness": "correct cause-effect relationships",
"internal_consistency": "absence of self-contradictions",
"task_completion": "fully addressing all parts of a request",
"constraint_adherence": "following specified constraints and formats",
"latent_task_identification": "identifying underlying goals or patterns",
"pattern_generalization": "applying patterns to new cases",
"information_density": "high value per word",
"relevance_focus": "staying on topic without tangents",
"structural_clarity": "clear organization and presentation",
"uncertainty_calibration": "appropriate expression of confidence",
"assumption_transparency": "clarity about underlying assumptions",
"perspective_awareness": "recognition of different viewpoints",
"insight_synthesis": "deep, integrative understanding",
"problem_decomposition": "breaking down complex problems clearly",
"conceptual_synthesis": "combining ideas into coherent wholes",
}
UNIFIED_ONTOLOGY_PROMPT = """You are an expert evaluator assessing AI assistant responses across multiple dimensions of quality. Your evaluation should consider the following comprehensive framework:
## FOUNDATIONAL REQUIREMENTS (Non-Negotiable)
**Safety & Compliance**
- The response must not violate content policies (no hate speech, PII exposure, or illicit content)
- The response must not provide advice that could cause real-world harm
## CORE QUALITY DIMENSIONS
### 1. REALITY FIDELITY - Grounding in Truth
Evaluate how well the response reflects accurate knowledge and sound reasoning:
**Factual Accuracy**: Are all factual claims correct and free from hallucination?
**Evidence Grounding**: Are claims properly supported by the context or cited sources?
**Logical Coherence**: Is the reasoning free from fallacies and contradictions?
**Causal Soundness**: Are cause-effect relationships correctly represented?
**Internal Consistency**: Is the response free from self-contradictions?
### 2. TASK INTELLIGENCE - Understanding & Execution
Assess how well the response understands and fulfills the user's true intent:
**Task Completion**: Does it fully address all explicit and implicit parts of the request?
**Constraint Adherence**: Does it strictly follow formatting, length, and style requirements?
**Latent Task Identification**: Does it correctly infer the underlying goal beyond literal instructions?
**Pattern Generalization**: Can it apply identified patterns to new cases?
### 3. INFORMATION QUALITY - Clarity & Efficiency
Consider how effectively information is communicated:
**Information Density**: Is the response concise while preserving necessary detail?
**Relevance Focus**: Does it stay on topic without unnecessary tangents?
**Structural Clarity**: Is it well-organized with clear formatting and flow?
### 4. EPISTEMIC AWARENESS - Intellectual Honesty
Evaluate the response's self-awareness and calibration:
**Uncertainty Calibration**: Does expressed confidence match actual accuracy?
**Assumption Transparency**: Are underlying assumptions clearly stated?
**Perspective Acknowledgment**: Does it recognize when topics have multiple valid viewpoints?
### 5. COGNITIVE DEPTH - Sophistication of Thought
Assess the depth and quality of reasoning:
**Insight Synthesis**: Does it provide non-obvious connections or deeper understanding?
**Problem Decomposition**: Are complex problems broken down into clear, logical steps?
**Conceptual Synthesis**: Does it skillfully combine disparate ideas into coherent insights?
## EVALUATION PRINCIPLES
When comparing responses, consider:
- A response that is factually accurate but incomplete is generally better than one that is complete but contains errors
- Clear, direct communication is preferred over verbose or overly complex explanations
- Responses that acknowledge their limitations are preferred over those that confidently assert uncertain information
- The ability to identify and address the user's underlying need is more valuable than literal instruction following when they conflict
## HOLISTIC ASSESSMENT
While each dimension is important, the best response is one that:
1. First and foremost, avoids harm and maintains accuracy
2. Genuinely understands and addresses the user's needs
3. Communicates clearly and efficiently
4. Demonstrates appropriate confidence and self-awareness
5. Provides insights that elevate the user's understanding
## IMPORTANT: You must choose ONE response as better overall
Even if the responses seem equal in quality, you must select the one that performs marginally better across all dimensions. If you genuinely cannot distinguish between them, pick the one that is even slightly more concise or clear. You MUST output ONLY "A" or "B" - no other text, explanations, or qualifiers."""
@dataclass
class ICMConfig:
"""Configuration for ICM preference learning"""
# Data source
dataset_name: str
api_url: str
# Dataset sampling
num_prompts: int
n_shot_examples: int
responses_per_prompt: int
shuffle_buffer_size: int
# Response generation parameters
response_temperature: float
response_top_p: float
response_top_k: int
response_frequency_penalty: float
response_presence_penalty: float
response_repeat_penalty: float # For llama.cpp
response_max_tokens: int
# ICM core algorithm
initial_k: int
alpha: float
initial_temp: float
final_temp: float
beta: float
max_iterations: int
max_labels_per_dimension: int
fix_inconsistencies_max_iterations: int
inconsistency_weight_multiplier: float
# Dimension configuration
dimensions: Optional[List[str]]
dimension_subset: str # "all", "minimal", "safety", "task", "epistemic"
unified_mode: bool
unified_dimension_name: str
# Model context management
max_context_tokens: int
response_reserve_tokens: int
token_estimation_ratio: float
max_context_examples: int
# Pair and data management
max_pairs_per_dimension: int
max_generation_failures: int
# Output configuration
output_dir: str
train_split: float
cache_file: str
save_interval: int
generated_dataset_path: str
# Execution configuration
max_workers: int
random_seed: Optional[int]
# Network and retry configuration
retry_max_attempts: int
retry_base_delay: float
timeout_short: float
timeout_medium: float
timeout_long: float
timeout_generation: float
# Deduplication parameters
duplicate_prefix_length: int
duplicate_fuzzy_length: int
duplicate_fuzzy_prefix_min_length: int
# Generation diversity parameters
diversity_attempt_threshold: int
diversity_strong_threshold: int
max_generation_attempts: int
generation_oversampling_factor: float
# Resilience parameters
max_consecutive_failures: int
skip_indecisive_pairs: bool
min_labels_per_dimension: int
save_partial_on_error: bool
# Misc parameters
log_prob_fallback: float
cache_report_interval: int
num_logprobs: int
streaming: bool
def __post_init__(self):
if self.random_seed is not None:
random.seed(self.random_seed)
np.random.seed(self.random_seed)
logger.info(f"Random seed set to: {self.random_seed}")
@dataclass
class GeneratedResponse:
"""A response generated for a prompt"""
prompt_id: str
response_text: str
response_id: str
generation_params: Dict[str, Any]
timestamp: float
def __post_init__(self):
# Generate UUID if not provided
if not self.response_id:
self.response_id = str(uuid.uuid4())[:8]
@dataclass
class PromptContext:
"""Context for generating responses"""
prompt_id: str
few_shot_examples: List[Dict]
target_conversation: List[Dict]
original_response: str
source_id: int
timestamp: float
@property
def prompt_text(self) -> str:
"""Get the formatted prompt text"""
return create_n_shot_prompt(self.few_shot_examples, self.target_conversation)
@dataclass
class LabeledExample:
"""Pairwise comparison example"""
prompt_id: str
response_a_id: str
response_b_id: str
response_a_text: str
response_b_text: str
label_dimension: str
label: str # "A" or "B"
position_swapped: bool
timestamp: float
def get_pair_key(self) -> Tuple[str, str]:
"""Get unique key for this response pair"""
return tuple(sorted([self.response_a_id, self.response_b_id]))
class ResponseCache:
"""Cache for generated responses with deduplication"""
def __init__(self, config: ICMConfig):
self.cache_file = config.cache_file
self.responses_by_prompt: Dict[str, List[GeneratedResponse]] = defaultdict(list)
self.responses_by_id: Dict[str, GeneratedResponse] = {}
self.prompt_contexts: Dict[str, PromptContext] = {}
self.response_hashes: Set[str] = set() # For deduplication
self._load_cache()
def _hash_response(self, prompt_id: str, response_text: str) -> str:
"""Create hash for response deduplication"""
normalized_text = response_text.strip().lower()
normalized_text = re.sub(r"\s+", " ", normalized_text)
content = f"{prompt_id}:{normalized_text}"
return hashlib.sha256(content.encode()).hexdigest()[:16]
def _load_cache(self):
"""Load existing cache from disk"""
try:
with open(self.cache_file, "r") as f:
for line in f:
entry = json.loads(line)
if entry["type"] == "response":
response = GeneratedResponse(
prompt_id=entry["prompt_id"],
response_text=entry["response_text"],
response_id=entry["response_id"],
generation_params=entry.get("generation_params", {}),
timestamp=entry.get("timestamp", time.time()),
)
self.responses_by_prompt[entry["prompt_id"]].append(response)
self.responses_by_id[response.response_id] = response
resp_hash = self._hash_response(
response.prompt_id, response.response_text
)
self.response_hashes.add(resp_hash)
elif entry["type"] == "context":
context = PromptContext(
prompt_id=entry["prompt_id"],
few_shot_examples=entry["few_shot_examples"],
target_conversation=entry["target_conversation"],
original_response=entry["original_response"],
source_id=entry["source_id"],
timestamp=entry.get("timestamp", time.time()),
)
self.prompt_contexts[entry["prompt_id"]] = context
logger.info(
f"Loaded {len(self.responses_by_prompt)} prompts with {len(self.responses_by_id)} responses from cache"
)
except FileNotFoundError:
logger.info("No existing response cache found, starting fresh")
def add_response(
self, prompt_id: str, response_text: str, generation_params: Dict[str, Any]
) -> Optional[GeneratedResponse]:
"""Add response with deduplication check"""
resp_hash = self._hash_response(prompt_id, response_text)
if resp_hash in self.response_hashes:
logger.debug(f"Skipping duplicate response for prompt {prompt_id[:8]}...")
return None
response = GeneratedResponse(
prompt_id=prompt_id,
response_text=response_text,
response_id=str(uuid.uuid4())[:8],
generation_params=generation_params,
timestamp=time.time(),
)
self.responses_by_prompt[prompt_id].append(response)
self.responses_by_id[response.response_id] = response
self.response_hashes.add(resp_hash)
entry = {
"type": "response",
"prompt_id": response.prompt_id,
"response_text": response.response_text,
"response_id": response.response_id,
"generation_params": response.generation_params,
"timestamp": response.timestamp,
}
with open(self.cache_file, "a", buffering=1) as f:
f.write(json.dumps(entry) + "\n")
return response
def add_context(self, context: PromptContext):
"""Add prompt context to cache"""
self.prompt_contexts[context.prompt_id] = context
entry = {
"type": "context",
"prompt_id": context.prompt_id,
"few_shot_examples": context.few_shot_examples,
"target_conversation": context.target_conversation,
"original_response": context.original_response,
"source_id": context.source_id,
"timestamp": context.timestamp,
}
with open(self.cache_file, "a", buffering=1) as f:
f.write(json.dumps(entry) + "\n")
def get_response_by_id(self, response_id: str) -> Optional[GeneratedResponse]:
"""Get response by ID"""
return self.responses_by_id.get(response_id)
def get_responses_for_prompt(self, prompt_id: str) -> List[GeneratedResponse]:
"""Get all responses for a prompt"""
return self.responses_by_prompt.get(prompt_id, [])
def get_context(self, prompt_id: str) -> Optional[PromptContext]:
"""Get context for a prompt"""
return self.prompt_contexts.get(prompt_id)
def save_final(self):
"""Save complete cache (already incrementally saved)"""
logger.info(f"Cache already saved incrementally to {self.cache_file}")
def retry_on_network_error(config: ICMConfig):
"""Retry decorator for network errors only"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(config.retry_max_attempts):
try:
return func(*args, **kwargs)
except requests.exceptions.RequestException as e:
# Do not retry on ContextLengthExceededError, it must be handled by caller
if isinstance(e, LLMInterface.ContextLengthExceededError):
raise
last_exception = e
if attempt < config.retry_max_attempts - 1:
delay = config.retry_base_delay * (2**attempt)
logger.warning(
f"Network error on attempt {attempt + 1}: {e}. Retrying in {delay}s..."
)
time.sleep(delay)
except Exception as e:
# Do not retry on ContextLengthExceededError
if isinstance(e, LLMInterface.ContextLengthExceededError):
raise
raise
raise last_exception
return wrapper
return decorator
class LLMInterface:
"""Interface for LLM communication"""
class ContextLengthExceededError(ValueError):
"""Custom exception for when prompt + max_tokens exceeds model's context."""
pass
def __init__(self, config: ICMConfig):
self.base_url = config.api_url
self.config = config
self.request_count = 0
self._detect_api_type()
self._detect_context_length()
def _detect_api_type(self):
"""Detect LLM server type"""
base_url = self.base_url.rstrip("/")
retry_decorator = retry_on_network_error(self.config)
@retry_decorator
def check_llamacpp():
response = requests.get(
f"{base_url}/props", timeout=self.config.timeout_short
)
if response.status_code == 200:
props = response.json()
if "model_path" in props:
return "llamacpp"
return None
@retry_decorator
def check_vllm():
response = requests.get(
f"{base_url}/version", timeout=self.config.timeout_short
)
if response.status_code == 200:
return "vllm"
return None
@retry_decorator
def check_openai():
response = requests.get(
f"{base_url}/v1/models", timeout=self.config.timeout_medium
)
if response.status_code == 200:
return "vllm"
return None
# Try to detect API type
try:
api_type = check_llamacpp()
if api_type:
self.api_type = api_type
logger.info(f"Detected llama.cpp server")
return
except requests.RequestException:
pass
try:
api_type = check_vllm()
if api_type:
self.api_type = api_type
logger.info("Detected vLLM server")
return
except requests.RequestException:
pass
try:
api_type = check_openai()
if api_type:
self.api_type = api_type
logger.info("Detected OpenAI-compatible API")
return
except requests.RequestException:
pass
raise ConnectionError(f"Cannot connect to LLM server at {base_url}")
def _detect_context_length(self):
"""Try to detect model's context length from API"""
original_context = self.config.max_context_tokens
try:
if self.api_type == "vllm":
response = requests.get(
f"{self.base_url}/v1/models", timeout=self.config.timeout_short
)
if response.status_code == 200:
models = response.json().get("data", [])
for model_info in models:
for field in ["max_model_len", "context_length", "max_seq_len"]:
if field in model_info:
detected = model_info[field]
if detected != original_context:
logger.info(
f"Detected model context length: {detected} tokens "
f"(was {original_context})"
)
self.config.max_context_tokens = detected
return
if "config" in model_info:
config = model_info["config"]
for field in [
"max_seq_len",
"max_position_embeddings",
"n_ctx",
]:
if field in config:
detected = config[field]
if detected != original_context:
logger.info(
f"Detected model context length: {detected} tokens "
f"(was {original_context})"
)
self.config.max_context_tokens = detected
return
elif self.api_type == "llamacpp":
response = requests.get(
f"{self.base_url}/props", timeout=self.config.timeout_short
)
if response.status_code == 200:
props = response.json()
if "n_ctx" in props:
detected = props["n_ctx"]
if detected != original_context:
logger.info(
f"Detected model context length: {detected} tokens "
f"(was {original_context})"
)
self.config.max_context_tokens = detected
except Exception as e:
logger.debug(f"Could not detect context length: {e}")
logger.info(
f"Using configured context length: {self.config.max_context_tokens} tokens"
)
def _check_for_context_error(self, response: requests.Response):
"""Checks a response for signs of a context length error and raises."""
if response.status_code == 400:
try:
error_data = response.json()
error_msg = str(error_data.get("error", "")).lower()
# Common phrases indicating context length issues from vLLM/OpenAI APIs
if "context length" in error_msg or "too long" in error_msg or "maximum sequence length" in error_msg:
raise self.ContextLengthExceededError(f"API error indicates context length exceeded: {error_msg}")
except (json.JSONDecodeError, AttributeError):
# If it's a 400 but not a clear context error, we can still infer it
pass
# If we got a 400 and it wasn't a clear network issue, it's likely a context error.
raise self.ContextLengthExceededError(f"Inferred context length error from HTTP 400. Response: {response.text[:200]}")
def generate_response(
self, prompt: str, max_tokens: int, temperature: float, top_p: float
) -> str:
"""Generate a response to a prompt"""
self.request_count += 1
retry_decorator = retry_on_network_error(self.config)
@retry_decorator
def _perform_request() -> requests.Response:
if self.api_type == "vllm":
return requests.post(
f"{self.base_url}/v1/completions",
json={
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": self.config.response_top_k,
"frequency_penalty": self.config.response_frequency_penalty,
"presence_penalty": self.config.response_presence_penalty,
"stop": ["</assistant>", "\n<user>", "\n<assistant>"],
"seed": random.randint(0, 1000000) if temperature > 0 else 42,
},
timeout=self.config.timeout_generation,
)
else: # llamacpp
return requests.post(
f"{self.base_url}/completion",
json={
"prompt": prompt,
"n_predict": max_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": self.config.response_top_k,
"repeat_penalty": self.config.response_repeat_penalty,
"frequency_penalty": self.config.response_frequency_penalty,
"presence_penalty": self.config.response_presence_penalty,
"cache_prompt": True,
"stop": ["</assistant>", "\n<user>", "\n<assistant>"],
"seed": random.randint(0, 1000000) if temperature > 0 else 42,
},
timeout=self.config.timeout_generation,
)
try:
response = _perform_request()
# Check for context error first, as it's a special failure case
self._check_for_context_error(response)
# If no context error, check for other HTTP errors
response.raise_for_status()
if self.api_type == "vllm":
return response.json()["choices"][0]["text"].strip()
else: # llamacpp
return response.json()["content"].strip()
except self.ContextLengthExceededError:
# Re-raise our custom exception so the caller can handle it specifically
raise
except requests.exceptions.RequestException as e:
logger.error(f"Request failed after all retries: {e}")
raise
def generate_label(
self, prompt: str, dimension: str, max_tokens: int, temperature: float
) -> Optional[str]:
"""Generate A or B label, raising ContextLengthExceededError on failure."""
self.request_count += 1
retry_decorator = retry_on_network_error(self.config)
@retry_decorator
def _generate():
try:
if self.api_type == "vllm":
response = requests.post(
f"{self.base_url}/v1/completions",
json={
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"logprobs": self.config.num_logprobs,
},
timeout=self.config.timeout_long,
)
else: # llamacpp
response = requests.post(
f"{self.base_url}/completion",
json={
"prompt": prompt,
"n_predict": max_tokens,
"temperature": temperature,
"cache_prompt": True,
"n_probs": self.config.num_logprobs,
},
timeout=self.config.timeout_long,
)
# IMPORTANT: Check for context error first and raise if found
self._check_for_context_error(response)
# If no context error, check for other HTTP errors
response.raise_for_status()
response_json = response.json()
if self.api_type == "vllm":
result = response_json["choices"][0]["text"].strip()
logprobs_data = response_json["choices"][0].get("logprobs", {})
if logprobs_data and "top_logprobs" in logprobs_data and logprobs_data["top_logprobs"]:
first_token_probs = logprobs_data["top_logprobs"][0]
else:
first_token_probs = {}
else: # llamacpp
result = response_json["content"].strip()
completion_probs = response_json.get("completion_probabilities", [])
if completion_probs:
top_logprobs_list = completion_probs[0].get("top_logprobs", [])
first_token_probs = {item['token']: item['logprob'] for item in top_logprobs_list}
else:
first_token_probs = {}
# --- Unified Logprob and Text Parsing Logic ---
result_upper = result.upper()
# Check logprobs first
if first_token_probs:
best_token, best_logprob = None, float("-inf")
a_total_prob, b_total_prob = 0.0, 0.0
for token, logprob in first_token_probs.items():
cleaned = re.sub(r"[^a-zA-Z0-9]", "", str(token)).upper()
if cleaned == "A":
a_total_prob += math.exp(logprob)
if logprob > best_logprob: best_logprob, best_token = logprob, "A"
elif cleaned == "B":
b_total_prob += math.exp(logprob)
if logprob > best_logprob: best_logprob, best_token = logprob, "B"
elif "A" in cleaned and "B" not in cleaned and len(cleaned) < 10: a_total_prob += math.exp(logprob) * 0.5
elif "B" in cleaned and "A" not in cleaned and len(cleaned) < 10: b_total_prob += math.exp(logprob) * 0.5
if best_token: return best_token
if a_total_prob > b_total_prob * 1.2: return "A"
if b_total_prob > a_total_prob * 1.2: return "B"
# Fallback to generated text
if "A" in result_upper and "B" not in result_upper: return "A"
if "B" in result_upper and "A" not in result_upper: return "B"
# Check for indecisiveness in logprobs or text
indecisive_terms = ["BOTH", "TIE", "EQUAL", "NEITHER", "SAME", "DRAW", "NONE", "EQUIVALENT"]
top_tokens_str = [str(t).upper() for t in list(first_token_probs.keys())[:5]]
if any(any(term in token for term in indecisive_terms) for token in top_tokens_str):
logger.debug(f"Model indecisive via logprobs, skipping pair.")
return None # Signal to skip
# If truly ambiguous, return None to signal a skip
logger.debug(f"Could not determine A/B preference from response '{result[:20]}...' or logprobs. Skipping.")
return None
except self.ContextLengthExceededError:
# Propagate the specific error for the retry loop
raise
except requests.exceptions.RequestException as e:
logger.error(f"Network error during label generation: {e}")
raise # Let the retry handler deal with it
except Exception as e:
logger.error(f"Unexpected error in generate_label, cannot determine label: {e}")
return None # Return None for other unexpected errors
return _generate()
def get_log_prob(self, prompt: str, completion: str) -> float:
"""Get log probability of completion given prompt"""
self.request_count += 1
if self.api_type == "vllm":
return self._get_log_prob_vllm(prompt, completion)
else:
return self._get_log_prob_llamacpp(prompt, completion)
def _get_log_prob_vllm(self, prompt: str, completion: str) -> float:
"""Get log probability for vLLM"""
if len(completion.strip()) > 1:
raise NotImplementedError(
"This optimized vLLM logprob function only supports single-token completions."
)
retry_decorator = retry_on_network_error(self.config)
@retry_decorator
def _get():
try:
response = requests.post(
f"{self.base_url}/v1/completions",
json={
"prompt": prompt,
"max_tokens": 1,
"temperature": 0.0,
"logprobs": self.config.num_logprobs,
},
timeout=self.config.timeout_long,
)
self._check_for_context_error(response)
response.raise_for_status()
result = response.json()
if "choices" not in result or not result["choices"]: raise ValueError("No choices in vLLM response")
choice = result["choices"][0]
logprobs_data = choice.get("logprobs")
if not logprobs_data or not logprobs_data.get("top_logprobs"): raise ValueError("No logprobs in vLLM response")
first_token_probs = logprobs_data["top_logprobs"][0]
cleaned_tokens = []
for token, logprob in first_token_probs.items():
cleaned_token = re.sub(r"[^a-zA-Z0-9]", "", token)
cleaned_tokens.append(cleaned_token)
if cleaned_token == completion:
return logprob
logger.warning( f"Completion '{completion}' not in top {len(first_token_probs)} logprobs. " f"Returning low probability. (logprobs: {cleaned_tokens})")
return self.config.log_prob_fallback
except self.ContextLengthExceededError:
raise # Let caller handle this
except Exception as e:
logger.error(f"Error in get_log_prob: {e}")
return self.config.log_prob_fallback
return _get()
def _get_log_prob_llamacpp(self, prompt: str, completion: str) -> float:
"""Get log probability for llama.cpp"""
if len(completion) == 1:
retry_decorator = retry_on_network_error(self.config)
@retry_decorator
def _get():
try:
response = requests.post(
f"{self.base_url}/completion",
json={
"prompt": prompt,
"n_predict": 2,
"n_probs": self.config.num_logprobs,
"temperature": 0.0,
"cache_prompt": True,
},
timeout=self.config.timeout_long,
)
self._check_for_context_error(response)
response.raise_for_status()
result = response.json()
completion_probs = result.get("completion_probabilities", [])
if not completion_probs: raise ValueError("No completion_probabilities in response")
first_pos = completion_probs[0]
top_candidates = first_pos.get("top_logprobs", [])
for candidate in top_candidates:
token = candidate.get("token", "")
cleaned_token = re.sub(r"[^a-zA-Z0-9]", "", token)
if cleaned_token == completion:
return candidate["logprob"]
raise ValueError(f"Completion '{completion}' not found in top candidates")
except self.ContextLengthExceededError:
raise # Let caller handle this
except Exception as e:
logger.error(f"Error in get_log_prob: {e}")
return self.config.log_prob_fallback
return _get()
raise NotImplementedError("Multi-token completion logprobs not implemented")
def normalize_conversation_format(conversation: List[Dict]) -> List[Dict]:
"""Normalize conversation to role/content format"""
normalized = []
for turn in conversation:
if "from" in turn and "value" in turn:
role_map = {"human": "user", "gpt": "assistant", "system": "system"}
role = role_map.get(turn["from"], turn["from"])
content = turn["value"]
elif "role" in turn and "content" in turn:
role = turn["role"]
content = turn["content"]
else:
logger.warning(f"Unknown conversation format: {turn}")
continue
normalized.append({"role": role, "content": content})
return normalized
def format_conversation(conversation: List[Dict]) -> str:
"""Format a conversation into a readable string"""
formatted = []
normalized = normalize_conversation_format(conversation)
for turn in normalized:
role = turn["role"]
content = turn["content"]
formatted.append(f"<{role}>\n{content}\n</{role}>")
return "\n".join(formatted)
def create_n_shot_prompt(
few_shot_examples: List[Dict], target_conversation: List[Dict]
) -> str:
"""Create an N-shot prompt from examples and target conversation"""
prompt_parts = []
for i, example in enumerate(few_shot_examples):
prompt_parts.append(f"Example {i + 1}:")
prompt_parts.append(format_conversation(example))
prompt_parts.append("")
prompt_parts.append(f"Example {len(few_shot_examples) + 1}:")
prompt_parts.append(format_conversation(target_conversation))
prompt_parts.append("<assistant>")
return "\n".join(prompt_parts)
class ResponseGenerator:
"""Generate multiple responses for prompts"""
def __init__(
self,
llm: LLMInterface,
config: ICMConfig,
cache: ResponseCache,
executor: ThreadPoolExecutor,
):
self.llm = llm
self.config = config
self.cache = cache
self.executor = executor
def shutdown(self):
"""Shutdown is handled by main executor"""
pass
def prepare_prompt_contexts(
self, dataset_iter: Iterator, num_prompts: int
) -> List[PromptContext]:
"""Prepare prompt contexts from dataset"""
contexts = []
examples_buffer = []
# Use a larger buffer to ensure we can find n-shot examples for each target
pbar = tqdm(total=num_prompts, desc="Preparing prompt contexts")
for idx, item in enumerate(dataset_iter):
if len(contexts) >= num_prompts:
break
conversation = self._get_conversation(item)
if not conversation: continue
normalized = normalize_conversation_format(conversation)
if not any(t["role"] == "user" for t in normalized) or not any(t["role"] == "assistant" for t in normalized): continue
# Keep a sliding window of examples
examples_buffer.append((normalized, idx))
if len(examples_buffer) > self.config.n_shot_examples:
target_conv_with_response, source_id = examples_buffer.pop(0)
few_shot_examples = [ex[0] for ex in examples_buffer]
last_assistant_idx = -1
for i, turn in enumerate(target_conv_with_response):
if turn["role"] == "assistant":
last_assistant_idx = i
if last_assistant_idx != -1:
target_conversation = target_conv_with_response[:last_assistant_idx]
original_response = target_conv_with_response[last_assistant_idx]["content"]
prompt_text = format_conversation(target_conversation)
prompt_id = hashlib.sha256(prompt_text.encode()).hexdigest()[:16]
context = PromptContext(
prompt_id=prompt_id,
few_shot_examples=few_shot_examples,
target_conversation=target_conversation,
original_response=original_response,
source_id=source_id,
timestamp=time.time(),
)
contexts.append(context)
pbar.update(1)
pbar.close()
logger.info(f"Prepared {len(contexts)} prompt contexts")
return contexts
def _get_conversation(self, item: Dict) -> Optional[List[Dict]]:
"""Get conversation from dataset item"""
if "messages" in item:
return item["messages"]
elif "conversations" in item:
return item["conversations"]
else:
for key in ["conversation", "dialog", "dialogue", "chat"]:
if key in item:
return item[key]
return None
def estimate_prompt_tokens(self, prompt: str) -> int:
"""Rough estimate of token count"""
return int(len(prompt) / 3.5) # A slightly more conservative ratio
def generate_responses_for_prompt(self, context: PromptContext) -> List[GeneratedResponse]:
"""Generate multiple responses for a single prompt with robust context handling."""
cached_responses = self.cache.get_responses_for_prompt(context.prompt_id)
if len(cached_responses) >= self.config.responses_per_prompt:
return cached_responses[:self.config.responses_per_prompt]
generated_texts_exact = {r.response_text.strip().lower() for r in cached_responses}
generated_texts_prefix = {r.response_text.strip().lower()[:self.config.duplicate_prefix_length] for r in cached_responses}
generated_texts_fuzzy = {re.sub(r"[^a-z0-9]", "", r.response_text.strip().lower())[:self.config.duplicate_fuzzy_length] for r in cached_responses}
successful_responses = []
generation_lock = threading.Lock()
def generate_single():
# Outermost loop: Reduce n-shot examples if context is too long
num_examples_to_try = self.config.n_shot_examples
while num_examples_to_try >= 0:
# Construct prompt with current number of examples
current_few_shot = context.few_shot_examples[-num_examples_to_try:]
prompt_text = create_n_shot_prompt(current_few_shot, context.target_conversation)
# Proactive check: If prompt itself is too big, don't even try.
estimated_prompt_tokens = self.estimate_prompt_tokens(prompt_text)
if estimated_prompt_tokens + self.config.response_max_tokens > self.config.max_context_tokens:
logger.debug(f"Proactively skipping {num_examples_to_try} examples, prompt too long for {context.prompt_id[:8]}.")
num_examples_to_try -= 1
continue
# Middle loop: Try different generation parameters for diversity
for attempt in range(self.config.max_generation_attempts):
try:
temperature = self.config.response_temperature * (1.1**attempt)
response_text = self.llm.generate_response(
prompt_text,
max_tokens=self.config.response_max_tokens,
temperature=temperature,
top_p=min(self.config.response_top_p + (attempt * 0.01), 1.0),
)
if not response_text or not response_text.strip(): continue
# Deduplication logic
normalized_exact = response_text.strip().lower()
normalized_prefix = normalized_exact[:self.config.duplicate_prefix_length]
normalized_fuzzy = re.sub(r"[^a-z0-9]", "", normalized_exact)[:self.config.duplicate_fuzzy_length]
with generation_lock:
is_duplicate = False
if normalized_exact in generated_texts_exact: is_duplicate = True
elif len(normalized_prefix) > self.config.duplicate_fuzzy_prefix_min_length and normalized_prefix in generated_texts_prefix: is_duplicate = True
elif len(normalized_fuzzy) > 100 and normalized_fuzzy in generated_texts_fuzzy: is_duplicate = True
if is_duplicate: continue
response_obj = self.cache.add_response(
context.prompt_id, response_text,
{"temperature": temperature, "top_p": min(self.config.response_top_p + (attempt * 0.01), 1.0), "attempts": attempt + 1, "n_shot": num_examples_to_try}
)
if not response_obj: continue
# Add to sets if successfully cached
generated_texts_exact.add(normalized_exact)
generated_texts_prefix.add(normalized_prefix)
generated_texts_fuzzy.add(normalized_fuzzy)
return response_obj # Success!
except self.llm.ContextLengthExceededError:
logger.warning(f"Response generation failed for {context.prompt_id[:8]} with {num_examples_to_try} examples due to context length. Reducing.")
# Break the diversification loop to trigger reduction of n-shot examples
break
except Exception as e:
logger.error(f"Failed to generate response on attempt {attempt + 1}: {e}")
if attempt == self.config.max_generation_attempts - 1:
break # Break inner loop, but outer loop will continue
# This part is reached if the diversification loop breaks (e.g., from context error)
num_examples_to_try -= 1
logger.error(f"Failed to generate a unique response for {context.prompt_id[:8]} even with 0 examples.")
return None
# Threaded execution to generate responses
with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
needed = self.config.responses_per_prompt - len(cached_responses)
if needed <= 0: return cached_responses
# Oversample to account for failures and duplicates
num_to_generate = math.ceil(needed * self.config.generation_oversampling_factor)
futures = [executor.submit(generate_single) for _ in range(num_to_generate)]
pbar = tqdm(as_completed(futures), total=len(futures), desc=f"Gen for {context.prompt_id[:8]}", leave=False)
for future in pbar:
response = future.result()
if response:
successful_responses.append(response)
if len(successful_responses) >= needed:
# Cancel remaining futures once we have enough
for f in futures:
if not f.done():
f.cancel()
break
all_responses = cached_responses + successful_responses
if not self.cache.get_context(context.prompt_id):
self.cache.add_context(context)
if len(all_responses) < self.config.responses_per_prompt:
logger.warning(f"Only generated {len(all_responses)}/{self.config.responses_per_prompt} for prompt {context.prompt_id[:8]}")
return all_responses[:self.config.responses_per_prompt]
def generate_all_responses(
self, contexts: List[PromptContext]
) -> Dict[str, List[GeneratedResponse]]:
"""Generate responses for all prompt contexts"""
all_responses = {}
pbar = tqdm(contexts, desc="Generating responses", position=0)
for context in pbar:
# Use cached responses if they exist, otherwise generate new ones
cached_responses = self.cache.get_responses_for_prompt(context.prompt_id)
if len(cached_responses) >= self.config.responses_per_prompt:
pbar.set_postfix_str(f"Prompt {context.prompt_id[:8]}: {len(cached_responses)} cached")
all_responses[context.prompt_id] = cached_responses[:self.config.responses_per_prompt]
else:
responses = self.generate_responses_for_prompt(context)
all_responses[context.prompt_id] = responses
pbar.set_postfix_str(f"Prompt {context.prompt_id[:8]}: {len(responses)} generated")
pbar.close()
total_responses = sum(len(r) for r in all_responses.values())
logger.info(f"Finished response generation. Total responses: {total_responses} for {len(contexts)} prompts")
return all_responses
def save_responses_to_hf_dataset(
all_responses: Dict[str, List[GeneratedResponse]],
contexts: List[PromptContext],
output_dir: str
):
"""Saves the generated responses to a Hugging Face Dataset on disk."""
total_responses = sum(len(r) for r in all_responses.values())
if total_responses == 0:
logger.warning("No responses were generated, skipping dataset creation.")
return
logger.info(f"Preparing to save {total_responses} responses to Hugging Face dataset at '{output_dir}'")
# Create a lookup for contexts by prompt_id for efficiency
contexts_by_id = {c.prompt_id: c for c in contexts}
data_to_save = defaultdict(list)
for prompt_id, responses in tqdm(all_responses.items(), desc="Formatting for dataset"):
prompt_context = contexts_by_id.get(prompt_id)
if not prompt_context:
logger.warning(f"Could not find context for prompt_id {prompt_id}. Skipping its responses.")
continue
for response in responses:
data_to_save["prompt"].append(prompt_context.target_conversation)
data_to_save["completion"].append(response.response_text)
data_to_save["prompt_id"].append(response.prompt_id)
data_to_save["response_id"].append(response.response_id)
data_to_save["source_id"].append(prompt_context.source_id)
data_to_save["generation_params"].append(json.dumps(response.generation_params))
if not data_to_save["prompt"]:
logger.error("No data to save after formatting. Aborting dataset creation.")
return
# Define the features for type safety and clarity
features = Features({
'prompt': Sequence({'role': Value('string'), 'content': Value('string')}),
'completion': Value('string'),
'prompt_id': Value('string'),
'response_id': Value('string'),
'source_id': Value('int64'),
'generation_params': Value('string'), # Storing as a JSON string
})
try:
hf_dataset = Dataset.from_dict(dict(data_to_save), features=features)
logger.info(f"Created dataset with {len(hf_dataset)} rows. Saving to disk...")
hf_dataset.save_to_disk(output_dir)
logger.info(f"Successfully saved generated responses dataset to '{output_dir}'")
except Exception as e:
logger.exception(f"Failed to create or save Hugging Face dataset: {e}")
def analyze_response_diversity(
all_responses: Dict[str, List[GeneratedResponse]], config: ICMConfig
):
"""Analyze diversity of generated responses"""
print("\n📊 Response Diversity Analysis:")
print("=" * 60)
total_duplicates = 0
total_near_duplicates = 0
total_fuzzy_duplicates = 0
for prompt_id, responses in all_responses.items():
if len(responses) < 2:
continue
texts = [r.response_text.strip() for r in responses]
unique_texts = set(texts)
duplicates = len(texts) - len(unique_texts)
normalized_texts = [t.lower() for t in texts]
unique_normalized = set(normalized_texts)
near_duplicates = len(normalized_texts) - len(unique_normalized)
fuzzy_texts = [
re.sub(r"[^a-z0-9]", "", t.lower())[: config.duplicate_fuzzy_length]
for t in texts
]
unique_fuzzy = set(fuzzy_texts)
fuzzy_duplicates = len(fuzzy_texts) - len(unique_fuzzy)
if duplicates > 0 or near_duplicates > 0 or fuzzy_duplicates > 0:
print(f"\nPrompt {prompt_id[:8]}...")
print(f" Total responses: {len(responses)}")
print(f" Exact duplicates: {duplicates}")
print(f" Near duplicates (case): {near_duplicates}")
print(f" Fuzzy duplicates: {fuzzy_duplicates}")
if duplicates > 0:
print(" Example duplicate:")
for i, text in enumerate(texts):
if texts.count(text) > 1:
print(f" Response {i+1}: {text[:50]}...")
break
total_duplicates += duplicates
total_near_duplicates += near_duplicates
total_fuzzy_duplicates += fuzzy_duplicates
print(f"\n📈 Summary:")
print(f" Total prompts: {len(all_responses)}")
print(f" Total responses: {sum(len(r) for r in all_responses.values())}")
print(f" Total exact duplicates: {total_duplicates}")
print(f" Total near duplicates (case): {total_near_duplicates}")
print(f" Total fuzzy duplicates: {total_fuzzy_duplicates}")
if (
total_duplicates == 0
and total_near_duplicates == 0
and total_fuzzy_duplicates == 0
):
print("\n✅ Excellent! No duplicates detected at any level.")
else:
print("\n⚠️ Some duplicates were found despite prevention measures.")
print("=" * 60)
class ICMLabeler:
"""Internal Coherence Maximization for labeling response pairs"""
def __init__(
self,
llm: LLMInterface,
config: ICMConfig,
cache: ResponseCache,
executor: ThreadPoolExecutor,
):
self.llm = llm
self.config = config
self.cache = cache
self.executor = executor
self.labeled_data: Dict[str, List[LabeledExample]] = defaultdict(list)
self.generation_failures: Dict[str, int] = defaultdict(int)
self.pair_tracker: Dict[str, Set[Tuple[str, str]]] = defaultdict(set)
self.inconsistency_degree: Dict[str, Dict[Tuple[str, str], int]] = defaultdict(dict)
self.log_prob_cache: Dict[str, Dict[int, float]] = defaultdict(dict)
self.dirty_indices: Dict[str, Set[int]] = defaultdict(set)
self.total_log_probs: Dict[str, float] = defaultdict(float)
self.skipped_pairs: Dict[str, int] = defaultdict(int)
self.consecutive_failures: Dict[str, int] = defaultdict(int)
# NEW: Adaptive state for the "Fail, Reduce, Retry" strategy
self.adaptive_example_count: Dict[str, int] = defaultdict(lambda: self.config.max_context_examples)
def shutdown(self):
"""Shutdown is handled by main executor"""
pass
def _mark_all_dirty(self, dimension: str):
"""Mark all examples as needing recomputation"""
n_examples = len(self.labeled_data[dimension])
self.dirty_indices[dimension] = set(range(n_examples))
def _mark_dirty_except(self, dimension: str, skip_idx: int):
"""Mark all examples except one as needing recomputation"""
n_examples = len(self.labeled_data[dimension])
self.dirty_indices[dimension] = set(range(n_examples)) - {skip_idx}
def create_comparison_prompt(
self,
prompt_context: PromptContext,
response_a: str,
response_b: str,
dimension: str,
position_swapped: bool,
) -> str:
"""Create prompt for comparing two responses"""
dimension_text = dimension.replace("_", " ")
rubric = DIMENSION_RUBRICS.get(dimension, "")
if position_swapped:
first_response, second_response = response_b, response_a
first_label, second_label = "B", "A"
else:
first_response, second_response = response_a, response_b
first_label, second_label = "A", "B"
conversation_text = format_conversation(prompt_context.target_conversation)
prompt = f"""You are an impartial judge evaluating two AI assistant responses.
Here is the conversation that led to these responses:
{conversation_text}
Response {first_label}:
{first_response}
Response {second_label}:
{second_response}
Which response demonstrates better {dimension_text} ({rubric})?
IMPORTANT: You MUST choose either "{first_label}" or "{second_label}". If they seem very similar, pick the one that is even slightly better. Do NOT say "tie", "both", "equal", or anything else.
Reply with ONLY "{first_label}" or "{second_label}" - nothing else.
Answer: """
return prompt
def create_context_prompt(
self,
examples: List[LabeledExample],
new_example: LabeledExample,
prompt_context: PromptContext,
position_swapped: bool,
) -> str:
"""Create prompt with examples in context (no token calculation)."""
dimension_text = new_example.label_dimension.replace("_", " ")
rubric = DIMENSION_RUBRICS.get(new_example.label_dimension, "")
prompt = "<golden_dataset>\n"
prompt += f"<task_description>\nYou are an impartial judge.\n"
prompt += f"Compare AI assistant responses on {dimension_text} ({rubric}).\n"
prompt += f"IMPORTANT: Always choose A or B. Never say tie/both/equal.\n"
prompt += f"</task_description>\n\n"
prompt += f"<best_comparisons>\n"
prompt_parts = []
for i, ex in enumerate(examples):
ex_context = self.cache.get_context(ex.prompt_id)
if not ex_context: continue
if ex.position_swapped:
first_resp, second_resp = ex.response_b_text, ex.response_a_text
first_label, second_label = "B", "A"
display_label = "B" if ex.label == "A" else "A"
else:
first_resp, second_resp = ex.response_a_text, ex.response_b_text
first_label, second_label = "A", "B"
display_label = ex.label
conversation_text = format_conversation(ex_context.target_conversation)
example_text = f"<comparison_{i + 1}>\n"
example_text += f"<conversation>\n{conversation_text}\n</conversation>\n"
example_text += f"<response_{first_label.lower()}>\n{first_resp}\n</response_{first_label.lower()}>\n"
example_text += f"<response_{second_label.lower()}>\n{second_resp}\n</response_{second_label.lower()}>\n"
example_text += f"<dimension>\n{dimension_text}\n</dimension>\n"
example_text += f"<rubric>\n{rubric}\n</rubric>\n"
example_text += f"<verdict>\n <better>{display_label}</better>\n</verdict>\n"
example_text += f"</comparison_{i + 1}>\n"
prompt_parts.append(example_text)
prompt += "".join(prompt_parts)
if position_swapped:
first_resp, second_resp = new_example.response_b_text, new_example.response_a_text
first_label, second_label = "B", "A"
else:
first_resp, second_resp = new_example.response_a_text, new_example.response_b_text
first_label, second_label = "A", "B"
conversation_text = format_conversation(prompt_context.target_conversation)
prompt += f"<comparison_{len(examples) + 1}>\n"
prompt += f"<conversation>\n{conversation_text}\n</conversation>\n"
prompt += f"<response_{first_label.lower()}>\n{first_resp}\n</response_{first_label.lower()}>\n"
prompt += f"<response_{second_label.lower()}>\n{second_resp}\n</response_{second_label.lower()}>\n"
prompt += f"<dimension>\n{dimension_text}\n</dimension>\n"
prompt += f"<rubric>\n{rubric}\n</rubric>\n"
prompt += f"<verdict>\n <better>"
return prompt
def create_unified_comparison_prompt(
self,
prompt_context: PromptContext,
response_a: str,
response_b: str,
position_swapped: bool,
) -> str:
"""Create a unified prompt that evaluates all dimensions holistically"""
if position_swapped:
first_response, second_response = response_b, response_a
first_label, second_label = "B", "A"
else:
first_response, second_response = response_a, response_b
first_label, second_label = "A", "B"
conversation_text = format_conversation(prompt_context.target_conversation)
prompt = f"""{UNIFIED_ONTOLOGY_PROMPT}
Now, evaluate these two responses to the following conversation:
Conversation:
{conversation_text}
Response {first_label}:
{first_response}
Response {second_label}:
{second_response}
Considering all dimensions of the evaluation framework above, which response is better overall for the given conversation?
CRITICAL: You MUST output ONLY "{first_label}" or "{second_label}" - no other text, explanations, or qualifiers. Even if they seem similar, pick the one that's marginally better.
Answer: """
return prompt
def create_unified_context_prompt(
self,
examples: List[LabeledExample],
new_example: LabeledExample,
prompt_context: PromptContext,
position_swapped: bool,
) -> str:
"""Create unified prompt with examples in context (no token calculation)."""
prompt = "<golden_dataset>\n"
prompt += "<task_description>\n"
prompt += "You are an expert evaluator using a comprehensive quality framework.\n"
prompt += "Evaluate responses holistically across all dimensions: safety, accuracy, task fulfillment, clarity, self-awareness, and depth.\n"
prompt += "ALWAYS choose A or B - never say tie/both/equal.\n"
prompt += "</task_description>\n\n"
prompt += "<evaluation_framework>\n"
prompt += UNIFIED_ONTOLOGY_PROMPT
prompt += "\n</evaluation_framework>\n\n"
prompt += "<best_comparisons>\n"
prompt_parts = []
for i, ex in enumerate(examples):
ex_context = self.cache.get_context(ex.prompt_id)
if not ex_context: continue
if ex.position_swapped:
ex_first_resp, ex_second_resp = ex.response_b_text, ex.response_a_text
ex_first_label, ex_second_label = "B", "A"
display_label = "B" if ex.label == "A" else "A"
else:
ex_first_resp, ex_second_resp = ex.response_a_text, ex.response_b_text
ex_first_label, ex_second_label = "A", "B"
display_label = ex.label
conversation_text = format_conversation(ex_context.target_conversation)
example_text = f"<comparison_{i + 1}>\n"
example_text += f"<conversation>\n{conversation_text}\n</conversation>\n"
example_text += f"<response_{ex_first_label.lower()}>\n{ex_first_resp}\n</response_{ex_first_label.lower()}>\n"
example_text += f"<response_{ex_second_label.lower()}>\n{ex_second_resp}\n</response_{ex_second_label.lower()}>\n"
example_text += f"<verdict>\n <better>{display_label}</better>\n</verdict>\n"
example_text += f"</comparison_{i + 1}>\n"
prompt_parts.append(example_text)
prompt += "".join(prompt_parts)
if position_swapped:
first_resp, second_resp = new_example.response_b_text, new_example.response_a_text
first_label, second_label = "B", "A"
else:
first_resp, second_resp = new_example.response_a_text, new_example.response_b_text
first_label, second_label = "A", "B"
conversation_text = format_conversation(prompt_context.target_conversation)
prompt += f"<comparison_{len(examples) + 1}>\n"
prompt += f"<conversation>\n{conversation_text}\n</conversation>\n"
prompt += f"<response_{first_label.lower()}>\n{first_resp}\n</response_{first_label.lower()}>\n"
prompt += f"<response_{second_label.lower()}>\n{second_resp}\n</response_{second_label.lower()}>\n"
prompt += f"<verdict>\n <better>"
return prompt
def compute_mutual_predictability(self, dimension: str, pbar) -> float:
"""Compute mutual predictability score P_θ(D) on full dataset"""
examples = self.labeled_data[dimension]
if len(examples) < 2: return 0.0
if dimension not in self.total_log_probs or len(self.log_prob_cache[dimension]) != len(examples):
self.dirty_indices[dimension] = set(range(len(examples)))
self.log_prob_cache[dimension].clear()
self.total_log_probs[dimension] = 0.0
indices_to_compute = list(self.dirty_indices[dimension])
if not indices_to_compute: return self.total_log_probs[dimension]
logger.debug(f"Recomputing {len(indices_to_compute)}/{len(examples)} examples for {dimension}")
if pbar: pbar.set_postfix_str(f"{len(self.labeled_data[dimension]):3d} | Skip: {self.skipped_pairs[dimension]:3d} | Computing P_θ(D)... (0/{len(indices_to_compute)})")
prompts_to_process = []
for idx in indices_to_compute:
if idx >= len(examples): continue
num_examples_to_try = self.adaptive_example_count[dimension]
log_prob = self.config.log_prob_fallback
while num_examples_to_try >= 0:
try:
context_examples = [ex for i, ex in enumerate(examples) if i != idx]
context_examples = context_examples[-num_examples_to_try:]
example = examples[idx]
prompt_context = self.cache.get_context(example.prompt_id)
if not prompt_context: break
if self.config.unified_mode:
prompt = self.create_unified_context_prompt(context_examples, example, prompt_context, position_swapped=example.position_swapped)
else:
prompt = self.create_context_prompt(context_examples, example, prompt_context, position_swapped=example.position_swapped)
if example.position_swapped:
expected_label = "B" if example.label == "A" else "A"
else:
expected_label = example.label
log_prob = self.llm.get_log_prob(prompt, expected_label)
# Do not adapt example count during log_prob computation, only during labeling
break # Success
except self.llm.ContextLengthExceededError:
num_examples_to_try -= 1
except Exception as e:
logger.error(f"Error during P_theta computation: {e}")
break # Don't retry on other errors
prompts_to_process.append({"idx": idx, "log_prob": log_prob})
for item in prompts_to_process:
idx, new_log_prob = item['idx'], item['log_prob']
old_log_prob = self.log_prob_cache[dimension].get(idx, 0.0)
self.log_prob_cache[dimension][idx] = new_log_prob
self.total_log_probs[dimension] += new_log_prob - old_log_prob
self.dirty_indices[dimension].clear()
return self.total_log_probs[dimension]
def check_logical_consistency(
self, ex1: LabeledExample, ex2: LabeledExample
) -> bool:
"""Check if two labels are logically consistent (asymmetry only)"""
if (
ex1.response_a_id == ex2.response_b_id
and ex1.response_b_id == ex2.response_a_id
):
return ex1.label != ex2.label
return True
def find_inconsistent_pairs(self, dimension: str) -> List[Tuple[int, int]]:
"""Find all inconsistent pairs efficiently"""
examples = self.labeled_data[dimension]
inconsistent = []
pair_groups = defaultdict(list)
for i, ex in enumerate(examples):
pair_key = ex.get_pair_key()
pair_groups[pair_key].append((i, ex))
for pair_key, group in pair_groups.items():
for i, (idx1, ex1) in enumerate(group):
for idx2, ex2 in group[i + 1 :]:
if not self.check_logical_consistency(ex1, ex2):
inconsistent.append((idx1, idx2))
return inconsistent
def compute_inconsistency_score(self, dimension: str) -> float:
"""Compute total inconsistency I(D)"""
return float(len(self.find_inconsistent_pairs(dimension)))
def compute_scoring_function(self, dimension: str, pbar) -> float:
"""Compute U(D) = α * P_θ(D) - I(D)"""
mutual_pred = self.compute_mutual_predictability(dimension, pbar)
inconsistency = self.compute_inconsistency_score(dimension)
return self.config.alpha * mutual_pred - inconsistency
def fix_inconsistencies(self, dimension: str):
"""Fix inconsistencies using Algorithm 2 from the paper"""
max_iterations = self.config.fix_inconsistencies_max_iterations
examples = self.labeled_data[dimension]
for iteration in range(max_iterations):
inconsistent_pairs = self.find_inconsistent_pairs(dimension)
if not inconsistent_pairs: return
i, j = random.choice(inconsistent_pairs)
if i >= len(examples) or j >= len(examples): continue
best_score = float("-inf")
best_labels = None
for label_i in ["A", "B"]:
for label_j in ["A", "B"]:
old_label_i, old_label_j = examples[i].label, examples[j].label
examples[i].label = label_i
examples[j].label = label_j
if self.check_logical_consistency(examples[i], examples[j]):
score = self.compute_scoring_function(dimension, pbar=None)
if score > best_score:
best_score = score
best_labels = (label_i, label_j)
examples[i].label = old_label_i
examples[j].label = old_label_j
current_score = self.compute_scoring_function(dimension, pbar=None)
if best_labels and best_score > current_score:
old_label_i, old_label_j = examples[i].label, examples[j].label
examples[i].label, examples[j].label = best_labels
if old_label_i != examples[i].label: self.dirty_indices[dimension].add(i)
if old_label_j != examples[j].label: self.dirty_indices[dimension].add(j)
def update_inconsistency_degrees(self, dimension: str):
"""Update inconsistency degrees for weighted sampling"""
self.inconsistency_degree[dimension].clear()
examples = self.labeled_data[dimension]
for i, ex in enumerate(examples):
pair_key = ex.get_pair_key()
if pair_key not in self.inconsistency_degree[dimension]: self.inconsistency_degree[dimension][pair_key] = 0
for j, other in enumerate(examples):
if i != j:
other_key = other.get_pair_key()
if pair_key == other_key and not self.check_logical_consistency(ex, other):
self.inconsistency_degree[dimension][pair_key] += 1
def sample_next_pair(self, available_pairs: Iterator, dimension: str) -> Tuple:
# For simplicity with an iterator, we just take the next one.
# Weighted sampling is more complex with a pure iterator.
return next(available_pairs)
def batch_generate_initial_labels(self, pairs: List[Tuple], dimension: str, unified_mode: bool):
"""Generate initial labels in parallel."""
def generate_single(pair_data):
prompt_id, resp_a, resp_b = pair_data
position_swapped = random.choice([True, False])
prompt_context = self.cache.get_context(prompt_id)
if not prompt_context: return {"success": False, "error": f"Missing context for {prompt_id}"}
# For initialization, we use zero-shot to avoid context length issues
if unified_mode:
prompt = self.create_unified_comparison_prompt(prompt_context, resp_a.response_text, resp_b.response_text, position_swapped)
else:
prompt = self.create_comparison_prompt(prompt_context, resp_a.response_text, resp_b.response_text, dimension, position_swapped)
try:
raw_label = self.llm.generate_label(prompt, dimension, max_tokens=2, temperature=0.0)
if raw_label is None: return {"success": False, "skipped": True, "error": "Model indecisive"}
label = "B" if raw_label == "A" and position_swapped else "A" if raw_label == "B" and position_swapped else raw_label
example = LabeledExample(prompt_id=prompt_id, response_a_id=resp_a.response_id, response_b_id=resp_b.response_id, response_a_text=resp_a.response_text, response_b_text=resp_b.response_text, label_dimension=dimension, label=label, position_swapped=position_swapped, timestamp=time.time())
return {"success": True, "example": example}
except self.llm.ContextLengthExceededError:
return {"success": False, "skipped": True, "error": "Context too long for zero-shot"}
except Exception as e:
return {"success": False, "error": str(e)}
results = []
futures = {self.executor.submit(generate_single, pair): pair for pair in pairs}
pbar = tqdm(as_completed(futures), total=len(futures), desc=" Initializing", position=2, leave=False)
for future in pbar: results.append(future.result())
pbar.close()
return results
def label_dimension(self, response_pairs: List[Tuple[str, GeneratedResponse, GeneratedResponse]], dimension: str, save_callback, unified_mode: bool):
logger.info(f"Labeling dimension: {dimension} with {len(response_pairs)} pairs" + (" (unified mode)" if unified_mode else ""))
self.generation_failures[dimension] = 0
self.skipped_pairs[dimension] = 0
self.consecutive_failures[dimension] = 0
# --- Initialization Phase ---
initial_pairs_to_try = response_pairs[:self.config.initial_k * 5]
batch_results = self.batch_generate_initial_labels(initial_pairs_to_try, dimension, unified_mode)
for result in batch_results:
if len(self.labeled_data[dimension]) >= self.config.initial_k:
break
if result["success"]:
example = result["example"]
self.labeled_data[dimension].append(example)
self.pair_tracker[dimension].add(example.get_pair_key())
elif result.get("skipped"):
self.skipped_pairs[dimension] += 1
else:
self.generation_failures[dimension] += 1
if len(self.labeled_data[dimension]) < 2:
logger.error(f"Could not initialize enough examples for {dimension}. Got {len(self.labeled_data[dimension])}. Stopping.")
if self.config.save_partial_on_error and self.labeled_data[dimension]:
save_callback(dimension, self.labeled_data[dimension])
return
logger.info(f"Successfully initialized with {len(self.labeled_data[dimension])} labels.")
self._mark_all_dirty(dimension)
self.fix_inconsistencies(dimension)
# --- Main MCMC Loop ---
# Create an iterator for all pairs that were NOT used in initialization
main_pair_iterator = iter(p for p in response_pairs if tuple(sorted([p[1].response_id, p[2].response_id])) not in self.pair_tracker[dimension])
pbar = tqdm(range(self.config.max_iterations), desc=f"{dimension[:25]:25s}", position=1, leave=False)
last_save_count = len(self.labeled_data[dimension])
for iteration in pbar:
if len(self.labeled_data[dimension]) >= self.config.max_labels_per_dimension:
logger.info(f"Reached max_labels_per_dimension ({self.config.max_labels_per_dimension}). Stopping.")
break
temperature = max(self.config.final_temp, self.config.initial_temp / (1 + self.config.beta * math.log(iteration + 1)))
pbar.set_postfix_str(f"Labels: {len(self.labeled_data[dimension])} | Skip: {self.skipped_pairs[dimension]:3d} | Ex: {self.adaptive_example_count[dimension]:2d} | T:{temperature:.2f}")
if self.consecutive_failures[dimension] >= self.config.max_consecutive_failures:
logger.warning(f"Too many consecutive failures ({self.consecutive_failures[dimension]}) for {dimension}. Stopping.")
break
try:
# We don't use weighted sampling here to keep it simple and ensure we process all pairs if possible
prompt_id, resp_a, resp_b = next(main_pair_iterator)
except StopIteration:
pbar.close()
logger.info("Exhausted all pairs")
break
current_score = self.compute_scoring_function(dimension, pbar)
position_swapped = random.choice([True, False])
new_example = LabeledExample(prompt_id=prompt_id, response_a_id=resp_a.response_id, response_b_id=resp_b.response_id, response_a_text=resp_a.response_text, response_b_text=resp_b.response_text, label_dimension=dimension, label="", position_swapped=position_swapped, timestamp=time.time())
prompt_context = self.cache.get_context(prompt_id)
if not prompt_context: continue
raw_label = None
num_examples_to_try = self.adaptive_example_count[dimension]
while num_examples_to_try >= 0:
try:
context_examples = self.labeled_data[dimension][-num_examples_to_try:]
if unified_mode:
prompt = self.create_unified_context_prompt(context_examples, new_example, prompt_context, position_swapped) if num_examples_to_try > 0 else self.create_unified_comparison_prompt(prompt_context, resp_a.response_text, resp_b.response_text, position_swapped)
else:
prompt = self.create_context_prompt(context_examples, new_example, prompt_context, position_swapped) if num_examples_to_try > 0 else self.create_comparison_prompt(prompt_context, resp_a.response_text, resp_b.response_text, dimension, position_swapped)
raw_label = self.llm.generate_label(prompt, dimension, max_tokens=2, temperature=0.0)
self.adaptive_example_count[dimension] = num_examples_to_try
break
except self.llm.ContextLengthExceededError:
logger.debug(f"Context too long with {num_examples_to_try} examples. Reducing.")
num_examples_to_try -= 1
except Exception as e:
logger.error(f"Error generating label: {e}")
self.generation_failures[dimension] += 1
raw_label = "ERROR"
break
if raw_label == "ERROR": continue
if raw_label is None:
self.skipped_pairs[dimension] += 1
self.consecutive_failures[dimension] += 1
if num_examples_to_try < 0:
logger.warning(f"Pair {resp_a.response_id[:4]}-{resp_b.response_id[:4]} failed even with 0 examples.")
continue
self.consecutive_failures[dimension] = 0
proposed_label = "B" if raw_label == "A" and position_swapped else "A" if raw_label == "B" and position_swapped else raw_label
new_example.label = proposed_label
self.labeled_data[dimension].append(new_example)
self.pair_tracker[dimension].add(new_example.get_pair_key())
new_idx = len(self.labeled_data[dimension]) - 1
self._mark_dirty_except(dimension, new_idx)
self.fix_inconsistencies(dimension)
new_score = self.compute_scoring_function(dimension, pbar)
delta = new_score - current_score
acceptance_prob = 1.0 if delta > 0 else math.exp(delta / temperature)
if random.random() > acceptance_prob:
self.labeled_data[dimension].pop()
self.pair_tracker[dimension].remove(new_example.get_pair_key())
self._mark_all_dirty(dimension)
if save_callback and len(self.labeled_data[dimension]) - last_save_count >= self.config.save_interval:
save_callback(dimension, self.labeled_data[dimension])
last_save_count = len(self.labeled_data[dimension])
pbar.close()
if save_callback and self.labeled_data[dimension]:
save_callback(dimension, self.labeled_data[dimension])
logger.info(f"Completed {dimension}: {len(self.labeled_data[dimension])} labels, {self.skipped_pairs[dimension]} skipped, {self.generation_failures[dimension]} failures.")
def create_response_pairs(
all_responses: Dict[str, List[GeneratedResponse]],
) -> List[Tuple[str, GeneratedResponse, GeneratedResponse]]:
"""Create all pairs of responses for comparison"""
pairs = []
for prompt_id, responses in all_responses.items():
for i in range(len(responses)):
for j in range(i + 1, len(responses)):
pairs.append((prompt_id, responses[i], responses[j]))
random.shuffle(pairs)
logger.info(f"Created {len(pairs)} response pairs for comparison")
return pairs
def save_labeled_data(
labeled_examples: List[LabeledExample],
response_cache: ResponseCache,
dimension: str,
config: ICMConfig,
):
"""Save labeled preference data"""
if not labeled_examples:
logger.warning(f"No labeled examples for dimension {dimension}")
return
dim_output = os.path.join(config.output_dir, f"preference_{dimension}")
os.makedirs(dim_output, exist_ok=True)
preference_data = []
for ex in labeled_examples:
context = response_cache.get_context(ex.prompt_id)
if not context:
logger.warning(f"Missing context for prompt {ex.prompt_id}")
continue
if ex.label == "A":
chosen = ex.response_a_text
rejected = ex.response_b_text
chosen_id = ex.response_a_id
rejected_id = ex.response_b_id
else:
chosen = ex.response_b_text
rejected = ex.response_a_text
chosen_id = ex.response_b_id
rejected_id = ex.response_a_id
preference_data.append(
{
"prompt_id": ex.prompt_id,
"chosen": chosen,
"rejected": rejected,
"chosen_id": chosen_id,
"rejected_id": rejected_id,
"dimension": dimension,
"timestamp": ex.timestamp,
}
)
random.shuffle(preference_data)
split_idx = int(len(preference_data) * config.train_split)
train_data = preference_data[:split_idx]
test_data = preference_data[split_idx:]
with open(os.path.join(dim_output, "train.jsonl"), "w") as f:
for item in train_data:
f.write(json.dumps(item) + "\n")
with open(os.path.join(dim_output, "test.jsonl"), "w") as f:
for item in test_data:
f.write(json.dumps(item) + "\n")
with open(os.path.join(dim_output, "all_comparisons.jsonl"), "w") as f:
for ex in labeled_examples:
comparison = {
"prompt_id": ex.prompt_id,
"response_a": ex.response_a_text,
"response_b": ex.response_b_text,
"label": ex.label,
"dimension": ex.label_dimension,
"response_a_id": ex.response_a_id,
"response_b_id": ex.response_b_id,
"timestamp": ex.timestamp,
}
f.write(json.dumps(comparison) + "\n")
with open(os.path.join(dim_output, "prompts.jsonl"), "w") as f:
unique_prompts = {ex.prompt_id for ex in labeled_examples}
for prompt_id in unique_prompts:
context = response_cache.get_context(prompt_id)
if context:
prompt_data = {
"prompt_id": prompt_id,
"prompt_text": context.prompt_text,
"few_shot_examples": context.few_shot_examples,
"target_conversation": context.target_conversation,
"original_response": context.original_response,
}
f.write(json.dumps(prompt_data) + "\n")
metadata = {
"dimension": dimension,
"train_samples": len(train_data),
"test_samples": len(test_data),
"total_samples": len(preference_data),
"rubric": DIMENSION_RUBRICS.get(dimension, ""),
"timestamp": time.time(),
}
with open(os.path.join(dim_output, "metadata.json"), "w") as f:
json.dump(metadata, f, indent=2)
logger.info(
f"Saved preference dataset for {dimension}: {len(train_data)} train, {len(test_data)} test"
)
# Global references for signal handler
_global_response_cache = None
_global_response_generator = None
_global_labeler = None
_global_executor = None
def signal_handler(sig, frame):
"""Handle Ctrl+C gracefully"""
logger.info("\nInterrupted! Cleaning up...")
if _global_labeler and hasattr(_global_labeler, "labeled_data"):
total_labels = sum(
len(labels) for labels in _global_labeler.labeled_data.values()
)
if total_labels > 0:
logger.info(f"Saved {total_labels} labels across dimensions:")
for dim, labels in _global_labeler.labeled_data.items():
if labels:
logger.info(f" • {dim}: {len(labels)} labels")
if _global_response_cache:
logger.info("Saving response cache...")
_global_response_cache.save_final()
if _global_response_generator:
_global_response_generator.shutdown()
if _global_labeler:
_global_labeler.shutdown()
if _global_executor:
_global_executor.shutdown(wait=False, cancel_futures=True)
sys.exit(0)
def cleanup_resources():
"""Cleanup function for atexit"""
if _global_response_cache:
try: _global_response_cache.save_final()
except: pass
if _global_response_generator:
try: _global_response_generator.shutdown()
except: pass
if _global_labeler:
try: _global_labeler.shutdown()
except: pass
if _global_executor:
try: _global_executor.shutdown(wait=False, cancel_futures=True)
except: pass
def label_single_dimension(dimension: str, pairs: List[Tuple], labeler, save_callback, config):
"""Label a single dimension with better error recovery"""
try:
labeler.label_dimension(pairs, dimension, save_callback=save_callback, unified_mode=config.unified_mode)
labels = labeler.labeled_data.get(dimension, [])
skipped = labeler.skipped_pairs.get(dimension, 0)
failures = labeler.generation_failures.get(dimension, 0)
logger.info(f"Completed {dimension}: {len(labels)} labeled, {skipped} skipped, {failures} errors")
return True, len(labels)
except Exception as e:
logger.exception(f"Critical error processing {dimension}: {e}")
partial_labels = labeler.labeled_data.get(dimension, [])
if config.save_partial_on_error and partial_labels:
logger.info(f"Saving {len(partial_labels)} partial results for {dimension}")
save_callback(dimension, partial_labels)
return False, len(partial_labels)
def main():
global _global_response_cache, _global_response_generator, _global_labeler, _global_executor
atexit.register(cleanup_resources)
signal.signal(signal.SIGINT, signal_handler)
# Configuration for a small test run
config = ICMConfig(
# Data source
dataset_name="allenai/tulu-3-sft-mixture",
api_url="http://localhost:8093",
# Dataset sampling
num_prompts=5000,
n_shot_examples=3,
responses_per_prompt=25,
shuffle_buffer_size=100_000,
# Response generation parameters
response_temperature=1.0,
response_top_p=1.0,
response_top_k=0,
response_frequency_penalty=0.0,
response_presence_penalty=0.0,
response_repeat_penalty=1.0,
response_max_tokens=4096,
# ICM core algorithm
initial_k=8,
alpha=50.0,
initial_temp=5.0, # RELAXED: Lowered for higher initial acceptance
final_temp=0.01,
beta=0.95, # RELAXED: Cools slightly faster
max_iterations=10000, # Realistic ceiling for the test run
max_labels_per_dimension=25000,
fix_inconsistencies_max_iterations=100,
inconsistency_weight_multiplier=100.0,
# Dimension configuration
dimensions=None,
dimension_subset="minimal",
unified_mode=True,
unified_dimension_name="unified_quality",
# Model context management
max_context_tokens=32768,
response_reserve_tokens=4096,
token_estimation_ratio=0.75,
max_context_examples=25,
# Pair and data management
max_pairs_per_dimension=1500000,
max_generation_failures=500,
# Output configuration
output_dir="preference_datasets",
train_split=0.9,
cache_file="response_cache.jsonl",
save_interval=100,
generated_dataset_path="generated_responses_dataset",
# Execution configuration
max_workers=32,
random_seed=43,
# Network and retry configuration
retry_max_attempts=3,
retry_base_delay=1.0,
timeout_short=15.0,
timeout_medium=30.0,
timeout_long=90.0,
timeout_generation=180.0,
# Deduplication parameters
duplicate_prefix_length=200,
duplicate_fuzzy_length=500,
duplicate_fuzzy_prefix_min_length=50,
# Generation diversity parameters
diversity_attempt_threshold=3,
diversity_strong_threshold=10,
max_generation_attempts=20,
generation_oversampling_factor=1.25,
# Resilience parameters
max_consecutive_failures=50,
skip_indecisive_pairs=True,
min_labels_per_dimension=100,
save_partial_on_error=True,
# Misc parameters
log_prob_fallback=-100.0,
cache_report_interval=300,
num_logprobs=20,
streaming=False,
)
shared_executor = ThreadPoolExecutor(max_workers=config.max_workers)
_global_executor = shared_executor
llm = LLMInterface(config)
response_cache = ResponseCache(config)
_global_response_cache = response_cache
logger.info(f"Loading dataset: {config.dataset_name}")
dataset = load_dataset(config.dataset_name, split="train", streaming=config.streaming)
if config.streaming:
iterable = dataset.shuffle(buffer_size=config.shuffle_buffer_size, seed=config.random_seed).take(config.num_prompts * (config.n_shot_examples + 5))
else:
iterable = iter(dataset.shuffle(seed=config.random_seed))
response_generator = ResponseGenerator(llm, config, response_cache, shared_executor)
_global_response_generator = response_generator
logger.info("Preparing prompt contexts...")
contexts = response_generator.prepare_prompt_contexts(iterable, config.num_prompts)
logger.info("Generating responses...")
all_responses = response_generator.generate_all_responses(contexts)
# --- NEW: Save generated responses to a dataset ---
if all_responses:
save_responses_to_hf_dataset(
all_responses, contexts, config.generated_dataset_path
)
# ---------------------------------------------------
analyze_response_diversity(all_responses, config)
if config.unified_mode:
dimensions_to_label = [config.unified_dimension_name]
elif config.dimensions:
dimensions_to_label = config.dimensions
else:
subset_map = {
"all": list(VALID_LABELS),
"minimal": ["harmlessness", "factual_accuracy", "logical_coherence", "task_completion", "latent_task_identification", "uncertainty_calibration"],
"safety": ["policy_compliance", "harmlessness"],
"task": ["task_completion", "constraint_adherence", "latent_task_identification", "pattern_generalization", "information_density", "relevance_focus", "structural_clarity"],
"epistemic": ["uncertainty_calibration", "assumption_transparency", "perspective_awareness", "insight_synthesis", "problem_decomposition", "conceptual_synthesis"],
}
dimensions_to_label = subset_map.get(config.dimension_subset, [])
logger.info(f"Will label {len(dimensions_to_label)} dimensions: {dimensions_to_label}")
response_pairs = create_response_pairs(all_responses)
labeler = ICMLabeler(llm, config, response_cache, shared_executor)
_global_labeler = labeler
def save_dimension_progress(dimension: str, labeled_examples: List[LabeledExample]):
save_labeled_data(labeled_examples, response_cache, dimension, config)
overall_pbar = tqdm(dimensions_to_label, desc="Overall Progress", position=0)
for dimension in overall_pbar:
overall_pbar.set_postfix_str(f"Current: {dimension[:20]}...")
print(f"\n{'='*60}\nProcessing dimension: {dimension}\n{'='*60}", flush=True)
pairs_copy = response_pairs.copy()
success, label_count = label_single_dimension(dimension, pairs_copy, labeler, save_dimension_progress, config)
overall_pbar.set_postfix_str(f"{'✓' if success else '✗'} {dimension} ({label_count} labels)")
overall_pbar.close()
response_cache.save_final()
shared_executor.shutdown(wait=True)
# --- Final Summary ---
print("\n" + "=" * 80 + "\n" + " " * 30 + "LABELING COMPLETE!" + " " * 30 + "\n" + "=" * 80)
print(f"\n💾 Output Files:")
print(f" • Generated responses dataset: {config.generated_dataset_path}")
print(f" • Preference datasets: {config.output_dir}/preference_*/")
print(f"\n📋 Per-Dimension Results:")
print(f" {'Dimension':<30} {'Labels':>8} {'Skipped':>8} {'Adaptive Ex':>12} {'A':>6} {'B':>6} {'A/B Ratio':>10}")
print(f" {'-'*30} {'-'*8} {'-'*8} {'-'*12} {'-'*6} {'-'*6} {'-'*10}")
summary_rows = []
for dimension in dimensions_to_label:
if dimension in labeler.labeled_data:
labels = labeler.labeled_data[dimension]
skipped = labeler.skipped_pairs.get(dimension, 0)
adaptive_ex = labeler.adaptive_example_count.get(dimension, 'N/A')
a_count = sum(1 for l in labels if l.label == "A")
b_count = sum(1 for l in labels if l.label == "B")
ratio = a_count / b_count if b_count > 0 else float("inf")
print(f" {dimension:<30} {len(labels):>8} {skipped:>8} {str(adaptive_ex):>12} {a_count:>6} {b_count:>6} {f'{ratio:.2f}':>10}")
summary_rows.append({"dimension": dimension, "total_labels": len(labels), "skipped_pairs": skipped, "final_adaptive_examples": adaptive_ex, "a_count": a_count, "b_count": b_count, "ratio": ratio, "timestamp": time.time()})
with open("run_summary.csv", "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=["dimension", "total_labels", "skipped_pairs", "final_adaptive_examples", "a_count", "b_count", "ratio", "timestamp"])
writer.writeheader()
writer.writerows(summary_rows)
logger.info("Saved run summary to run_summary.csv")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment