Last active
May 31, 2025 15:31
-
-
Save awni/33a5315e0a5b91ea2cd032af39a624d8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import math | |
import mlx.core as mx | |
import mlx.nn as nn | |
from tqdm import tqdm | |
from mlx_lm.utils import load | |
from pathlib import Path | |
def eval_ppl(model, data, batch_size=32): | |
all_loss = 0.0 | |
ntoks = 0 | |
for s in range(0, len(data), batch_size): | |
batch = data[s:s+batch_size] | |
logits = model(batch[:, :-1]).astype(mx.float32) | |
losses = nn.losses.cross_entropy(logits, batch[:, 1:]) | |
all_loss += losses.sum().item() | |
ntoks += losses.size | |
ppl = math.exp(all_loss / ntoks) | |
return ppl | |
def load_dataset(tokenizer, num_samples: int, sequence_length: int) -> mx.array: | |
save_dir = Path.home() / ".cache/mlx-lm/calibration_v5.txt" | |
if not save_dir.exists(): | |
save_dir.parent.mkdir(parents=True, exist_ok=True) | |
url = "https://gist.githubusercontent.com/tristandruyen/9e207a95c7d75ddf37525d353e00659c/raw/571fda718462de863e5a0171078c175420c7649a/calibration_data_v5_rc.txt" | |
request.urlretrieve(url, save_dir) | |
with open(save_dir) as fid: | |
texts = fid.read() | |
tokens = tokenizer.encode(texts, return_tensors="mlx")[0] | |
# select random non-overlapping chunks | |
tokens = tokens[: (tokens.size // sequence_length) * sequence_length] | |
tokens = tokens.reshape(-1, sequence_length) | |
segments = mx.random.permutation(tokens.shape[0])[:num_samples] | |
return tokens[segments] | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--model", "-m", default="Qwen/Qwen3-1.7B" | |
) | |
parser.add_argument("--num-samples", type=int, default=32) | |
parser.add_argument("--sequence-length", type=int, default=512) | |
parser.add_argument("--seed", type=int, default=123) | |
args = parser.parse_args() | |
mx.random.seed(args.seed) | |
model, tokenizer = load(args.model) | |
data = load_dataset(tokenizer, args.num_samples, args.sequence_length) | |
ppl = eval_ppl(model, data) | |
print(f"Original PPL: {ppl:.3f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment