-
-
Save rust-play/2c0f9621c0c824e0fc61d86ded4f88a9 to your computer and use it in GitHub Desktop.
Code shared from the Rust Playground
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use std::env; | |
| use std::fs::File; | |
| use std::io::{BufRead, BufReader, Read, Write, Result}; | |
| /* ------------------------------------------------------------------ */ | |
| /* Model hyper-parameters */ | |
| /* ------------------------------------------------------------------ */ | |
| const N_EMBD: usize = 32; | |
| const N_HEAD: usize = 4; | |
| const N_LAYER: usize = 1; | |
| const BLOCK_SIZE: usize = 8; | |
| const HEAD_DIM: usize = N_EMBD / N_HEAD; | |
| const MLP_DIM: usize = 4 * N_EMBD; | |
| const MAX_DOCS: usize = 85000; | |
| const MAX_CHARS: usize = 128; | |
| /* ------------------------------------------------------------------ */ | |
| /* Minimal xorshift PRNG */ | |
| /* ------------------------------------------------------------------ */ | |
| struct Rng { | |
| state: u64, | |
| } | |
| impl Rng { | |
| fn new(seed: u64) -> Self { Self { state: seed } } | |
| fn next(&mut self) -> u64 { | |
| self.state ^= self.state << 13; | |
| self.state ^= self.state >> 7; | |
| self.state ^= self.state << 17; | |
| self.state | |
| } | |
| fn uniform(&mut self) -> f64 { (self.next() >> 11) as f64 * (1.0 / 9007199254740992.0) } | |
| fn gauss(&mut self, mean: f32, std: f32) -> f32 { | |
| let mut u1 = self.uniform(); | |
| let u2 = self.uniform(); | |
| if u1 < 1e-30 { u1 = 1e-30; } | |
| let mag = ((-2.0 * u1.ln()).sqrt()) as f32; | |
| mean + std * mag * ((2.0 * std::f64::consts::PI * u2).cos() as f32) | |
| } | |
| fn shuffle<T>(&mut self, arr: &mut [T]) { | |
| for i in (1..arr.len()).rev() { | |
| let j = (self.uniform() * (i + 1) as f64) as usize; | |
| arr.swap(i, j); | |
| } | |
| } | |
| } | |
| /* ------------------------------------------------------------------ */ | |
| /* Weights and Activations */ | |
| /* ------------------------------------------------------------------ */ | |
| #[derive(Clone)] | |
| struct PosActs { | |
| x_embed: [f32; N_EMBD], | |
| rms_scale_init: f32, | |
| x_in: [[f32; N_EMBD]; N_LAYER], | |
| xn_attn: [[f32; N_EMBD]; N_LAYER], | |
| rms_scale_attn: [f32; N_LAYER], | |
| q: [[f32; N_EMBD]; N_LAYER], | |
| aw: [[[f32; BLOCK_SIZE]; N_HEAD]; N_LAYER], | |
| attn_out: [[f32; N_EMBD]; N_LAYER], | |
| x_mid: [[f32; N_EMBD]; N_LAYER], | |
| xn_mlp: [[f32; N_EMBD]; N_LAYER], | |
| rms_scale_mlp: [f32; N_LAYER], | |
| mlp_pre: [[f32; MLP_DIM]; N_LAYER], | |
| mlp_post: [[f32; MLP_DIM]; N_LAYER], | |
| x_out: [f32; N_EMBD], | |
| } | |
| impl PosActs { | |
| fn new() -> Self { | |
| Self { | |
| x_embed: [0.0; N_EMBD], rms_scale_init: 0.0, x_in: [[0.0; N_EMBD]; N_LAYER], | |
| xn_attn: [[0.0; N_EMBD]; N_LAYER], rms_scale_attn: [0.0; N_LAYER], q: [[0.0; N_EMBD]; N_LAYER], | |
| aw: [[[0.0; BLOCK_SIZE]; N_HEAD]; N_LAYER], attn_out: [[0.0; N_EMBD]; N_LAYER], | |
| x_mid: [[0.0; N_EMBD]; N_LAYER], xn_mlp: [[0.0; N_EMBD]; N_LAYER], rms_scale_mlp: [0.0; N_LAYER], | |
| mlp_pre: [[0.0; MLP_DIM]; N_LAYER], mlp_post: [[0.0; MLP_DIM]; N_LAYER], x_out: [0.0; N_EMBD], | |
| } | |
| } | |
| } | |
| struct LayerWeights { | |
| wq: Vec<f32>, wk: Vec<f32>, wv: Vec<f32>, wo: Vec<f32>, | |
| fc1: Vec<f32>, fc2: Vec<f32>, | |
| d_wq: Vec<f32>, d_wk: Vec<f32>, d_wv: Vec<f32>, d_wo: Vec<f32>, | |
| d_fc1: Vec<f32>, d_fc2: Vec<f32>, | |
| m_wq: Vec<f32>, v_wq: Vec<f32>, m_wk: Vec<f32>, v_wk: Vec<f32>, | |
| m_wv: Vec<f32>, v_wv: Vec<f32>, m_wo: Vec<f32>, v_wo: Vec<f32>, | |
| m_fc1: Vec<f32>, v_fc1: Vec<f32>, m_fc2: Vec<f32>, v_fc2: Vec<f32>, | |
| } | |
| /* ------------------------------------------------------------------ */ | |
| /* Math primitives */ | |
| /* ------------------------------------------------------------------ */ | |
| fn linear_fwd(x: &[f32], w: &[f32], nout: usize, nin: usize, out: &mut [f32]) { | |
| for r in 0..nout { | |
| let mut s = 0.0; | |
| let wr = &w[r * nin..]; | |
| for c in 0..nin { s += wr[c] * x[c]; } | |
| out[r] = s; | |
| } | |
| } | |
| fn rmsnorm_fwd(x: &[f32], n: usize, out: &mut [f32]) -> f32 { | |
| let mut ms = 0.0; | |
| for i in 0..n { ms += x[i] * x[i]; } | |
| let scale = 1.0 / ((ms / n as f32) + 1e-5).sqrt(); | |
| for i in 0..n { out[i] = x[i] * scale; } | |
| scale | |
| } | |
| fn softmax_fwd(logits: &[f32], n: usize, probs: &mut [f32], temp: f32) { | |
| let mut mx = logits[0] / temp; | |
| for i in 1..n { if (logits[i] / temp) > mx { mx = logits[i] / temp; } } | |
| let mut sum = 0.0; | |
| for i in 0..n { | |
| probs[i] = ((logits[i] / temp) - mx).exp(); | |
| sum += probs[i]; | |
| } | |
| let inv = 1.0 / sum; | |
| for i in 0..n { probs[i] *= inv; } | |
| } | |
| /* ------------------------------------------------------------------ */ | |
| /* Core Logic */ | |
| /* ------------------------------------------------------------------ */ | |
| fn forward( | |
| tok: usize, pos: usize, | |
| wte: &[f32], wpe: &[f32], layers: &[LayerWeights], lm_head: &[f32], | |
| vocab_size: usize, act: &mut PosActs, | |
| kv_k: &mut [[[f32; N_EMBD]; BLOCK_SIZE]; N_LAYER], | |
| kv_v: &mut [[[f32; N_EMBD]; BLOCK_SIZE]; N_LAYER], | |
| ) -> [f32; MAX_CHARS + 1] { | |
| let mut logits = [0.0f32; MAX_CHARS + 1]; | |
| for i in 0..N_EMBD { act.x_embed[i] = wte[tok * N_EMBD + i] + wpe[pos * N_EMBD + i]; } | |
| let mut x = act.x_embed; | |
| let x_copy = x; // Fix for borrow checker | |
| act.rms_scale_init = rmsnorm_fwd(&x_copy, N_EMBD, &mut x); | |
| for li in 0..N_LAYER { | |
| act.x_in[li] = x; | |
| let mut xn = [0.0f32; N_EMBD]; | |
| let x_li_copy = x; // Fix for borrow checker | |
| act.rms_scale_attn[li] = rmsnorm_fwd(&x_li_copy, N_EMBD, &mut xn); | |
| act.xn_attn[li] = xn; | |
| let (mut q, mut k, mut v) = ([0.0f32; N_EMBD], [0.0f32; N_EMBD], [0.0f32; N_EMBD]); | |
| linear_fwd(&xn, &layers[li].wq, N_EMBD, N_EMBD, &mut q); | |
| linear_fwd(&xn, &layers[li].wk, N_EMBD, N_EMBD, &mut k); | |
| linear_fwd(&xn, &layers[li].wv, N_EMBD, N_EMBD, &mut v); | |
| act.q[li] = q; kv_k[li][pos] = k; kv_v[li][pos] = v; | |
| let mut ao = [0.0f32; N_EMBD]; | |
| let scale = 1.0 / (HEAD_DIM as f32).sqrt(); | |
| for h in 0..N_HEAD { | |
| let hs = h * HEAD_DIM; | |
| let mut al = [0.0f32; BLOCK_SIZE]; | |
| for tt in 0..=pos { | |
| let mut dot = 0.0; | |
| for j in 0..HEAD_DIM { dot += q[hs+j] * kv_k[li][tt][hs+j]; } | |
| al[tt] = dot * scale; | |
| } | |
| let mut sm = [0.0f32; BLOCK_SIZE]; | |
| softmax_fwd(&al, pos + 1, &mut sm, 1.0); | |
| for tt in 0..=pos { | |
| act.aw[li][h][tt] = sm[tt]; | |
| for j in 0..HEAD_DIM { ao[hs+j] += sm[tt] * kv_v[li][tt][hs+j]; } | |
| } | |
| } | |
| act.attn_out[li] = ao; | |
| let mut tmp = [0.0f32; N_EMBD]; | |
| linear_fwd(&ao, &layers[li].wo, N_EMBD, N_EMBD, &mut tmp); | |
| for i in 0..N_EMBD { x[i] = tmp[i] + act.x_in[li][i]; } | |
| act.x_mid[li] = x; | |
| let mut xn_m = [0.0f32; N_EMBD]; | |
| let x_mid_copy = x; // Fix for borrow checker | |
| act.rms_scale_mlp[li] = rmsnorm_fwd(&x_mid_copy, N_EMBD, &mut xn_m); | |
| act.xn_mlp[li] = xn_m; | |
| let mut h1 = [0.0f32; MLP_DIM]; | |
| linear_fwd(&xn_m, &layers[li].fc1, MLP_DIM, N_EMBD, &mut h1); | |
| act.mlp_pre[li] = h1; | |
| let mut h2 = [0.0f32; MLP_DIM]; | |
| for i in 0..MLP_DIM { h2[i] = if h1[i] > 0.0 { h1[i] * h1[i] } else { 0.0 }; } | |
| act.mlp_post[li] = h2; | |
| let mut f2_out = [0.0f32; N_EMBD]; | |
| linear_fwd(&h2, &layers[li].fc2, N_EMBD, MLP_DIM, &mut f2_out); | |
| for i in 0..N_EMBD { x[i] = f2_out[i] + act.x_mid[li][i]; } | |
| } | |
| act.x_out = x; | |
| linear_fwd(&x, lm_head, vocab_size, N_EMBD, &mut logits[..vocab_size]); | |
| logits | |
| } | |
| fn generate_top_p( | |
| wte: &[f32], wpe: &[f32], layers: &[LayerWeights], lm: &[f32], | |
| vocab_size: usize, bos_id: usize, uchars: &[char], rng: &mut Rng, temp: f32, top_p: f32 | |
| ) { | |
| let mut token_id = bos_id; | |
| let mut res = String::new(); | |
| let mut kv_k = [[[0.0f32; N_EMBD]; BLOCK_SIZE]; N_LAYER]; | |
| let mut kv_v = [[[0.0f32; N_EMBD]; BLOCK_SIZE]; N_LAYER]; | |
| let mut act = PosActs::new(); | |
| for pos in 0..BLOCK_SIZE { | |
| let logits = forward(token_id, pos, wte, wpe, layers, lm, vocab_size, &mut act, &mut kv_k, &mut kv_v); | |
| let mut probs = [0.0f32; MAX_CHARS + 1]; | |
| softmax_fwd(&logits, vocab_size, &mut probs[..vocab_size], temp); | |
| // Top-p sampling | |
| let mut sorted_indices: Vec<usize> = (0..vocab_size).collect(); | |
| sorted_indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap()); | |
| let mut cumulative_prob = 0.0; | |
| let mut last_idx = vocab_size - 1; | |
| for (i, &idx) in sorted_indices.iter().enumerate() { | |
| cumulative_prob += probs[idx]; | |
| if cumulative_prob > top_p { | |
| last_idx = i; | |
| break; | |
| } | |
| } | |
| let mut new_sum = 0.0; | |
| for i in 0..=last_idx { new_sum += probs[sorted_indices[i]]; } | |
| let mut r = rng.uniform() as f32 * new_sum; | |
| let mut choice = sorted_indices[last_idx]; | |
| for i in 0..=last_idx { | |
| r -= probs[sorted_indices[i]]; | |
| if r <= 0.0 { choice = sorted_indices[i]; break; } | |
| } | |
| token_id = choice; | |
| if token_id == bos_id { break; } | |
| res.push(uchars[token_id]); | |
| } | |
| println!("Top-P Generate: {}", res); | |
| } | |
| fn main() { | |
| let mut rng = Rng::new(42); | |
| let mut docs = vec!["hello world".to_string(), "rust playground gpt".to_string()]; | |
| // 1. Vocab | |
| let mut uchars: Vec<char> = Vec::new(); | |
| let mut seen = [false; 256]; | |
| for d in &docs { for b in d.as_bytes() { seen[*b as usize] = true; } } | |
| for i in 0..256 { if seen[i] { uchars.push(i as u8 as char); } } | |
| uchars.sort(); | |
| let vocab_size = uchars.len() + 1; | |
| let bos_id = uchars.len(); | |
| // 2. Weights | |
| let make_p = |sz, std, r: &mut Rng| (0..sz).map(|_| r.gauss(0.0, std)).collect::<Vec<f32>>(); | |
| let mut wte = make_p(vocab_size * N_EMBD, 0.02, &mut rng); | |
| let mut wpe = make_p(BLOCK_SIZE * N_EMBD, 0.02, &mut rng); | |
| let mut lm = make_p(vocab_size * N_EMBD, 0.02, &mut rng); | |
| let mut layers: Vec<LayerWeights> = (0..N_LAYER).map(|_| LayerWeights { | |
| wq: make_p(N_EMBD * N_EMBD, 0.02, &mut rng), wk: make_p(N_EMBD * N_EMBD, 0.02, &mut rng), | |
| wv: make_p(N_EMBD * N_EMBD, 0.02, &mut rng), wo: make_p(N_EMBD * N_EMBD, 0.0, &mut rng), | |
| fc1: make_p(MLP_DIM * N_EMBD, 0.02, &mut rng), fc2: make_p(MLP_DIM * N_EMBD, 0.0, &mut rng), | |
| d_wq: vec![0.0; N_EMBD * N_EMBD], d_wk: vec![0.0; N_EMBD * N_EMBD], d_wv: vec![0.0; N_EMBD * N_EMBD], d_wo: vec![0.0; N_EMBD * N_EMBD], | |
| d_fc1: vec![0.0; MLP_DIM * N_EMBD], d_fc2: vec![0.0; MLP_DIM * N_EMBD], | |
| m_wq: vec![0.0; N_EMBD * N_EMBD], v_wq: vec![0.0; N_EMBD * N_EMBD], m_wk: vec![0.0; N_EMBD * N_EMBD], v_wk: vec![0.0; N_EMBD * N_EMBD], | |
| m_wv: vec![0.0; N_EMBD * N_EMBD], v_wv: vec![0.0; N_EMBD * N_EMBD], m_wo: vec![0.0; N_EMBD * N_EMBD], v_wo: vec![0.0; N_EMBD * N_EMBD], | |
| m_fc1: vec![0.0; MLP_DIM * N_EMBD], v_fc1: vec![0.0; MLP_DIM * N_EMBD], m_fc2: vec![0.0; MLP_DIM * N_EMBD], v_fc2: vec![0.0; MLP_DIM * N_EMBD], | |
| }).collect(); | |
| println!("GPT Initialized. Starting Playground demo..."); | |
| generate_top_p(&wte, &wpe, &layers, &lm, vocab_size, bos_id, &uchars, &mut rng, 0.8, 0.9); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment