Skip to content

Instantly share code, notes, and snippets.

@rust-play
Created February 19, 2026 23:38
Show Gist options
  • Select an option

  • Save rust-play/1a92b767db126e85d58e7948ed83b11d to your computer and use it in GitHub Desktop.

Select an option

Save rust-play/1a92b767db126e85d58e7948ed83b11d to your computer and use it in GitHub Desktop.
Code shared from the Rust Playground
use nalgebra::{DMatrix, DVector};
// Import the Rng trait so that .gen() is available on thread_rng()
use rand_0_8_5::Rng;
/// 1. Computes standard Softmax Attention Y
fn compute_softmax_attention(q: &DMatrix<f64>, k: &DMatrix<f64>, v: &DMatrix<f64>) -> DMatrix<f64> {
let mut scores = q * k.transpose();
for i in 0..scores.nrows() {
let mut row = scores.row_mut(i);
let max_val = row.max();
row.iter_mut().for_each(|x| *x = (*x - max_val).exp());
let sum: f64 = row.sum();
row.iter_mut().for_each(|x| *x /= sum);
}
scores * v
}
/// 2. Generates linearized feature matrix X
fn compute_feature_matrix_x(q: &DMatrix<f64>, ck: &DMatrix<f64>, beta: &DVector<f64>) -> DMatrix<f64> {
let mut x = q * ck.transpose();
for i in 0..x.nrows() {
let mut row = x.row_mut(i);
for j in 0..row.ncols() {
row[j] = (row[j] + beta[j]).exp();
}
let sum: f64 = row.sum();
row.iter_mut().for_each(|val| *val /= sum);
}
x
}
fn main() {
let mut rng = rand_0_8_5::thread_rng();
// Dimensions
let t = 10;
let d = 4;
let k_feat = 6;
// --- INITIALIZE MATRICES ---
// Fix: Use r#gen (Raw Identifier) to call the .gen() method
// without conflicting with the reserved 'gen' keyword.
let q = DMatrix::from_fn(t, d, |_, _| rng.r#gen::<f64>());
let k = DMatrix::from_fn(t, d, |_, _| rng.r#gen::<f64>());
let v = DMatrix::from_fn(t, d, |_, _| rng.r#gen::<f64>());
let ck = DMatrix::from_fn(k_feat, d, |_, _| rng.r#gen::<f64>());
let beta = DVector::from_fn(k_feat, |_, _| rng.r#gen::<f64>());
println!("--- Step 1: Computing Standard Attention (Y) ---");
let y = compute_softmax_attention(&q, &k, &v);
println!("--- Step 2: Computing Linearized Basis (X) ---");
let x = compute_feature_matrix_x(&q, &ck, &beta);
println!("--- Step 3: Solving OLS Solution for Cv* ---");
let svd = x.clone().svd(true, true);
let cv_star = svd.solve(&y, 1e-10).expect("SVD Solve failed");
// --- VERIFICATION ---
let y_approx = &x * &cv_star;
let error = (&y - &y_approx).norm();
println!("\n--- RESULTS ---");
println!("Approximation Error: {:.10}", error);
println!("Cv* (Top-left 3x3):\n{}", cv_star.fixed_view::<3, 3>(0, 0));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment