Skip to content

Instantly share code, notes, and snippets.

@macleginn
Created July 9, 2025 11:31
Show Gist options
  • Save macleginn/09a270a51276a96975ffa1a12dcef733 to your computer and use it in GitHub Desktop.
Save macleginn/09a270a51276a96975ffa1a12dcef733 to your computer and use it in GitHub Desktop.
Estimate the sum probability mass of a corpus using importance sampling
import json
import random
from math import ceil
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.auto import tqdm
random.seed(42)
def compute_prefix_likelihood(model, tokenizer, prefix_text, all_prefixes, device="cuda"):
"""
Computes the model likelihood of a random previously unseen prefix of the input.
Returns the length of the input for importance-sampling weighting.
"""
model.eval()
inputs = tokenizer(prefix_text, return_tensors="pt", truncation=True, max_length=2048)
input_ids = inputs["input_ids"][0].tolist()
# Select a random prefix and add an EOS token to it.
# Return the sequence length for normalisation.
sequence_length = len(input_ids)
# Make several attempts to find an unseen prefix.
for _ in range(5):
prefix_length = random.randint(1, sequence_length-1)
input_ids_candidate = input_ids[:prefix_length]
input_ids_key = tuple(input_ids_candidate)
if input_ids_key not in all_prefixes:
all_prefixes.add(input_ids_key)
input_ids = input_ids_candidate
break
else:
# No new prefix in this text.
return None
if tokenizer.eos_token_id is not None:
input_ids += [tokenizer.eos_token_id]
else:
input_ids += [tokenizer.pad_token_id]
input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
logits = outputs.logits
# For causal LM, we predict token i+1 given tokens 0...i
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
# Compute log probabilities
log_probs = F.log_softmax(shift_logits, dim=-1)
# Gather log probabilities for actual tokens
per_token_log_probs = log_probs.gather(
2,
shift_labels.unsqueeze(-1)
).squeeze(-1)
total_log_likelihood = per_token_log_probs.sum().item()
return np.exp(total_log_likelihood), prefix_length
olmo = AutoModelForCausalLM.from_pretrained(
"allenai/OLMo-7B-0424-hf", cache_dir="../hf_cache/", device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
"allenai/OLMo-7B-0424-hf", cache_dir="../hf_cache/"
)
DOLMA_N_DOCUMENTS = 3086909876
x = np.longdouble()
seen_prefixes = set()
with open('../../corpora/dolma/old/sample.jsonl', 'r', encoding='utf-8') as inp:
for i, line in tqdm(enumerate(inp), total=10**5):
estimation_result = compute_prefix_likelihood(
olmo, tokenizer, json.loads(line.strip())['text'], seen_prefixes)
if estimation_result is None:
continue
else:
prefix_likelihood, prefix_length = estimation_result
x += prefix_likelihood * prefix_length
with open('dolma_olmo_prob.log', 'a') as out:
print(f"{len(seen_prefixes)},{x * DOLMA_N_DOCUMENTS / len(seen_prefixes)}", file=out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment