-
-
Save rust-play/1a92b767db126e85d58e7948ed83b11d to your computer and use it in GitHub Desktop.
Code shared from the Rust Playground
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use nalgebra::{DMatrix, DVector}; | |
| // Import the Rng trait so that .gen() is available on thread_rng() | |
| use rand_0_8_5::Rng; | |
| /// 1. Computes standard Softmax Attention Y | |
| fn compute_softmax_attention(q: &DMatrix<f64>, k: &DMatrix<f64>, v: &DMatrix<f64>) -> DMatrix<f64> { | |
| let mut scores = q * k.transpose(); | |
| for i in 0..scores.nrows() { | |
| let mut row = scores.row_mut(i); | |
| let max_val = row.max(); | |
| row.iter_mut().for_each(|x| *x = (*x - max_val).exp()); | |
| let sum: f64 = row.sum(); | |
| row.iter_mut().for_each(|x| *x /= sum); | |
| } | |
| scores * v | |
| } | |
| /// 2. Generates linearized feature matrix X | |
| fn compute_feature_matrix_x(q: &DMatrix<f64>, ck: &DMatrix<f64>, beta: &DVector<f64>) -> DMatrix<f64> { | |
| let mut x = q * ck.transpose(); | |
| for i in 0..x.nrows() { | |
| let mut row = x.row_mut(i); | |
| for j in 0..row.ncols() { | |
| row[j] = (row[j] + beta[j]).exp(); | |
| } | |
| let sum: f64 = row.sum(); | |
| row.iter_mut().for_each(|val| *val /= sum); | |
| } | |
| x | |
| } | |
| fn main() { | |
| let mut rng = rand_0_8_5::thread_rng(); | |
| // Dimensions | |
| let t = 10; | |
| let d = 4; | |
| let k_feat = 6; | |
| // --- INITIALIZE MATRICES --- | |
| // Fix: Use r#gen (Raw Identifier) to call the .gen() method | |
| // without conflicting with the reserved 'gen' keyword. | |
| let q = DMatrix::from_fn(t, d, |_, _| rng.r#gen::<f64>()); | |
| let k = DMatrix::from_fn(t, d, |_, _| rng.r#gen::<f64>()); | |
| let v = DMatrix::from_fn(t, d, |_, _| rng.r#gen::<f64>()); | |
| let ck = DMatrix::from_fn(k_feat, d, |_, _| rng.r#gen::<f64>()); | |
| let beta = DVector::from_fn(k_feat, |_, _| rng.r#gen::<f64>()); | |
| println!("--- Step 1: Computing Standard Attention (Y) ---"); | |
| let y = compute_softmax_attention(&q, &k, &v); | |
| println!("--- Step 2: Computing Linearized Basis (X) ---"); | |
| let x = compute_feature_matrix_x(&q, &ck, &beta); | |
| println!("--- Step 3: Solving OLS Solution for Cv* ---"); | |
| let svd = x.clone().svd(true, true); | |
| let cv_star = svd.solve(&y, 1e-10).expect("SVD Solve failed"); | |
| // --- VERIFICATION --- | |
| let y_approx = &x * &cv_star; | |
| let error = (&y - &y_approx).norm(); | |
| println!("\n--- RESULTS ---"); | |
| println!("Approximation Error: {:.10}", error); | |
| println!("Cv* (Top-left 3x3):\n{}", cv_star.fixed_view::<3, 3>(0, 0)); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment