Skip to content

Instantly share code, notes, and snippets.

@rawsh
Last active September 18, 2025 12:13
Show Gist options
  • Save rawsh/245b3ddd466911d744b2d1b9f409d21b to your computer and use it in GitHub Desktop.
Save rawsh/245b3ddd466911d744b2d1b9f409d21b to your computer and use it in GitHub Desktop.
WIP inference engine divergence testing

relevant blog posts and docs

start a container:

sudo docker run   \
--name rob-div   \
-it   \
--shm-size=64g   \
--gpus all   \
--ulimit memlock=-1:-1   \
--ulimit stack=67108864   \
-e UV_CACHE_DIR=/tmp/uv_cache   \
-v /tmp:/tmp   \
nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04    \
bash

install uv

apt-get update
apt-get install -y --no-install-recommends curl ca-certificates

curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env

uv venv
source .venv/bin/activate

deps

apt-get install libnuma-dev
uv pip install sglang==0.5.2 sgl-kernel==0.3.9.post2 vllm==0.10.2 matplotlib datasets torchao accelerate ninja
uv pip install flash_attn --no-build-isolation --verbose

dataset: https://huggingface.co/datasets/AI-MO/aimo-validation-aime

figures generated with 90 prompts * 8 completions for each prompt

run sglang

python compare.py --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --engine sglang --n 8 --max-new-tokens 32768 --out out_sglang

run vllm

python compare.py --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --engine vllm --n 8 --max-new-tokens 32768 --out out_vllm

run vllm with inductor=true

python compare.py --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --engine vllm --n 8 --max-new-tokens 32768 --out out_vllm --vllm-use-inductor

sglang

{
  "mae_logprob": 0.01667347177863121,
  "rmse_logprob": 0.037837449461221695,
  "pearson_r": 0.9997906483213266,
  "kl_divergence": 0.00027361718095684465,
  "rollout_probs_diff_mean": -0.00017899970093034225,
  "rollout_probs_diff_std": 0.00930520652831646,
  "avg_completion_length": 11864.4,
  "min_completion_length": 1280,
  "max_completion_length": 23774,
  "n_tokens": 8542368
}
diff_raw diff_log

longest generation

longest_prob_diff

vllm (inductor true)

{
  "mae_logprob": 0.017033929005265236,
  "rmse_logprob": 0.03777551278471947,
  "pearson_r": 0.9997940070543079,
  "kl_divergence": 0.00027386811367469913,
  "rollout_probs_diff_mean": -0.00021243607617292383,
  "rollout_probs_diff_std": 0.00940233598260049,
  "avg_completion_length": 11782.836111111112,
  "min_completion_length": 2152,
  "max_completion_length": 21178,
  "n_tokens": 8483642
}
diff_raw diff_log

longest generation

longest_prob_diff

vllm (inductor false)

{
  "mae_logprob": 0.014655027538537979,
  "rmse_logprob": 0.03175614774227142,
  "pearson_r": 0.999848608900816,
  "kl_divergence": 0.00017749466835524134,
  "rollout_probs_diff_mean": -8.441660138512609e-05,
  "rollout_probs_diff_std": 0.007982945783328473,
  "avg_completion_length": 11615.763888888889,
  "min_completion_length": 1484,
  "max_completion_length": 22054,
  "n_tokens": 8363350
}
diff_raw diff_log

longest generation

longest_prob_diff

discussion: https://x.com/rawsh0/status/1967462360333107364

#!/usr/bin/env python3
import argparse, os, gc, json, random, csv
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
# os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
import numpy as np
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from matplotlib import pyplot as plt
# Optional, used only when --engine vllm
try:
from vllm import LLM, SamplingParams
except Exception:
LLM = None
SamplingParams = None
SUFFIX = "Let's think step by step and answer in \\boxed{}."
def ensure_dir(d):
os.makedirs(d, exist_ok=True)
def set_seed(s):
random.seed(s); np.random.seed(s); torch.manual_seed(s)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(s)
def _get_base_module_and_head(model: AutoModelForCausalLM):
head = getattr(model, "lm_head", None) or getattr(model, "embed_out", None)
base = None
base_prefix = getattr(model, "base_model_prefix", None)
if isinstance(base_prefix, str) and hasattr(model, base_prefix):
base = getattr(model, base_prefix)
for name in ["model", "transformer", "language_model", "backbone", "base_model"]:
if base is None and hasattr(model, name):
base = getattr(model, name)
if base is None or head is None:
raise RuntimeError("Could not locate base transformer or lm_head on the HF model.")
return base, head
@torch.no_grad()
def _chunked_token_logprobs_from_hidden(hiddens: torch.Tensor,
head_weight: torch.Tensor,
targets: torch.Tensor,
time_chunk: int = 256) -> torch.Tensor:
"""
hiddens: [B, L, H] (on CPU)
head_weight: [V, H] on its own (likely CUDA) device
targets: [B, L-1] (on CPU)
returns lp: [B, L-1] (float32, on CPU)
"""
B, L, H = hiddens.shape
T = L - 1
V, Hw = head_weight.shape
assert H == Hw, f"Hidden size mismatch: {H} vs {Hw}"
out_device = hiddens.device # CPU
weight_device = head_weight.device # e.g., cuda:0/1
# Ensure we run on the correct CUDA device to avoid cross-device kernel weirdness
if weight_device.type == 'cuda':
torch.cuda.set_device(weight_device.index)
# keep W on the head device, dtype matched with hiddens (bf16 on CPU is fine)
W = head_weight.to(dtype=hiddens.dtype, device=weight_device)
lp = torch.empty((B, T), dtype=torch.float32, device=out_device)
t = 0
cur_chunk = max(1, int(time_chunk))
while t < T:
cur = min(cur_chunk, T - t)
try:
# Slice, make contiguous, then copy CPU->CUDA (blocking copy for stability)
h = hiddens[:, t:t+cur, :].contiguous().view(-1, H).to(weight_device, non_blocking=False)
y = targets[:, t:t+cur].contiguous().view(-1).to(weight_device, non_blocking=False)
logits = F.linear(h, W) # [B*cur, V] on weight_device
# Numerically stable log softmax for chosen tokens
m = logits.max(dim=-1).values
lse = torch.logsumexp((logits - m.unsqueeze(1)).to(torch.float32), dim=-1) + m.to(torch.float32)
chosen = logits.gather(1, y.unsqueeze(1)).squeeze(1).to(torch.float32)
lp_chunk = (chosen - lse).view(B, cur).to(out_device, non_blocking=False)
lp[:, t:t+cur] = lp_chunk
# cleanup
del h, y, logits, m, lse, chosen, lp_chunk
if torch.cuda.is_available(): torch.cuda.empty_cache()
t += cur
cur_chunk = time_chunk
except RuntimeError as e:
# Back off time chunk on OOM
if "out of memory" in str(e).lower() and cur > 1:
if torch.cuda.is_available(): torch.cuda.empty_cache()
cur_chunk = max(1, cur // 2)
continue
raise
return lp
def infer_log_probs_batch(model: AutoModelForCausalLM, sequences, device_hint: str, time_chunk: int = 256):
"""
Run base forward (sharded on GPUs), offload [B,L,H] to CPU, then do chunked head on the head device.
"""
use_cpu_io = hasattr(model, "hf_device_map") or hasattr(model, "device_map")
tgt_device = 'cpu' if use_cpu_io else device_hint
lens = [len(s) for s in sequences]
Lm = max(lens) if lens else 0
pad_id = model.config.pad_token_id if model.config.pad_token_id is not None else (model.config.eos_token_id or 0)
inp = torch.full((len(sequences), Lm), pad_id, dtype=torch.long, device=tgt_device)
for i, s in enumerate(sequences):
if len(s) > 0:
inp[i, :len(s)] = torch.tensor(s, dtype=torch.long, device=tgt_device)
attn = (inp != pad_id).long()
try:
model.config.use_cache = False
except Exception:
pass
base, head = _get_base_module_and_head(model)
# If the model was loaded with a device_map (even single GPU), ensure inputs
# are moved to the same device as the base module's first parameter (embedding device)
if use_cpu_io:
try:
first_param_device = next(base.parameters()).device
if inp.device != first_param_device:
inp = inp.to(first_param_device)
attn = attn.to(first_param_device)
except Exception:
pass
with torch.inference_mode():
outputs = base(
input_ids=inp,
attention_mask=attn,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
)
hidden_states = outputs[0] # [B, Lm, H] on last shard GPU
del outputs
# Offload activations to CPU, then make contiguous + (optional) pin
if hidden_states.is_cuda:
hidden_states = hidden_states.to('cpu', non_blocking=False)
torch.cuda.synchronize(); torch.cuda.empty_cache()
hidden_states = hidden_states.contiguous()
try:
hidden_states = hidden_states.pin_memory()
except Exception:
pass
tgt = inp[:, 1:] # [B, Lm-1] on CPU
lp = _chunked_token_logprobs_from_hidden(
hiddens=hidden_states,
head_weight=head.weight,
targets=tgt,
time_chunk=time_chunk,
)
del hidden_states
if torch.cuda.is_available(): torch.cuda.empty_cache()
return lp, lens
# ------------------------- plotting -------------------------
def plot_correlation(eng_logp, hf_logp, out_png, log_space=False):
p_min, p_max = (-40, 0) if log_space else (0, 1)
X = (eng_logp if log_space else eng_logp.exp()).float().cpu().numpy()
Y = (hf_logp if log_space else hf_logp.exp()).float().cpu().numpy()
fig, axes = plt.subplots(2, 1, figsize=(10, 10), gridspec_kw={'height_ratios': [4, 2]})
axes[0].set_aspect('equal')
axes[0].set_xlim(p_min, p_max); axes[0].set_ylim(p_min, p_max)
axes[1].set_xlim(p_min, p_max)
hist, xe, ye = np.histogram2d(X, Y, bins=100, range=[[p_min, p_max], [p_min, p_max]], density=False)
hist = np.log(hist + 1e-10)
Xm, Ym = np.meshgrid(xe[:-1], ye[:-1])
im = axes[0].pcolormesh(Xm, Ym, hist.T, shading='auto')
axes[0].plot([p_min, p_max], [p_min, p_max], linestyle='--', linewidth=1)
axes[0].set_xlabel('Engine ' + ('log-prob' if log_space else 'probability'))
axes[0].set_ylabel('HF ' + ('log-prob' if log_space else 'probability'))
fig.colorbar(im, ax=axes[0], label='Log Frequency')
hx, xe1 = np.histogram(X, bins=100, range=[p_min, p_max], density=True)
hy, ye1 = np.histogram(Y, bins=100, range=[p_min, p_max], density=True)
axes[1].plot(xe1[:-1], np.log(hx + 1e-12), label='Engine')
axes[1].plot(ye1[:-1], np.log(hy + 1e-12), label='HF')
axes[1].legend(); axes[1].set_ylabel('Log Density'); axes[1].set_xlabel('log-prob' if log_space else 'probability')
plt.tight_layout(); plt.savefig(out_png, dpi=150); plt.close()
def plot_sample_prob_diff(hf_logp, eng_logp, out_png):
diff = hf_logp.exp() - eng_logp.exp()
xs = np.arange(len(diff))
plt.figure(figsize=(10, 4))
plt.plot(xs, diff.cpu().numpy())
plt.xlabel('Response token index'); plt.ylabel('Δ prob (HF − Engine)')
plt.tight_layout(); plt.savefig(out_png, dpi=150); plt.close()
# ------------------------- engines -------------------------
def run_vllm(prompt_ids_list, model, batch_size, max_new_tokens, seed=0, use_inductor=False):
assert LLM is not None, 'vLLM is not installed.'
llm = LLM(
model=model,
dtype=torch.bfloat16,
trust_remote_code=True,
compilation_config={
"use_inductor": use_inductor,
}
)
# NOTE: script not valid for temp != 1.0, vllm needs a patch bc logprobs are returned before sampling
sp = SamplingParams(
temperature=1.0,
top_p=1.0,
top_k=-1,
max_tokens=int(max_new_tokens),
logprobs=0,
detokenize=True,
seed=seed
)
prompt_ids, gen_ids, gen_logp, texts = [], [], [], []
for i in range(0, len(prompt_ids_list), batch_size):
batch_prompts = prompt_ids_list[i:i+batch_size]
# vLLM expects a list of PromptInputs; use tokenized prompts schema
batch_inputs = [{"prompt_token_ids": ids} for ids in batch_prompts]
outs = llm.generate(batch_inputs, sampling_params=sp)
for pid_sent, o in zip(batch_prompts, outs):
sample = o.outputs[0]
p_ids = list(pid_sent)
g_ids = list(sample.token_ids)
if sample.logprobs is None:
raise RuntimeError("vLLM returned no logprobs; set SamplingParams.logprobs >= 1.")
chosen_lp = []
for t, tok_id in enumerate(g_ids):
lp_dict = sample.logprobs[t] # dict[token_id -> Logprob]
lp_obj = lp_dict.get(tok_id)
if lp_obj is None:
raise RuntimeError("Chosen token not in returned top-k logprobs (???)")
chosen_lp.append(float(lp_obj.logprob))
g_lp = torch.tensor(chosen_lp, dtype=torch.float32)
prompt_ids.append(p_ids); gen_ids.append(g_ids); gen_logp.append(g_lp); texts.append(sample.text)
del llm; gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()
return prompt_ids, gen_ids, gen_logp, texts
def run_sglang(prompt_ids_list, model, batch_size, max_new_tokens):
from sglang.srt.entrypoints.engine import Engine
engine = Engine(model_path=model, dtype='bfloat16', tp_size=1, trust_remote_code=True, load_format='auto', log_level='INFO', max_running_requests=1024)
prompt_ids, gen_ids, gen_logp, texts = [], [], [], []
for i in range(0, len(prompt_ids_list), batch_size):
batch_ids = prompt_ids_list[i:i+batch_size]
sp = {
"n": 1, "max_new_tokens": int(max_new_tokens), "temperature": 1.0, "top_p": 1.0,
"top_k": -1, "ignore_eos": False, "min_new_tokens": 0,
"skip_special_tokens": True, "spaces_between_special_tokens": True,
}
outs = engine.generate(prompt=None, sampling_params=sp, return_logprob=True, input_ids=batch_ids, image_data=None)
for pid, out in zip(batch_ids, outs):
tpl = out["meta_info"]["output_token_logprobs"]
g_lp = torch.tensor([float(t[0]) for t in tpl], dtype=torch.float32)
g_ii = [int(t[1]) for t in tpl]
prompt_ids.append(pid); gen_ids.append(g_ii); gen_logp.append(g_lp); texts.append(out["text"])
try:
engine.shutdown()
except Exception:
pass
del engine; gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()
return prompt_ids, gen_ids, gen_logp, texts
def build_hf_model(model_name):
# Give accelerate a memory budget so it offloads instead of spiking one GPU
max_memory = {}
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
total = torch.cuda.get_device_properties(i).total_memory
giB = int((total * 0.95) // (1024**3))
max_memory[i] = f"{giB}GiB"
max_memory["cpu"] = "256GiB"
# Flash-Attention 2 only
hf = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
device_map='auto' if torch.cuda.is_available() else None,
max_memory=max_memory if torch.cuda.is_available() else None,
)
hf.eval()
try:
hf.config.use_cache = False
except Exception:
pass
if not torch.cuda.is_available():
hf.to("cpu")
return hf, ('cuda' if torch.cuda.is_available() else 'cpu')
# ------------------------- main -------------------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--engine', choices=['vllm','sglang'], required=True)
ap.add_argument('--model', required=True)
ap.add_argument('--batch-size', type=int, default=256)
ap.add_argument('--hf-batch-size', type=int, default=64, help='Batch size for HF scoring (keep small to avoid OOM)')
ap.add_argument('--time-chunk-size', type=int, default=128, help='Time chunk (tokens) for head projection')
ap.add_argument('--max-new-tokens', type=int, default=32768)
ap.add_argument('--seed', type=int, default=0)
ap.add_argument('--n', type=int, default=16)
ap.add_argument('--out', default='out')
ap.add_argument('--dataset', default='AI-MO/aimo-validation-aime') # or MathArena/aime_2025
ap.add_argument('--vllm-use-inductor', action='store_true')
args = ap.parse_args()
set_seed(args.seed); ensure_dir(args.out)
if args.vllm_use_inductor:
assert args.engine == 'vllm', 'vLLM with inductor requires vLLM engine'
ds = load_dataset(args.dataset, split='train')
user_texts = [f"{row['problem']}\n\n{SUFFIX}" for row in ds]
user_texts = user_texts * args.n
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
# Build chat prompts using the model's chat template
messages_list = [[{"role": "user", "content": text}] for text in user_texts]
chat_prompt_ids = [
tokenizer.apply_chat_template(msgs, add_generation_prompt=True, tokenize=True, return_tensors=None)
for msgs in messages_list
]
# Ensure lists of ints for sglang
chat_prompt_ids = [
ids.tolist() if hasattr(ids, 'tolist') else (list(ids[0]) if isinstance(ids, (tuple, list)) and hasattr(ids[0], '__iter__') else list(ids))
for ids in chat_prompt_ids
]
if args.engine == 'vllm':
p_ids, g_ids, g_lp, texts = run_vllm(chat_prompt_ids, args.model, args.batch_size, args.max_new_tokens, seed=args.seed, use_inductor=args.vllm_use_inductor)
else:
p_ids, g_ids, g_lp, texts = run_sglang(chat_prompt_ids, args.model, args.batch_size, args.max_new_tokens)
hf, device = build_hf_model(args.model)
# Sanity checks
V = hf.get_input_embeddings().weight.size(0)
for idx, (pi, gi, elp) in enumerate(zip(p_ids, g_ids, g_lp)):
assert len(gi) == len(elp), f"Engine IDs/logprobs length mismatch at sample {idx} (len(ids)={len(gi)} vs len(lp)={len(elp)})"
if (pi and max(pi) >= V) or (gi and max(gi) >= V):
raise ValueError(f"Token id out of range at sample {idx}: vocab={V}, max_id={max(pi+gi)}")
sequences = [pi + gi for pi, gi in zip(p_ids, g_ids)]
max_pos = getattr(hf.config, "max_position_embeddings", None)
if max_pos is not None:
too_long = [i for i, s in enumerate(sequences) if len(s) > max_pos]
if len(too_long) > 0:
print(f"Warning: {len(too_long)} sequences exceed model max_position_embeddings={max_pos} and may be truncated or slow.")
# ---------- HF scoring (length-bucket to cut padding) ----------
idxs = list(range(len(sequences)))
idxs.sort(key=lambda i: len(sequences[i])) # shortest -> longest
hf_rows = [None] * len(sequences)
for start in range(0, len(sequences), args.hf_batch_size):
batch_idx = idxs[start:start+args.hf_batch_size]
seq_batch = [sequences[i] for i in batch_idx]
lp_batch, lens_batch = infer_log_probs_batch(
hf, seq_batch, device, time_chunk=args.time_chunk_size
)
for j, (i_orig, L) in enumerate(zip(batch_idx, lens_batch)):
hf_rows[i_orig] = lp_batch[j, :L-1].detach().cpu()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Align engine vs HF on generated tokens
eng_slices, hf_slices, slice_indices = [], [], []
for idx, (pi, gi, eng_lp) in enumerate(zip(p_ids, g_ids, g_lp)):
Lp, Lg = len(pi), len(gi)
hf_row = hf_rows[idx]
start = max(Lp - 1, 0)
hf_slice = hf_row[start:start+Lg]
m = min(len(eng_lp), len(hf_slice))
if m > 0:
eng_slices.append(eng_lp[:m])
hf_slices.append(hf_slice[:m])
slice_indices.append(idx)
eng_all = torch.cat(eng_slices) if eng_slices else torch.empty(0)
hf_all = torch.cat(hf_slices) if hf_slices else torch.empty(0)
# Save raw per-item
# Re-ensure output directory and use absolute path to avoid cwd ambiguity
out_dir = os.path.abspath(args.out)
os.makedirs(out_dir, exist_ok=True)
out_jsonl = os.path.join(out_dir, 'engine_outputs.jsonl')
with open(out_jsonl, 'w', encoding='utf-8') as f:
for pi, gi, elp, txt in zip(p_ids, g_ids, g_lp, texts):
f.write(json.dumps({'prompt_ids': pi, 'gen_ids': gi, 'gen_logprobs': [float(x) for x in elp.tolist()], 'text': txt})+'\n')
# Summary metrics + plots
if len(eng_all) > 0:
e = eng_all.float().cpu().numpy()
h = hf_all.float().cpu().numpy()
mae = float(np.mean(np.abs(h - e)))
rmse = float(np.sqrt(np.mean((h - e) ** 2)))
corr = float(np.corrcoef(h, e)[0, 1]) if len(h) > 1 else float('nan')
# Additional metrics in probability space
lnp = np.clip(h.astype(np.float64), -80.0, 0.0)
lnq = np.clip(e.astype(np.float64), -80.0, 0.0)
p_raw = np.exp(lnp)
q_raw = np.exp(lnq)
diff = p_raw - q_raw
rollout_probs_diff_mean = float(np.mean(diff))
rollout_probs_diff_std = float(np.std(diff))
# Bernoulli KL between chosen-token probabilities (clip to avoid log(0))
eps = 1e-12
p = np.clip(p_raw, eps, 1.0 - eps)
q = np.clip(q_raw, eps, 1.0 - eps)
kl_vals = p * (np.log(p) - np.log(q)) + (1.0 - p) * (np.log1p(-p) - np.log1p(-q))
kl_divergence = float(np.mean(kl_vals))
# Completion length stats (in tokens)
completion_lengths = np.array([len(gen_ids) for gen_ids in g_ids], dtype=np.int64)
avg_completion_length = float(np.mean(completion_lengths)) if completion_lengths.size > 0 else 0.0
min_completion_length = int(np.min(completion_lengths)) if completion_lengths.size > 0 else 0
max_completion_length = int(np.max(completion_lengths)) if completion_lengths.size > 0 else 0
with open(os.path.join(out_dir, 'summary_metrics.json'), 'w', encoding='utf-8') as f:
json.dump({
"mae_logprob": mae,
"rmse_logprob": rmse,
"pearson_r": corr,
"kl_divergence": kl_divergence,
"rollout_probs_diff_mean": rollout_probs_diff_mean,
"rollout_probs_diff_std": rollout_probs_diff_std,
"avg_completion_length": avg_completion_length,
"min_completion_length": min_completion_length,
"max_completion_length": max_completion_length,
"n_tokens": int(len(h))
}, f, indent=2)
plot_correlation(eng_all, hf_all, os.path.join(out_dir, 'diff_raw.png'), log_space=False)
plot_correlation(eng_all, hf_all, os.path.join(out_dir, 'diff_log.png'), log_space=True)
j = len(eng_slices) // 2 if eng_slices else 0
if eng_slices:
plot_sample_prob_diff(hf_slices[j], eng_slices[j], os.path.join(out_dir, 'sample_prob_diff.png'))
orig_j = slice_indices[j]
with open(os.path.join(out_dir, 'sample_completion.txt'), 'w', encoding='utf-8') as f:
f.write(texts[orig_j])
toks = tokenizer.convert_ids_to_tokens(g_ids[orig_j][:len(eng_slices[j])])
with open(os.path.join(out_dir, 'sample_token_diffs.csv'), 'w', newline='', encoding='utf-8') as cf:
w = csv.writer(cf); w.writerow(['idx','token','prob_hf','prob_engine','delta'])
for i, (hlp, elp, t) in enumerate(zip(hf_slices[j], eng_slices[j], toks)):
ph, pe = float(hlp.exp()), float(elp.exp())
w.writerow([i, t, ph, pe, ph-pe])
j_longest = max(range(len(slice_indices)), key=lambda k: len(g_ids[slice_indices[k]]))
orig_longest = slice_indices[j_longest]
plot_sample_prob_diff(hf_slices[j_longest], eng_slices[j_longest], os.path.join(out_dir, 'longest_prob_diff.png'))
with open(os.path.join(out_dir, 'longest_completion.txt'), 'w', encoding='utf-8') as f:
f.write(texts[orig_longest])
toks = tokenizer.convert_ids_to_tokens(g_ids[orig_longest][:len(eng_slices[j_longest])])
with open(os.path.join(out_dir, 'longest_token_diffs.csv'), 'w', newline='', encoding='utf-8') as cf:
w = csv.writer(cf); w.writerow(['idx','token','prob_hf','prob_engine','delta'])
for i, (hlp, elp, t) in enumerate(zip(hf_slices[j_longest], eng_slices[j_longest], toks)):
ph, pe = float(hlp.exp()), float(elp.exp())
w.writerow([i, t, ph, pe, ph-pe])
# Clean up torch.distributed process group if initialized (avoids NCCL warning)
try:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
except Exception:
pass
print('Saved to', out_dir)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment