Last active
November 28, 2025 20:53
-
-
Save snellingio/6f76ace60d82fa2c8f8ebbfb44411c45 to your computer and use it in GitHub Desktop.
find_model_configs.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Finding configs for ~1,000,000 params (±5%) | |
| Vocab size: 8,192 | |
| ==================================================================================================== | |
| depth | dim | heads | kv_heads | mlp_m | mlp_hid | params | diff | pct | |
| ---------------------------------------------------------------------------------------------------- | |
| 9 | 48 | 6 | 1 | 4.0 | 192 | 1,000,704 | 704 | +0.07% (GQA) | |
| 4 | 56 | 4 | 1 | 2.0 | 112 | 999,040 | 960 | -0.10% (GQA) | |
| 18 | 40 | 1 | 1 | 4.0 | 160 | 1,000,960 | 960 | +0.10% | |
| 18 | 40 | 1 | 1 | 4.0 | 160 | 1,000,960 | 960 | +0.10% | |
| 18 | 40 | 1 | 1 | 4.0 | 160 | 1,000,960 | 960 | +0.10% | |
| 18 | 40 | 2 | 2 | 4.0 | 160 | 1,000,960 | 960 | +0.10% | |
| 18 | 40 | 4 | 4 | 4.0 | 160 | 1,000,960 | 960 | +0.10% | |
| 11 | 48 | 4 | 1 | 3.0 | 144 | 1,001,856 | 1,856 | +0.19% (GQA) | |
| 11 | 48 | 8 | 2 | 3.0 | 144 | 1,001,856 | 1,856 | +0.19% (GQA) | |
| 11 | 48 | 12 | 3 | 3.0 | 144 | 1,001,856 | 1,856 | +0.19% (GQA) | |
| 3 | 56 | 2 | 1 | 3.0 | 168 | 1,002,176 | 2,176 | +0.22% (GQA) | |
| 3 | 56 | 2 | 1 | 3.0 | 168 | 1,002,176 | 2,176 | +0.22% (GQA) | |
| 3 | 56 | 4 | 2 | 3.0 | 168 | 1,002,176 | 2,176 | +0.22% (GQA) | |
| 11 | 48 | 6 | 1 | 3.0 | 144 | 997,632 | 2,368 | -0.24% (GQA) | |
| 3 | 56 | 4 | 1 | 3.0 | 168 | 997,472 | 2,528 | -0.25% (GQA) | |
| 13 | 48 | 2 | 1 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA) | |
| 13 | 48 | 2 | 1 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA) | |
| 13 | 48 | 4 | 2 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA) | |
| 13 | 48 | 6 | 3 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA) | |
| 13 | 48 | 8 | 4 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA) | |
| 13 | 48 | 12 | 6 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA) | |
| 14 | 48 | 4 | 1 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA) | |
| 14 | 48 | 8 | 2 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA) | |
| 14 | 48 | 12 | 3 | 2.0 | 96 | 996,096 | 3,904 | -0.39% (GQA) | |
| 9 | 48 | 4 | 1 | 4.0 | 192 | 1,004,160 | 4,160 | +0.42% (GQA) | |
| 9 | 48 | 8 | 2 | 4.0 | 192 | 1,004,160 | 4,160 | +0.42% (GQA) | |
| 9 | 48 | 12 | 3 | 4.0 | 192 | 1,004,160 | 4,160 | +0.42% (GQA) | |
| 4 | 56 | 2 | 1 | 2.0 | 112 | 1,005,312 | 5,312 | +0.53% (GQA) | |
| 4 | 56 | 2 | 1 | 2.0 | 112 | 1,005,312 | 5,312 | +0.53% (GQA) | |
| 4 | 56 | 4 | 2 | 2.0 | 112 | 1,005,312 | 5,312 | +0.53% (GQA) | |
| ==================================================================================================== | |
| Example CLI commands for top 3 configs: | |
| ---------------------------------------------------------------------------------------------------- | |
| 1. 1,000,704 params: | |
| uv run python -m scripts.base_train --depth=9 --model_dim=48 --num_heads=6 --num_kv_heads=1 --mlp_hidden_mult=4.0 | |
| 2. 999,040 params: | |
| uv run python -m scripts.base_train --depth=4 --model_dim=56 --num_heads=4 --num_kv_heads=1 --mlp_hidden_mult=2.0 | |
| 3. 1,000,960 params: | |
| uv run python -m scripts.base_train --depth=18 --model_dim=40 --num_heads=1 --mlp_hidden_mult=4.0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Find model configurations that match a target parameter count. | |
| Usage: | |
| uv run python -m scripts.find_model_configs --target=1000000 | |
| uv run python -m scripts.find_model_configs --target=10000000 --vocab_size=32000 | |
| uv run python -m scripts.find_model_configs --target=1000000 --tolerance=0.05 | |
| """ | |
| import argparse | |
| def count_params(vocab_size, n_layer, n_embd, n_head, n_kv_head, mlp_hidden_dim): | |
| """Count total parameters for the model.""" | |
| head_dim = n_embd // n_head | |
| # Embeddings | |
| wte = vocab_size * n_embd | |
| lm_head = n_embd * vocab_size | |
| # Per block | |
| c_q = n_embd * n_embd | |
| c_k = n_embd * (n_kv_head * head_dim) | |
| c_v = n_embd * (n_kv_head * head_dim) | |
| c_proj = n_embd * n_embd | |
| mlp_fc = n_embd * mlp_hidden_dim | |
| mlp_proj = mlp_hidden_dim * n_embd | |
| block_params = c_q + c_k + c_v + c_proj + mlp_fc + mlp_proj | |
| total = wte + lm_head + n_layer * block_params | |
| return total | |
| def find_configs(target_params, vocab_size, tolerance=0.05, max_results=30): | |
| """Find model configurations near the target parameter count.""" | |
| min_params = target_params * (1 - tolerance) | |
| max_params = target_params * (1 + tolerance) | |
| # Determine search ranges based on target size | |
| if target_params < 1_000_000: | |
| depth_range = range(2, 16) | |
| dim_range = range(16, 256, 4) | |
| elif target_params < 10_000_000: | |
| depth_range = range(2, 24) | |
| dim_range = range(32, 512, 8) | |
| elif target_params < 100_000_000: | |
| depth_range = range(4, 32) | |
| dim_range = range(64, 1024, 16) | |
| else: | |
| depth_range = range(6, 48) | |
| dim_range = range(128, 2048, 32) | |
| results = [] | |
| for depth in depth_range: | |
| for dim in dim_range: | |
| if dim % 2 != 0: | |
| continue | |
| for n_heads in [1, 2, 4, 6, 8, 12, 16]: | |
| if dim % n_heads != 0: | |
| continue | |
| head_dim = dim // n_heads | |
| if head_dim % 2 != 0: | |
| continue | |
| # Also try GQA configs | |
| for n_kv_heads in [n_heads, max(1, n_heads // 2), max(1, n_heads // 4)]: | |
| if n_kv_heads > n_heads or n_heads % n_kv_heads != 0: | |
| continue | |
| for mlp_mult in [2.0, 3.0, 4.0]: | |
| mlp_hidden = int(dim * mlp_mult) | |
| params = count_params(vocab_size, depth, dim, n_heads, n_kv_heads, mlp_hidden) | |
| if min_params <= params <= max_params: | |
| diff = abs(params - target_params) | |
| pct_diff = 100 * (params - target_params) / target_params | |
| results.append({ | |
| "depth": depth, | |
| "model_dim": dim, | |
| "num_heads": n_heads, | |
| "num_kv_heads": n_kv_heads, | |
| "mlp_mult": mlp_mult, | |
| "mlp_hidden": mlp_hidden, | |
| "params": params, | |
| "diff": diff, | |
| "pct_diff": pct_diff, | |
| }) | |
| # Sort by absolute difference from target | |
| results.sort(key=lambda x: x["diff"]) | |
| return results[:max_results] | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Find model configurations for a target parameter count") | |
| parser.add_argument("--target", type=int, default=1_000_000, help="Target number of parameters") | |
| parser.add_argument("--vocab_size", type=int, default=8192, help="Vocabulary size") | |
| parser.add_argument("--tolerance", type=float, default=0.05, help="Tolerance as fraction (0.05 = 5%%)") | |
| parser.add_argument("--max_results", type=int, default=30, help="Maximum number of results to show") | |
| args = parser.parse_args() | |
| print(f"Finding configs for ~{args.target:,} params (±{args.tolerance*100:.0f}%)") | |
| print(f"Vocab size: {args.vocab_size:,}") | |
| print("=" * 100) | |
| results = find_configs(args.target, args.vocab_size, args.tolerance, args.max_results) | |
| if not results: | |
| print("No configurations found. Try increasing --tolerance or adjusting --target.") | |
| return | |
| print(f"{'depth':>5} | {'dim':>5} | {'heads':>5} | {'kv_heads':>8} | {'mlp_m':>5} | {'mlp_hid':>7} | {'params':>12} | {'diff':>10} | {'pct':>6}") | |
| print("-" * 100) | |
| for r in results: | |
| gqa_str = "" if r["num_heads"] == r["num_kv_heads"] else " (GQA)" | |
| print(f"{r['depth']:>5} | {r['model_dim']:>5} | {r['num_heads']:>5} | {r['num_kv_heads']:>8} | {r['mlp_mult']:>5.1f} | {r['mlp_hidden']:>7} | {r['params']:>12,} | {r['diff']:>10,} | {r['pct_diff']:>+5.2f}%{gqa_str}") | |
| print("=" * 100) | |
| # Print CLI commands for the top 3 | |
| print("\nExample CLI commands for top 3 configs:") | |
| print("-" * 100) | |
| for i, r in enumerate(results[:3]): | |
| cmd = f"uv run python -m scripts.base_train --depth={r['depth']} --model_dim={r['model_dim']} --num_heads={r['num_heads']}" | |
| if r["num_kv_heads"] != r["num_heads"]: | |
| cmd += f" --num_kv_heads={r['num_kv_heads']}" | |
| cmd += f" --mlp_hidden_mult={r['mlp_mult']}" | |
| print(f"{i+1}. {r['params']:,} params:") | |
| print(f" {cmd}") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Sweep script to explore small model configurations under 1M parameters. | |
| Reports tok/sec, bpb, and val loss after ~50 steps. | |
| Usage: | |
| uv run python -m scripts.small_model_sweep | |
| """ | |
| import os | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| import time | |
| from contextlib import nullcontext | |
| from dataclasses import dataclass | |
| import torch | |
| from nanochat.gpt import GPT, GPTConfig | |
| from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state | |
| from nanochat.common import compute_init, compute_cleanup, print0, autodetect_device_type | |
| from nanochat.tokenizer import get_tokenizer, get_token_bytes | |
| from nanochat.loss_eval import evaluate_bpb | |
| @dataclass | |
| class SweepConfig: | |
| depth: int | |
| model_dim: int | |
| num_heads: int | |
| mlp_hidden_mult: float | |
| def count_params(vocab_size, n_layer, n_embd, n_head, mlp_hidden_dim): | |
| """Count total parameters for the model.""" | |
| head_dim = n_embd // n_head | |
| wte = vocab_size * n_embd | |
| lm_head = n_embd * vocab_size | |
| # Per block: c_q, c_k, c_v, c_proj (attention) + c_fc, c_proj (MLP) | |
| c_q = n_embd * n_embd | |
| c_k = n_embd * (n_head * head_dim) # assuming n_kv_head == n_head | |
| c_v = n_embd * (n_head * head_dim) | |
| c_proj = n_embd * n_embd | |
| mlp_fc = n_embd * mlp_hidden_dim | |
| mlp_proj = mlp_hidden_dim * n_embd | |
| block_params = c_q + c_k + c_v + c_proj + mlp_fc + mlp_proj | |
| return wte + lm_head + n_layer * block_params | |
| def generate_sweep_configs(vocab_size, max_params=1_000_000): | |
| """Generate valid model configurations under the parameter budget.""" | |
| configs = [] | |
| for depth in [3, 4, 5, 6, 7, 8, 9, 10]: | |
| for dim in range(32, 128, 4): | |
| if dim % 2 != 0: | |
| continue | |
| for n_heads in [1, 2, 4]: | |
| if dim % n_heads != 0: | |
| continue | |
| head_dim = dim // n_heads | |
| if head_dim % 2 != 0: | |
| continue | |
| for mlp_mult in [2.0, 3.0, 4.0]: | |
| mlp_hidden = int(dim * mlp_mult) | |
| params = count_params(vocab_size, depth, dim, n_heads, mlp_hidden) | |
| if params <= max_params: | |
| configs.append((params, SweepConfig(depth, dim, n_heads, mlp_mult))) | |
| # Sort by param count descending (prefer larger models within budget) | |
| configs.sort(key=lambda x: -x[0]) | |
| # Deduplicate and pick diverse configs | |
| seen = set() | |
| unique = [] | |
| for params, cfg in configs: | |
| key = (cfg.depth, cfg.model_dim) | |
| if key not in seen: | |
| seen.add(key) | |
| unique.append((params, cfg)) | |
| return unique[:12] # Top 12 configs | |
| def train_config(cfg: SweepConfig, device, device_type, vocab_size, token_bytes, | |
| num_iterations=50, max_seq_len=512, device_batch_size=4, total_batch_size=2048, | |
| eval_tokens=8192, data_dir=None): | |
| """Train a single configuration and return metrics.""" | |
| mlp_hidden_dim = int(cfg.model_dim * cfg.mlp_hidden_mult) | |
| model_config = GPTConfig( | |
| sequence_len=max_seq_len, | |
| vocab_size=vocab_size, | |
| n_layer=cfg.depth, | |
| n_head=cfg.num_heads, | |
| n_kv_head=cfg.num_heads, | |
| n_embd=cfg.model_dim, | |
| mlp_hidden_mult=cfg.mlp_hidden_mult, | |
| mlp_hidden_dim=mlp_hidden_dim, | |
| ) | |
| with torch.device("meta"): | |
| model = GPT(model_config) | |
| model.to_empty(device=device) | |
| model.init_weights() | |
| num_params = sum(p.numel() for p in model.parameters()) | |
| # Compile if possible | |
| if device_type == "cuda" and not os.environ.get("NANOCHAT_NO_COMPILE", ""): | |
| model = torch.compile(model, dynamic=False) | |
| # Setup optimizer | |
| optimizers = model.setup_optimizers(unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0) | |
| # Setup dataloader | |
| train_loader = tokenizing_distributed_data_loader_with_state( | |
| device_batch_size, max_seq_len, split="train", device=device, | |
| resume_state_dict=None, data_dir=data_dir, | |
| ) | |
| # Gradient accumulation | |
| tokens_per_fwdbwd = device_batch_size * max_seq_len | |
| grad_accum_steps = max(1, total_batch_size // tokens_per_fwdbwd) | |
| # Autocast context | |
| if device_type == "cuda": | |
| autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) | |
| elif device_type == "mps": | |
| autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) | |
| else: | |
| autocast_ctx = nullcontext() | |
| synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None | |
| # Training loop | |
| x, y, _ = next(train_loader) | |
| total_tokens = 0 | |
| total_time = 0.0 | |
| train_losses = [] | |
| model.train() | |
| for step in range(num_iterations): | |
| synchronize() | |
| t0 = time.time() | |
| for micro_step in range(grad_accum_steps): | |
| with autocast_ctx: | |
| loss = model(x, y) | |
| train_loss = loss.detach().item() | |
| loss = loss / grad_accum_steps | |
| loss.backward() | |
| x, y, _ = next(train_loader) | |
| # Clip gradients | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| # Step optimizers | |
| for opt in optimizers: | |
| opt.step() | |
| model.zero_grad(set_to_none=True) | |
| synchronize() | |
| t1 = time.time() | |
| dt = t1 - t0 | |
| if step >= 5: # Skip warmup steps | |
| total_time += dt | |
| total_tokens += total_batch_size | |
| train_losses.append(train_loss) | |
| # Evaluate | |
| model.eval() | |
| val_loader = tokenizing_distributed_data_loader( | |
| device_batch_size, max_seq_len, split="val", device=device, data_dir=data_dir, | |
| ) | |
| eval_steps = max(1, eval_tokens // (device_batch_size * max_seq_len)) | |
| with autocast_ctx: | |
| val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) | |
| # Compute metrics | |
| tok_per_sec = total_tokens / total_time if total_time > 0 else 0 | |
| avg_train_loss = sum(train_losses[-10:]) / len(train_losses[-10:]) # Last 10 steps | |
| return { | |
| "num_params": num_params, | |
| "tok_per_sec": tok_per_sec, | |
| "val_bpb": val_bpb, | |
| "train_loss": avg_train_loss, | |
| } | |
| def main(): | |
| # Initialize compute | |
| device_type = autodetect_device_type() | |
| ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) | |
| # Tokenizer | |
| tokenizer = get_tokenizer() | |
| token_bytes = get_token_bytes(device=device) | |
| vocab_size = tokenizer.get_vocab_size() | |
| print0(f"Device: {device}, Vocab size: {vocab_size}") | |
| print0("=" * 80) | |
| # Generate sweep configs | |
| configs = generate_sweep_configs(vocab_size, max_params=1_000_000) | |
| print0(f"Running sweep over {len(configs)} configurations...") | |
| print0("=" * 80) | |
| results = [] | |
| for i, (est_params, cfg) in enumerate(configs): | |
| print0(f"\n[{i+1}/{len(configs)}] depth={cfg.depth}, dim={cfg.model_dim}, heads={cfg.num_heads}, mlp_mult={cfg.mlp_hidden_mult}") | |
| print0(f" Estimated params: {est_params:,}") | |
| try: | |
| metrics = train_config( | |
| cfg, device, device_type, vocab_size, token_bytes, | |
| num_iterations=50, | |
| max_seq_len=512, | |
| device_batch_size=4, | |
| total_batch_size=2048, | |
| eval_tokens=8192, | |
| ) | |
| results.append((cfg, metrics)) | |
| print0(f" Actual params: {metrics['num_params']:,}") | |
| print0(f" tok/sec: {metrics['tok_per_sec']:,.0f}") | |
| print0(f" val_bpb: {metrics['val_bpb']:.4f}") | |
| print0(f" train_loss: {metrics['train_loss']:.4f}") | |
| except Exception as e: | |
| print0(f" FAILED: {e}") | |
| # Clear memory | |
| torch.cuda.empty_cache() if device_type == "cuda" else None | |
| # Print summary table | |
| print0("\n" + "=" * 80) | |
| print0("SWEEP RESULTS SUMMARY") | |
| print0("=" * 80) | |
| print0(f"{'depth':>5} | {'dim':>4} | {'heads':>5} | {'mlp_m':>5} | {'params':>10} | {'tok/s':>10} | {'val_bpb':>8} | {'train_loss':>10}") | |
| print0("-" * 80) | |
| # Sort by val_bpb (lower is better) | |
| results.sort(key=lambda x: x[1]['val_bpb']) | |
| for cfg, metrics in results: | |
| print0(f"{cfg.depth:>5} | {cfg.model_dim:>4} | {cfg.num_heads:>5} | {cfg.mlp_hidden_mult:>5.1f} | {metrics['num_params']:>10,} | {metrics['tok_per_sec']:>10,.0f} | {metrics['val_bpb']:>8.4f} | {metrics['train_loss']:>10.4f}") | |
| print0("=" * 80) | |
| print0("(Sorted by val_bpb, lower is better)") | |
| compute_cleanup() | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Device: mps, Vocab size: 8192 | |
| ================================================================================ | |
| Running sweep over 12 configurations... | |
| ================================================================================ | |
| [1/12] depth=9, dim=48, heads=1, mlp_mult=3.0 | |
| Estimated params: 993,792 | |
| Scaling the LR for the AdamW parameters ∝1/√(48/768) = 4.000000 | |
| Actual params: 993,792 | |
| tok/sec: 45,309 | |
| val_bpb: 2.2323 | |
| train_loss: 6.8131 | |
| [2/12] depth=3, dim=56, heads=1, mlp_mult=2.0 | |
| Estimated params: 992,768 | |
| Scaling the LR for the AdamW parameters ∝1/√(56/768) = 3.703280 | |
| Actual params: 992,768 | |
| tok/sec: 88,444 | |
| val_bpb: 2.2162 | |
| train_loss: 6.7881 | |
| [3/12] depth=5, dim=52, heads=1, mlp_mult=3.0 | |
| Estimated params: 987,168 | |
| Scaling the LR for the AdamW parameters ∝1/√(52/768) = 3.843076 | |
| Actual params: 987,168 | |
| tok/sec: 66,254 | |
| val_bpb: 2.2279 | |
| train_loss: 6.7844 | |
| [4/12] depth=4, dim=52, heads=1, mlp_mult=4.0 | |
| Estimated params: 981,760 | |
| Scaling the LR for the AdamW parameters ∝1/√(52/768) = 3.843076 | |
| Actual params: 981,760 | |
| tok/sec: 73,589 | |
| val_bpb: 2.2221 | |
| train_loss: 6.7687 | |
| [5/12] depth=6, dim=52, heads=1, mlp_mult=2.0 | |
| Estimated params: 981,760 | |
| Scaling the LR for the AdamW parameters ∝1/√(52/768) = 3.843076 | |
| Actual params: 981,760 | |
| tok/sec: 57,035 | |
| val_bpb: 2.2245 | |
| train_loss: 6.8004 | |
| [6/12] depth=7, dim=48, heads=1, mlp_mult=4.0 | |
| Estimated params: 979,968 | |
| Scaling the LR for the AdamW parameters ∝1/√(48/768) = 4.000000 | |
| Actual params: 979,968 | |
| tok/sec: 51,458 | |
| val_bpb: 2.2289 | |
| train_loss: 6.8111 | |
| [7/12] depth=8, dim=48, heads=1, mlp_mult=3.0 | |
| Estimated params: 970,752 | |
| Scaling the LR for the AdamW parameters ∝1/√(48/768) = 4.000000 | |
| Actual params: 970,752 | |
| tok/sec: 48,367 | |
| val_bpb: 2.2273 | |
| train_loss: 6.8150 | |
| [8/12] depth=10, dim=48, heads=1, mlp_mult=2.0 | |
| Estimated params: 970,752 | |
| Scaling the LR for the AdamW parameters ∝1/√(48/768) = 4.000000 | |
| Actual params: 970,752 | |
| tok/sec: 42,270 | |
| val_bpb: 2.2462 | |
| train_loss: 6.8419 | |
| [9/12] depth=10, dim=44, heads=1, mlp_mult=4.0 | |
| Estimated params: 953,216 | |
| Scaling the LR for the AdamW parameters ∝1/√(44/768) = 4.177864 | |
| Actual params: 953,216 | |
| tok/sec: 41,609 | |
| val_bpb: 2.2446 | |
| train_loss: 6.8540 | |
| [10/12] depth=6, dim=48, heads=1, mlp_mult=4.0 | |
| Estimated params: 952,320 | |
| Scaling the LR for the AdamW parameters ∝1/√(48/768) = 4.000000 | |
| Actual params: 952,320 | |
| tok/sec: 57,539 | |
| val_bpb: 2.2293 | |
| train_loss: 6.8035 | |
| [11/12] depth=3, dim=52, heads=1, mlp_mult=4.0 | |
| Estimated params: 949,312 | |
| Scaling the LR for the AdamW parameters ∝1/√(52/768) = 3.843076 | |
| Actual params: 949,312 | |
| tok/sec: 83,993 | |
| val_bpb: 2.2236 | |
| train_loss: 6.8127 | |
| [12/12] depth=9, dim=44, heads=1, mlp_mult=4.0 | |
| Estimated params: 929,984 | |
| Scaling the LR for the AdamW parameters ∝1/√(44/768) = 4.177864 | |
| Actual params: 929,984 | |
| tok/sec: 43,448 | |
| val_bpb: 2.2401 | |
| train_loss: 6.8298 | |
| ================================================================================ | |
| SWEEP RESULTS SUMMARY | |
| ================================================================================ | |
| depth | dim | heads | mlp_m | params | tok/s | val_bpb | train_loss | |
| -------------------------------------------------------------------------------- | |
| 3 | 56 | 1 | 2.0 | 992,768 | 88,444 | 2.2162 | 6.7881 | |
| 4 | 52 | 1 | 4.0 | 981,760 | 73,589 | 2.2221 | 6.7687 | |
| 3 | 52 | 1 | 4.0 | 949,312 | 83,993 | 2.2236 | 6.8127 | |
| 6 | 52 | 1 | 2.0 | 981,760 | 57,035 | 2.2245 | 6.8004 | |
| 8 | 48 | 1 | 3.0 | 970,752 | 48,367 | 2.2273 | 6.8150 | |
| 5 | 52 | 1 | 3.0 | 987,168 | 66,254 | 2.2279 | 6.7844 | |
| 7 | 48 | 1 | 4.0 | 979,968 | 51,458 | 2.2289 | 6.8111 | |
| 6 | 48 | 1 | 4.0 | 952,320 | 57,539 | 2.2293 | 6.8035 | |
| 9 | 48 | 1 | 3.0 | 993,792 | 45,309 | 2.2323 | 6.8131 | |
| 9 | 44 | 1 | 4.0 | 929,984 | 43,448 | 2.2401 | 6.8298 | |
| 10 | 44 | 1 | 4.0 | 953,216 | 41,609 | 2.2446 | 6.8540 | |
| 10 | 48 | 1 | 2.0 | 970,752 | 42,270 | 2.2462 | 6.8419 | |
| ================================================================================ | |
| (Sorted by val_bpb, lower is better) | |
| 2025-11-28 14:47:16,472 - nanochat.common - INFO - Distributed world size: 1 | |
| W1128 14:47:17.937000 26162 .venv/lib/python3.10/site-packages/torch/_inductor/utils.py:1558] [0/0] | |
| Not enough SMs to use max_autotune_gemm mode | |
| ⏺ Results are in. The best val_bpb after 50 steps: 02:47 PM claude-opus-4-5-20251101 | |
| | Rank | Depth | Dim | MLP | Params | tok/s | val_bpb | | |
| |------|-------|-----|-----|---------|--------|---------| | |
| | 1 | 3 | 56 | 2.0 | 992,768 | 88,444 | 2.2162 | | |
| | 2 | 4 | 52 | 4.0 | 981,760 | 73,589 | 2.2221 | | |
| | 3 | 3 | 52 | 4.0 | 949,312 | 83,993 | 2.2236 | | |
| Interesting findings: | |
| - Shallow & wide wins - depth=3 with dim=56 got the best val_bpb | |
| and fastest tok/sec (88k) | |
| - Deeper models (9-10 layers) performed worse despite similar | |
| param counts | |
| - All configs used 1 head (the search space favored this at small | |
| dims) | |
| The winner: --depth=3 --model_dim=56 --num_heads=1 | |
| --mlp_hidden_mult=2.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment