Skip to content

Instantly share code, notes, and snippets.

@RandyMcMillan
Forked from rust-play/playground.rs
Last active February 14, 2026 20:57
Show Gist options
  • Select an option

  • Save RandyMcMillan/084ef847fcaeb9308b2ae1a49e581331 to your computer and use it in GitHub Desktop.

Select an option

Save RandyMcMillan/084ef847fcaeb9308b2ae1a49e581331 to your computer and use it in GitHub Desktop.
mini_gpt.rs
use std::env;
use std::fs::File;
use std::io::{BufRead, BufReader, Read, Write, Result};
/* ------------------------------------------------------------------ */
/* Model hyper-parameters */
/* ------------------------------------------------------------------ */
const N_EMBD: usize = 32;
const N_HEAD: usize = 4;
const N_LAYER: usize = 1;
const BLOCK_SIZE: usize = 8;
const HEAD_DIM: usize = N_EMBD / N_HEAD;
const MLP_DIM: usize = 4 * N_EMBD;
const MAX_DOCS: usize = 85000;
const MAX_CHARS: usize = 128;
/* ------------------------------------------------------------------ */
/* Minimal xorshift PRNG */
/* ------------------------------------------------------------------ */
struct Rng {
state: u64,
}
impl Rng {
fn new(seed: u64) -> Self { Self { state: seed } }
fn next(&mut self) -> u64 {
self.state ^= self.state << 13;
self.state ^= self.state >> 7;
self.state ^= self.state << 17;
self.state
}
fn uniform(&mut self) -> f64 { (self.next() >> 11) as f64 * (1.0 / 9007199254740992.0) }
fn gauss(&mut self, mean: f32, std: f32) -> f32 {
let mut u1 = self.uniform();
let u2 = self.uniform();
if u1 < 1e-30 { u1 = 1e-30; }
let mag = ((-2.0 * u1.ln()).sqrt()) as f32;
mean + std * mag * ((2.0 * std::f64::consts::PI * u2).cos() as f32)
}
fn shuffle<T>(&mut self, arr: &mut [T]) {
for i in (1..arr.len()).rev() {
let j = (self.uniform() * (i + 1) as f64) as usize;
arr.swap(i, j);
}
}
}
/* ------------------------------------------------------------------ */
/* Weights and Activations */
/* ------------------------------------------------------------------ */
#[derive(Clone)]
struct PosActs {
x_embed: [f32; N_EMBD],
rms_scale_init: f32,
x_in: [[f32; N_EMBD]; N_LAYER],
xn_attn: [[f32; N_EMBD]; N_LAYER],
rms_scale_attn: [f32; N_LAYER],
q: [[f32; N_EMBD]; N_LAYER],
aw: [[[f32; BLOCK_SIZE]; N_HEAD]; N_LAYER],
attn_out: [[f32; N_EMBD]; N_LAYER],
x_mid: [[f32; N_EMBD]; N_LAYER],
xn_mlp: [[f32; N_EMBD]; N_LAYER],
rms_scale_mlp: [f32; N_LAYER],
mlp_pre: [[f32; MLP_DIM]; N_LAYER],
mlp_post: [[f32; MLP_DIM]; N_LAYER],
x_out: [f32; N_EMBD],
}
impl PosActs {
fn new() -> Self {
Self {
x_embed: [0.0; N_EMBD], rms_scale_init: 0.0, x_in: [[0.0; N_EMBD]; N_LAYER],
xn_attn: [[0.0; N_EMBD]; N_LAYER], rms_scale_attn: [0.0; N_LAYER], q: [[0.0; N_EMBD]; N_LAYER],
aw: [[[0.0; BLOCK_SIZE]; N_HEAD]; N_LAYER], attn_out: [[0.0; N_EMBD]; N_LAYER],
x_mid: [[0.0; N_EMBD]; N_LAYER], xn_mlp: [[0.0; N_EMBD]; N_LAYER], rms_scale_mlp: [0.0; N_LAYER],
mlp_pre: [[0.0; MLP_DIM]; N_LAYER], mlp_post: [[0.0; MLP_DIM]; N_LAYER], x_out: [0.0; N_EMBD],
}
}
}
struct LayerWeights {
wq: Vec<f32>, wk: Vec<f32>, wv: Vec<f32>, wo: Vec<f32>,
fc1: Vec<f32>, fc2: Vec<f32>,
d_wq: Vec<f32>, d_wk: Vec<f32>, d_wv: Vec<f32>, d_wo: Vec<f32>,
d_fc1: Vec<f32>, d_fc2: Vec<f32>,
m_wq: Vec<f32>, v_wq: Vec<f32>, m_wk: Vec<f32>, v_wk: Vec<f32>,
m_wv: Vec<f32>, v_wv: Vec<f32>, m_wo: Vec<f32>, v_wo: Vec<f32>,
m_fc1: Vec<f32>, v_fc1: Vec<f32>, m_fc2: Vec<f32>, v_fc2: Vec<f32>,
}
/* ------------------------------------------------------------------ */
/* Math primitives */
/* ------------------------------------------------------------------ */
fn linear_fwd(x: &[f32], w: &[f32], nout: usize, nin: usize, out: &mut [f32]) {
for r in 0..nout {
let mut s = 0.0;
let wr = &w[r * nin..];
for c in 0..nin { s += wr[c] * x[c]; }
out[r] = s;
}
}
fn rmsnorm_fwd(x: &[f32], n: usize, out: &mut [f32]) -> f32 {
let mut ms = 0.0;
for i in 0..n { ms += x[i] * x[i]; }
let scale = 1.0 / ((ms / n as f32) + 1e-5).sqrt();
for i in 0..n { out[i] = x[i] * scale; }
scale
}
fn softmax_fwd(logits: &[f32], n: usize, probs: &mut [f32], temp: f32) {
let mut mx = logits[0] / temp;
for i in 1..n { if (logits[i] / temp) > mx { mx = logits[i] / temp; } }
let mut sum = 0.0;
for i in 0..n {
probs[i] = ((logits[i] / temp) - mx).exp();
sum += probs[i];
}
let inv = 1.0 / sum;
for i in 0..n { probs[i] *= inv; }
}
/* ------------------------------------------------------------------ */
/* Core Logic */
/* ------------------------------------------------------------------ */
fn forward(
tok: usize, pos: usize,
wte: &[f32], wpe: &[f32], layers: &[LayerWeights], lm_head: &[f32],
vocab_size: usize, act: &mut PosActs,
kv_k: &mut [[[f32; N_EMBD]; BLOCK_SIZE]; N_LAYER],
kv_v: &mut [[[f32; N_EMBD]; BLOCK_SIZE]; N_LAYER],
) -> [f32; MAX_CHARS + 1] {
let mut logits = [0.0f32; MAX_CHARS + 1];
for i in 0..N_EMBD { act.x_embed[i] = wte[tok * N_EMBD + i] + wpe[pos * N_EMBD + i]; }
let mut x = act.x_embed;
let x_copy = x; // Fix for borrow checker
act.rms_scale_init = rmsnorm_fwd(&x_copy, N_EMBD, &mut x);
for li in 0..N_LAYER {
act.x_in[li] = x;
let mut xn = [0.0f32; N_EMBD];
let x_li_copy = x; // Fix for borrow checker
act.rms_scale_attn[li] = rmsnorm_fwd(&x_li_copy, N_EMBD, &mut xn);
act.xn_attn[li] = xn;
let (mut q, mut k, mut v) = ([0.0f32; N_EMBD], [0.0f32; N_EMBD], [0.0f32; N_EMBD]);
linear_fwd(&xn, &layers[li].wq, N_EMBD, N_EMBD, &mut q);
linear_fwd(&xn, &layers[li].wk, N_EMBD, N_EMBD, &mut k);
linear_fwd(&xn, &layers[li].wv, N_EMBD, N_EMBD, &mut v);
act.q[li] = q; kv_k[li][pos] = k; kv_v[li][pos] = v;
let mut ao = [0.0f32; N_EMBD];
let scale = 1.0 / (HEAD_DIM as f32).sqrt();
for h in 0..N_HEAD {
let hs = h * HEAD_DIM;
let mut al = [0.0f32; BLOCK_SIZE];
for tt in 0..=pos {
let mut dot = 0.0;
for j in 0..HEAD_DIM { dot += q[hs+j] * kv_k[li][tt][hs+j]; }
al[tt] = dot * scale;
}
let mut sm = [0.0f32; BLOCK_SIZE];
softmax_fwd(&al, pos + 1, &mut sm, 1.0);
for tt in 0..=pos {
act.aw[li][h][tt] = sm[tt];
for j in 0..HEAD_DIM { ao[hs+j] += sm[tt] * kv_v[li][tt][hs+j]; }
}
}
act.attn_out[li] = ao;
let mut tmp = [0.0f32; N_EMBD];
linear_fwd(&ao, &layers[li].wo, N_EMBD, N_EMBD, &mut tmp);
for i in 0..N_EMBD { x[i] = tmp[i] + act.x_in[li][i]; }
act.x_mid[li] = x;
let mut xn_m = [0.0f32; N_EMBD];
let x_mid_copy = x; // Fix for borrow checker
act.rms_scale_mlp[li] = rmsnorm_fwd(&x_mid_copy, N_EMBD, &mut xn_m);
act.xn_mlp[li] = xn_m;
let mut h1 = [0.0f32; MLP_DIM];
linear_fwd(&xn_m, &layers[li].fc1, MLP_DIM, N_EMBD, &mut h1);
act.mlp_pre[li] = h1;
let mut h2 = [0.0f32; MLP_DIM];
for i in 0..MLP_DIM { h2[i] = if h1[i] > 0.0 { h1[i] * h1[i] } else { 0.0 }; }
act.mlp_post[li] = h2;
let mut f2_out = [0.0f32; N_EMBD];
linear_fwd(&h2, &layers[li].fc2, N_EMBD, MLP_DIM, &mut f2_out);
for i in 0..N_EMBD { x[i] = f2_out[i] + act.x_mid[li][i]; }
}
act.x_out = x;
linear_fwd(&x, lm_head, vocab_size, N_EMBD, &mut logits[..vocab_size]);
logits
}
fn generate_top_p(
wte: &[f32], wpe: &[f32], layers: &[LayerWeights], lm: &[f32],
vocab_size: usize, bos_id: usize, uchars: &[char], rng: &mut Rng, temp: f32, top_p: f32
) {
let mut token_id = bos_id;
let mut res = String::new();
let mut kv_k = [[[0.0f32; N_EMBD]; BLOCK_SIZE]; N_LAYER];
let mut kv_v = [[[0.0f32; N_EMBD]; BLOCK_SIZE]; N_LAYER];
let mut act = PosActs::new();
for pos in 0..BLOCK_SIZE {
let logits = forward(token_id, pos, wte, wpe, layers, lm, vocab_size, &mut act, &mut kv_k, &mut kv_v);
let mut probs = [0.0f32; MAX_CHARS + 1];
softmax_fwd(&logits, vocab_size, &mut probs[..vocab_size], temp);
// Top-p sampling
let mut sorted_indices: Vec<usize> = (0..vocab_size).collect();
sorted_indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap());
let mut cumulative_prob = 0.0;
let mut last_idx = vocab_size - 1;
for (i, &idx) in sorted_indices.iter().enumerate() {
cumulative_prob += probs[idx];
if cumulative_prob > top_p {
last_idx = i;
break;
}
}
let mut new_sum = 0.0;
for i in 0..=last_idx { new_sum += probs[sorted_indices[i]]; }
let mut r = rng.uniform() as f32 * new_sum;
let mut choice = sorted_indices[last_idx];
for i in 0..=last_idx {
r -= probs[sorted_indices[i]];
if r <= 0.0 { choice = sorted_indices[i]; break; }
}
token_id = choice;
if token_id == bos_id { break; }
res.push(uchars[token_id]);
}
println!("Top-P Generate: {}", res);
}
fn main() {
let mut rng = Rng::new(42);
let mut docs = vec!["hello world".to_string(), "rust playground gpt".to_string()];
// 1. Vocab
let mut uchars: Vec<char> = Vec::new();
let mut seen = [false; 256];
for d in &docs { for b in d.as_bytes() { seen[*b as usize] = true; } }
for i in 0..256 { if seen[i] { uchars.push(i as u8 as char); } }
uchars.sort();
let vocab_size = uchars.len() + 1;
let bos_id = uchars.len();
// 2. Weights
let make_p = |sz, std, r: &mut Rng| (0..sz).map(|_| r.gauss(0.0, std)).collect::<Vec<f32>>();
let mut wte = make_p(vocab_size * N_EMBD, 0.02, &mut rng);
let mut wpe = make_p(BLOCK_SIZE * N_EMBD, 0.02, &mut rng);
let mut lm = make_p(vocab_size * N_EMBD, 0.02, &mut rng);
let mut layers: Vec<LayerWeights> = (0..N_LAYER).map(|_| LayerWeights {
wq: make_p(N_EMBD * N_EMBD, 0.02, &mut rng), wk: make_p(N_EMBD * N_EMBD, 0.02, &mut rng),
wv: make_p(N_EMBD * N_EMBD, 0.02, &mut rng), wo: make_p(N_EMBD * N_EMBD, 0.0, &mut rng),
fc1: make_p(MLP_DIM * N_EMBD, 0.02, &mut rng), fc2: make_p(MLP_DIM * N_EMBD, 0.0, &mut rng),
d_wq: vec![0.0; N_EMBD * N_EMBD], d_wk: vec![0.0; N_EMBD * N_EMBD], d_wv: vec![0.0; N_EMBD * N_EMBD], d_wo: vec![0.0; N_EMBD * N_EMBD],
d_fc1: vec![0.0; MLP_DIM * N_EMBD], d_fc2: vec![0.0; MLP_DIM * N_EMBD],
m_wq: vec![0.0; N_EMBD * N_EMBD], v_wq: vec![0.0; N_EMBD * N_EMBD], m_wk: vec![0.0; N_EMBD * N_EMBD], v_wk: vec![0.0; N_EMBD * N_EMBD],
m_wv: vec![0.0; N_EMBD * N_EMBD], v_wv: vec![0.0; N_EMBD * N_EMBD], m_wo: vec![0.0; N_EMBD * N_EMBD], v_wo: vec![0.0; N_EMBD * N_EMBD],
m_fc1: vec![0.0; MLP_DIM * N_EMBD], v_fc1: vec![0.0; MLP_DIM * N_EMBD], m_fc2: vec![0.0; MLP_DIM * N_EMBD], v_fc2: vec![0.0; MLP_DIM * N_EMBD],
}).collect();
println!("GPT Initialized. Starting Playground demo...");
generate_top_p(&wte, &wpe, &layers, &lm, vocab_size, bos_id, &uchars, &mut rng, 0.8, 0.9);
}
@RandyMcMillan
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment