Skip to content

Instantly share code, notes, and snippets.

@snellingio
Last active November 28, 2025 20:53
Show Gist options
  • Select an option

  • Save snellingio/6f76ace60d82fa2c8f8ebbfb44411c45 to your computer and use it in GitHub Desktop.

Select an option

Save snellingio/6f76ace60d82fa2c8f8ebbfb44411c45 to your computer and use it in GitHub Desktop.
find_model_configs.py
Finding configs for ~1,000,000 params (±5%)
Vocab size: 8,192
====================================================================================================
depth | dim | heads | kv_heads | mlp_m | mlp_hid | params | diff | pct
----------------------------------------------------------------------------------------------------
9 | 48 | 6 | 1 | 4.0 | 192 | 1,000,704 | 704 | +0.07% (GQA)
4 | 56 | 4 | 1 | 2.0 | 112 | 999,040 | 960 | -0.10% (GQA)
18 | 40 | 1 | 1 | 4.0 | 160 | 1,000,960 | 960 | +0.10%
18 | 40 | 1 | 1 | 4.0 | 160 | 1,000,960 | 960 | +0.10%
18 | 40 | 1 | 1 | 4.0 | 160 | 1,000,960 | 960 | +0.10%
18 | 40 | 2 | 2 | 4.0 | 160 | 1,000,960 | 960 | +0.10%
18 | 40 | 4 | 4 | 4.0 | 160 | 1,000,960 | 960 | +0.10%
11 | 48 | 4 | 1 | 3.0 | 144 | 1,001,856 | 1,856 | +0.19% (GQA)
11 | 48 | 8 | 2 | 3.0 | 144 | 1,001,856 | 1,856 | +0.19% (GQA)
11 | 48 | 12 | 3 | 3.0 | 144 | 1,001,856 | 1,856 | +0.19% (GQA)
3 | 56 | 2 | 1 | 3.0 | 168 | 1,002,176 | 2,176 | +0.22% (GQA)
3 | 56 | 2 | 1 | 3.0 | 168 | 1,002,176 | 2,176 | +0.22% (GQA)
3 | 56 | 4 | 2 | 3.0 | 168 | 1,002,176 | 2,176 | +0.22% (GQA)
11 | 48 | 6 | 1 | 3.0 | 144 | 997,632 | 2,368 | -0.24% (GQA)
3 | 56 | 4 | 1 | 3.0 | 168 | 997,472 | 2,528 | -0.25% (GQA)
13 | 48 | 2 | 1 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA)
13 | 48 | 2 | 1 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA)
13 | 48 | 4 | 2 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA)
13 | 48 | 6 | 3 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA)
13 | 48 | 8 | 4 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA)
13 | 48 | 12 | 6 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA)
14 | 48 | 4 | 1 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA)
14 | 48 | 8 | 2 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA)
14 | 48 | 12 | 3 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA)
9 | 48 | 4 | 1 | 4.0 | 192 | 1,004,160 | 4,160 | +0.42% (GQA)
9 | 48 | 8 | 2 | 4.0 | 192 | 1,004,160 | 4,160 | +0.42% (GQA)
9 | 48 | 12 | 3 | 4.0 | 192 | 1,004,160 | 4,160 | +0.42% (GQA)
4 | 56 | 2 | 1 | 2.0 | 112 | 1,005,312 | 5,312 | +0.53% (GQA)
4 | 56 | 2 | 1 | 2.0 | 112 | 1,005,312 | 5,312 | +0.53% (GQA)
4 | 56 | 4 | 2 | 2.0 | 112 | 1,005,312 | 5,312 | +0.53% (GQA)
====================================================================================================
Example CLI commands for top 3 configs:
----------------------------------------------------------------------------------------------------
1. 1,000,704 params:
uv run python -m scripts.base_train --depth=9 --model_dim=48 --num_heads=6 --num_kv_heads=1 --mlp_hidden_mult=4.0
2. 999,040 params:
uv run python -m scripts.base_train --depth=4 --model_dim=56 --num_heads=4 --num_kv_heads=1 --mlp_hidden_mult=2.0
3. 1,000,960 params:
uv run python -m scripts.base_train --depth=18 --model_dim=40 --num_heads=1 --mlp_hidden_mult=4.0
"""
Find model configurations that match a target parameter count.
Usage:
uv run python -m scripts.find_model_configs --target=1000000
uv run python -m scripts.find_model_configs --target=10000000 --vocab_size=32000
uv run python -m scripts.find_model_configs --target=1000000 --tolerance=0.05
"""
import argparse
def count_params(vocab_size, n_layer, n_embd, n_head, n_kv_head, mlp_hidden_dim):
"""Count total parameters for the model."""
head_dim = n_embd // n_head
# Embeddings
wte = vocab_size * n_embd
lm_head = n_embd * vocab_size
# Per block
c_q = n_embd * n_embd
c_k = n_embd * (n_kv_head * head_dim)
c_v = n_embd * (n_kv_head * head_dim)
c_proj = n_embd * n_embd
mlp_fc = n_embd * mlp_hidden_dim
mlp_proj = mlp_hidden_dim * n_embd
block_params = c_q + c_k + c_v + c_proj + mlp_fc + mlp_proj
total = wte + lm_head + n_layer * block_params
return total
def find_configs(target_params, vocab_size, tolerance=0.05, max_results=30):
"""Find model configurations near the target parameter count."""
min_params = target_params * (1 - tolerance)
max_params = target_params * (1 + tolerance)
# Determine search ranges based on target size
if target_params < 1_000_000:
depth_range = range(2, 16)
dim_range = range(16, 256, 4)
elif target_params < 10_000_000:
depth_range = range(2, 24)
dim_range = range(32, 512, 8)
elif target_params < 100_000_000:
depth_range = range(4, 32)
dim_range = range(64, 1024, 16)
else:
depth_range = range(6, 48)
dim_range = range(128, 2048, 32)
results = []
for depth in depth_range:
for dim in dim_range:
if dim % 2 != 0:
continue
for n_heads in [1, 2, 4, 6, 8, 12, 16]:
if dim % n_heads != 0:
continue
head_dim = dim // n_heads
if head_dim % 2 != 0:
continue
# Also try GQA configs
for n_kv_heads in [n_heads, max(1, n_heads // 2), max(1, n_heads // 4)]:
if n_kv_heads > n_heads or n_heads % n_kv_heads != 0:
continue
for mlp_mult in [2.0, 3.0, 4.0]:
mlp_hidden = int(dim * mlp_mult)
params = count_params(vocab_size, depth, dim, n_heads, n_kv_heads, mlp_hidden)
if min_params <= params <= max_params:
diff = abs(params - target_params)
pct_diff = 100 * (params - target_params) / target_params
results.append({
"depth": depth,
"model_dim": dim,
"num_heads": n_heads,
"num_kv_heads": n_kv_heads,
"mlp_mult": mlp_mult,
"mlp_hidden": mlp_hidden,
"params": params,
"diff": diff,
"pct_diff": pct_diff,
})
# Sort by absolute difference from target
results.sort(key=lambda x: x["diff"])
return results[:max_results]
def main():
parser = argparse.ArgumentParser(description="Find model configurations for a target parameter count")
parser.add_argument("--target", type=int, default=1_000_000, help="Target number of parameters")
parser.add_argument("--vocab_size", type=int, default=8192, help="Vocabulary size")
parser.add_argument("--tolerance", type=float, default=0.05, help="Tolerance as fraction (0.05 = 5%%)")
parser.add_argument("--max_results", type=int, default=30, help="Maximum number of results to show")
args = parser.parse_args()
print(f"Finding configs for ~{args.target:,} params (±{args.tolerance*100:.0f}%)")
print(f"Vocab size: {args.vocab_size:,}")
print("=" * 100)
results = find_configs(args.target, args.vocab_size, args.tolerance, args.max_results)
if not results:
print("No configurations found. Try increasing --tolerance or adjusting --target.")
return
print(f"{'depth':>5} | {'dim':>5} | {'heads':>5} | {'kv_heads':>8} | {'mlp_m':>5} | {'mlp_hid':>7} | {'params':>12} | {'diff':>10} | {'pct':>6}")
print("-" * 100)
for r in results:
gqa_str = "" if r["num_heads"] == r["num_kv_heads"] else " (GQA)"
print(f"{r['depth']:>5} | {r['model_dim']:>5} | {r['num_heads']:>5} | {r['num_kv_heads']:>8} | {r['mlp_mult']:>5.1f} | {r['mlp_hidden']:>7} | {r['params']:>12,} | {r['diff']:>10,} | {r['pct_diff']:>+5.2f}%{gqa_str}")
print("=" * 100)
# Print CLI commands for the top 3
print("\nExample CLI commands for top 3 configs:")
print("-" * 100)
for i, r in enumerate(results[:3]):
cmd = f"uv run python -m scripts.base_train --depth={r['depth']} --model_dim={r['model_dim']} --num_heads={r['num_heads']}"
if r["num_kv_heads"] != r["num_heads"]:
cmd += f" --num_kv_heads={r['num_kv_heads']}"
cmd += f" --mlp_hidden_mult={r['mlp_mult']}"
print(f"{i+1}. {r['params']:,} params:")
print(f" {cmd}")
if __name__ == "__main__":
main()
"""
Sweep script to explore small model configurations under 1M parameters.
Reports tok/sec, bpb, and val loss after ~50 steps.
Usage:
uv run python -m scripts.small_model_sweep
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
from contextlib import nullcontext
from dataclasses import dataclass
import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
from nanochat.common import compute_init, compute_cleanup, print0, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.loss_eval import evaluate_bpb
@dataclass
class SweepConfig:
depth: int
model_dim: int
num_heads: int
mlp_hidden_mult: float
def count_params(vocab_size, n_layer, n_embd, n_head, mlp_hidden_dim):
"""Count total parameters for the model."""
head_dim = n_embd // n_head
wte = vocab_size * n_embd
lm_head = n_embd * vocab_size
# Per block: c_q, c_k, c_v, c_proj (attention) + c_fc, c_proj (MLP)
c_q = n_embd * n_embd
c_k = n_embd * (n_head * head_dim) # assuming n_kv_head == n_head
c_v = n_embd * (n_head * head_dim)
c_proj = n_embd * n_embd
mlp_fc = n_embd * mlp_hidden_dim
mlp_proj = mlp_hidden_dim * n_embd
block_params = c_q + c_k + c_v + c_proj + mlp_fc + mlp_proj
return wte + lm_head + n_layer * block_params
def generate_sweep_configs(vocab_size, max_params=1_000_000):
"""Generate valid model configurations under the parameter budget."""
configs = []
for depth in [3, 4, 5, 6, 7, 8, 9, 10]:
for dim in range(32, 128, 4):
if dim % 2 != 0:
continue
for n_heads in [1, 2, 4]:
if dim % n_heads != 0:
continue
head_dim = dim // n_heads
if head_dim % 2 != 0:
continue
for mlp_mult in [2.0, 3.0, 4.0]:
mlp_hidden = int(dim * mlp_mult)
params = count_params(vocab_size, depth, dim, n_heads, mlp_hidden)
if params <= max_params:
configs.append((params, SweepConfig(depth, dim, n_heads, mlp_mult)))
# Sort by param count descending (prefer larger models within budget)
configs.sort(key=lambda x: -x[0])
# Deduplicate and pick diverse configs
seen = set()
unique = []
for params, cfg in configs:
key = (cfg.depth, cfg.model_dim)
if key not in seen:
seen.add(key)
unique.append((params, cfg))
return unique[:12] # Top 12 configs
def train_config(cfg: SweepConfig, device, device_type, vocab_size, token_bytes,
num_iterations=50, max_seq_len=512, device_batch_size=4, total_batch_size=2048,
eval_tokens=8192, data_dir=None):
"""Train a single configuration and return metrics."""
mlp_hidden_dim = int(cfg.model_dim * cfg.mlp_hidden_mult)
model_config = GPTConfig(
sequence_len=max_seq_len,
vocab_size=vocab_size,
n_layer=cfg.depth,
n_head=cfg.num_heads,
n_kv_head=cfg.num_heads,
n_embd=cfg.model_dim,
mlp_hidden_mult=cfg.mlp_hidden_mult,
mlp_hidden_dim=mlp_hidden_dim,
)
with torch.device("meta"):
model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
num_params = sum(p.numel() for p in model.parameters())
# Compile if possible
if device_type == "cuda" and not os.environ.get("NANOCHAT_NO_COMPILE", ""):
model = torch.compile(model, dynamic=False)
# Setup optimizer
optimizers = model.setup_optimizers(unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0)
# Setup dataloader
train_loader = tokenizing_distributed_data_loader_with_state(
device_batch_size, max_seq_len, split="train", device=device,
resume_state_dict=None, data_dir=data_dir,
)
# Gradient accumulation
tokens_per_fwdbwd = device_batch_size * max_seq_len
grad_accum_steps = max(1, total_batch_size // tokens_per_fwdbwd)
# Autocast context
if device_type == "cuda":
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
elif device_type == "mps":
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
else:
autocast_ctx = nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
# Training loop
x, y, _ = next(train_loader)
total_tokens = 0
total_time = 0.0
train_losses = []
model.train()
for step in range(num_iterations):
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach().item()
loss = loss / grad_accum_steps
loss.backward()
x, y, _ = next(train_loader)
# Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Step optimizers
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
dt = t1 - t0
if step >= 5: # Skip warmup steps
total_time += dt
total_tokens += total_batch_size
train_losses.append(train_loss)
# Evaluate
model.eval()
val_loader = tokenizing_distributed_data_loader(
device_batch_size, max_seq_len, split="val", device=device, data_dir=data_dir,
)
eval_steps = max(1, eval_tokens // (device_batch_size * max_seq_len))
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
# Compute metrics
tok_per_sec = total_tokens / total_time if total_time > 0 else 0
avg_train_loss = sum(train_losses[-10:]) / len(train_losses[-10:]) # Last 10 steps
return {
"num_params": num_params,
"tok_per_sec": tok_per_sec,
"val_bpb": val_bpb,
"train_loss": avg_train_loss,
}
def main():
# Initialize compute
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
# Tokenizer
tokenizer = get_tokenizer()
token_bytes = get_token_bytes(device=device)
vocab_size = tokenizer.get_vocab_size()
print0(f"Device: {device}, Vocab size: {vocab_size}")
print0("=" * 80)
# Generate sweep configs
configs = generate_sweep_configs(vocab_size, max_params=1_000_000)
print0(f"Running sweep over {len(configs)} configurations...")
print0("=" * 80)
results = []
for i, (est_params, cfg) in enumerate(configs):
print0(f"\n[{i+1}/{len(configs)}] depth={cfg.depth}, dim={cfg.model_dim}, heads={cfg.num_heads}, mlp_mult={cfg.mlp_hidden_mult}")
print0(f" Estimated params: {est_params:,}")
try:
metrics = train_config(
cfg, device, device_type, vocab_size, token_bytes,
num_iterations=50,
max_seq_len=512,
device_batch_size=4,
total_batch_size=2048,
eval_tokens=8192,
)
results.append((cfg, metrics))
print0(f" Actual params: {metrics['num_params']:,}")
print0(f" tok/sec: {metrics['tok_per_sec']:,.0f}")
print0(f" val_bpb: {metrics['val_bpb']:.4f}")
print0(f" train_loss: {metrics['train_loss']:.4f}")
except Exception as e:
print0(f" FAILED: {e}")
# Clear memory
torch.cuda.empty_cache() if device_type == "cuda" else None
# Print summary table
print0("\n" + "=" * 80)
print0("SWEEP RESULTS SUMMARY")
print0("=" * 80)
print0(f"{'depth':>5} | {'dim':>4} | {'heads':>5} | {'mlp_m':>5} | {'params':>10} | {'tok/s':>10} | {'val_bpb':>8} | {'train_loss':>10}")
print0("-" * 80)
# Sort by val_bpb (lower is better)
results.sort(key=lambda x: x[1]['val_bpb'])
for cfg, metrics in results:
print0(f"{cfg.depth:>5} | {cfg.model_dim:>4} | {cfg.num_heads:>5} | {cfg.mlp_hidden_mult:>5.1f} | {metrics['num_params']:>10,} | {metrics['tok_per_sec']:>10,.0f} | {metrics['val_bpb']:>8.4f} | {metrics['train_loss']:>10.4f}")
print0("=" * 80)
print0("(Sorted by val_bpb, lower is better)")
compute_cleanup()
if __name__ == "__main__":
main()
Device: mps, Vocab size: 8192
================================================================================
Running sweep over 12 configurations...
================================================================================
[1/12] depth=9, dim=48, heads=1, mlp_mult=3.0
Estimated params: 993,792
Scaling the LR for the AdamW parameters ∝1/√(48/768) = 4.000000
Actual params: 993,792
tok/sec: 45,309
val_bpb: 2.2323
train_loss: 6.8131
[2/12] depth=3, dim=56, heads=1, mlp_mult=2.0
Estimated params: 992,768
Scaling the LR for the AdamW parameters ∝1/√(56/768) = 3.703280
Actual params: 992,768
tok/sec: 88,444
val_bpb: 2.2162
train_loss: 6.7881
[3/12] depth=5, dim=52, heads=1, mlp_mult=3.0
Estimated params: 987,168
Scaling the LR for the AdamW parameters ∝1/√(52/768) = 3.843076
Actual params: 987,168
tok/sec: 66,254
val_bpb: 2.2279
train_loss: 6.7844
[4/12] depth=4, dim=52, heads=1, mlp_mult=4.0
Estimated params: 981,760
Scaling the LR for the AdamW parameters ∝1/√(52/768) = 3.843076
Actual params: 981,760
tok/sec: 73,589
val_bpb: 2.2221
train_loss: 6.7687
[5/12] depth=6, dim=52, heads=1, mlp_mult=2.0
Estimated params: 981,760
Scaling the LR for the AdamW parameters ∝1/√(52/768) = 3.843076
Actual params: 981,760
tok/sec: 57,035
val_bpb: 2.2245
train_loss: 6.8004
[6/12] depth=7, dim=48, heads=1, mlp_mult=4.0
Estimated params: 979,968
Scaling the LR for the AdamW parameters ∝1/√(48/768) = 4.000000
Actual params: 979,968
tok/sec: 51,458
val_bpb: 2.2289
train_loss: 6.8111
[7/12] depth=8, dim=48, heads=1, mlp_mult=3.0
Estimated params: 970,752
Scaling the LR for the AdamW parameters ∝1/√(48/768) = 4.000000
Actual params: 970,752
tok/sec: 48,367
val_bpb: 2.2273
train_loss: 6.8150
[8/12] depth=10, dim=48, heads=1, mlp_mult=2.0
Estimated params: 970,752
Scaling the LR for the AdamW parameters ∝1/√(48/768) = 4.000000
Actual params: 970,752
tok/sec: 42,270
val_bpb: 2.2462
train_loss: 6.8419
[9/12] depth=10, dim=44, heads=1, mlp_mult=4.0
Estimated params: 953,216
Scaling the LR for the AdamW parameters ∝1/√(44/768) = 4.177864
Actual params: 953,216
tok/sec: 41,609
val_bpb: 2.2446
train_loss: 6.8540
[10/12] depth=6, dim=48, heads=1, mlp_mult=4.0
Estimated params: 952,320
Scaling the LR for the AdamW parameters ∝1/√(48/768) = 4.000000
Actual params: 952,320
tok/sec: 57,539
val_bpb: 2.2293
train_loss: 6.8035
[11/12] depth=3, dim=52, heads=1, mlp_mult=4.0
Estimated params: 949,312
Scaling the LR for the AdamW parameters ∝1/√(52/768) = 3.843076
Actual params: 949,312
tok/sec: 83,993
val_bpb: 2.2236
train_loss: 6.8127
[12/12] depth=9, dim=44, heads=1, mlp_mult=4.0
Estimated params: 929,984
Scaling the LR for the AdamW parameters ∝1/√(44/768) = 4.177864
Actual params: 929,984
tok/sec: 43,448
val_bpb: 2.2401
train_loss: 6.8298
================================================================================
SWEEP RESULTS SUMMARY
================================================================================
depth | dim | heads | mlp_m | params | tok/s | val_bpb | train_loss
--------------------------------------------------------------------------------
3 | 56 | 1 | 2.0 | 992,768 | 88,444 | 2.2162 | 6.7881
4 | 52 | 1 | 4.0 | 981,760 | 73,589 | 2.2221 | 6.7687
3 | 52 | 1 | 4.0 | 949,312 | 83,993 | 2.2236 | 6.8127
6 | 52 | 1 | 2.0 | 981,760 | 57,035 | 2.2245 | 6.8004
8 | 48 | 1 | 3.0 | 970,752 | 48,367 | 2.2273 | 6.8150
5 | 52 | 1 | 3.0 | 987,168 | 66,254 | 2.2279 | 6.7844
7 | 48 | 1 | 4.0 | 979,968 | 51,458 | 2.2289 | 6.8111
6 | 48 | 1 | 4.0 | 952,320 | 57,539 | 2.2293 | 6.8035
9 | 48 | 1 | 3.0 | 993,792 | 45,309 | 2.2323 | 6.8131
9 | 44 | 1 | 4.0 | 929,984 | 43,448 | 2.2401 | 6.8298
10 | 44 | 1 | 4.0 | 953,216 | 41,609 | 2.2446 | 6.8540
10 | 48 | 1 | 2.0 | 970,752 | 42,270 | 2.2462 | 6.8419
================================================================================
(Sorted by val_bpb, lower is better)
2025-11-28 14:47:16,472 - nanochat.common - INFO - Distributed world size: 1
W1128 14:47:17.937000 26162 .venv/lib/python3.10/site-packages/torch/_inductor/utils.py:1558] [0/0]
Not enough SMs to use max_autotune_gemm mode
⏺ Results are in. The best val_bpb after 50 steps: 02:47 PM claude-opus-4-5-20251101
| Rank | Depth | Dim | MLP | Params | tok/s | val_bpb |
|------|-------|-----|-----|---------|--------|---------|
| 1 | 3 | 56 | 2.0 | 992,768 | 88,444 | 2.2162 |
| 2 | 4 | 52 | 4.0 | 981,760 | 73,589 | 2.2221 |
| 3 | 3 | 52 | 4.0 | 949,312 | 83,993 | 2.2236 |
Interesting findings:
- Shallow & wide wins - depth=3 with dim=56 got the best val_bpb
and fastest tok/sec (88k)
- Deeper models (9-10 layers) performed worse despite similar
param counts
- All configs used 1 head (the search space favored this at small
dims)
The winner: --depth=3 --model_dim=56 --num_heads=1
--mlp_hidden_mult=2.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment