Created
July 9, 2025 11:31
-
-
Save macleginn/09a270a51276a96975ffa1a12dcef733 to your computer and use it in GitHub Desktop.
Estimate the sum probability mass of a corpus using importance sampling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import random | |
from math import ceil | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from tqdm.auto import tqdm | |
random.seed(42) | |
def compute_prefix_likelihood(model, tokenizer, prefix_text, all_prefixes, device="cuda"): | |
""" | |
Computes the model likelihood of a random previously unseen prefix of the input. | |
Returns the length of the input for importance-sampling weighting. | |
""" | |
model.eval() | |
inputs = tokenizer(prefix_text, return_tensors="pt", truncation=True, max_length=2048) | |
input_ids = inputs["input_ids"][0].tolist() | |
# Select a random prefix and add an EOS token to it. | |
# Return the sequence length for normalisation. | |
sequence_length = len(input_ids) | |
# Make several attempts to find an unseen prefix. | |
for _ in range(5): | |
prefix_length = random.randint(1, sequence_length-1) | |
input_ids_candidate = input_ids[:prefix_length] | |
input_ids_key = tuple(input_ids_candidate) | |
if input_ids_key not in all_prefixes: | |
all_prefixes.add(input_ids_key) | |
input_ids = input_ids_candidate | |
break | |
else: | |
# No new prefix in this text. | |
return None | |
if tokenizer.eos_token_id is not None: | |
input_ids += [tokenizer.eos_token_id] | |
else: | |
input_ids += [tokenizer.pad_token_id] | |
input_ids = torch.tensor(input_ids).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids, labels=input_ids) | |
logits = outputs.logits | |
# For causal LM, we predict token i+1 given tokens 0...i | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = input_ids[..., 1:].contiguous() | |
# Compute log probabilities | |
log_probs = F.log_softmax(shift_logits, dim=-1) | |
# Gather log probabilities for actual tokens | |
per_token_log_probs = log_probs.gather( | |
2, | |
shift_labels.unsqueeze(-1) | |
).squeeze(-1) | |
total_log_likelihood = per_token_log_probs.sum().item() | |
return np.exp(total_log_likelihood), prefix_length | |
olmo = AutoModelForCausalLM.from_pretrained( | |
"allenai/OLMo-7B-0424-hf", cache_dir="../hf_cache/", device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"allenai/OLMo-7B-0424-hf", cache_dir="../hf_cache/" | |
) | |
DOLMA_N_DOCUMENTS = 3086909876 | |
x = np.longdouble() | |
seen_prefixes = set() | |
with open('../../corpora/dolma/old/sample.jsonl', 'r', encoding='utf-8') as inp: | |
for i, line in tqdm(enumerate(inp), total=10**5): | |
estimation_result = compute_prefix_likelihood( | |
olmo, tokenizer, json.loads(line.strip())['text'], seen_prefixes) | |
if estimation_result is None: | |
continue | |
else: | |
prefix_likelihood, prefix_length = estimation_result | |
x += prefix_likelihood * prefix_length | |
with open('dolma_olmo_prob.log', 'a') as out: | |
print(f"{len(seen_prefixes)},{x * DOLMA_N_DOCUMENTS / len(seen_prefixes)}", file=out) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment