Skip to content

Instantly share code, notes, and snippets.

@awni
Last active May 31, 2025 15:31
Show Gist options
  • Save awni/33a5315e0a5b91ea2cd032af39a624d8 to your computer and use it in GitHub Desktop.
Save awni/33a5315e0a5b91ea2cd032af39a624d8 to your computer and use it in GitHub Desktop.
import argparse
import math
import mlx.core as mx
import mlx.nn as nn
from tqdm import tqdm
from mlx_lm.utils import load
from pathlib import Path
def eval_ppl(model, data, batch_size=32):
all_loss = 0.0
ntoks = 0
for s in range(0, len(data), batch_size):
batch = data[s:s+batch_size]
logits = model(batch[:, :-1]).astype(mx.float32)
losses = nn.losses.cross_entropy(logits, batch[:, 1:])
all_loss += losses.sum().item()
ntoks += losses.size
ppl = math.exp(all_loss / ntoks)
return ppl
def load_dataset(tokenizer, num_samples: int, sequence_length: int) -> mx.array:
save_dir = Path.home() / ".cache/mlx-lm/calibration_v5.txt"
if not save_dir.exists():
save_dir.parent.mkdir(parents=True, exist_ok=True)
url = "https://gist.githubusercontent.com/tristandruyen/9e207a95c7d75ddf37525d353e00659c/raw/571fda718462de863e5a0171078c175420c7649a/calibration_data_v5_rc.txt"
request.urlretrieve(url, save_dir)
with open(save_dir) as fid:
texts = fid.read()
tokens = tokenizer.encode(texts, return_tensors="mlx")[0]
# select random non-overlapping chunks
tokens = tokens[: (tokens.size // sequence_length) * sequence_length]
tokens = tokens.reshape(-1, sequence_length)
segments = mx.random.permutation(tokens.shape[0])[:num_samples]
return tokens[segments]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", "-m", default="Qwen/Qwen3-1.7B"
)
parser.add_argument("--num-samples", type=int, default=32)
parser.add_argument("--sequence-length", type=int, default=512)
parser.add_argument("--seed", type=int, default=123)
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load(args.model)
data = load_dataset(tokenizer, args.num_samples, args.sequence_length)
ppl = eval_ppl(model, data)
print(f"Original PPL: {ppl:.3f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment