Skip to content

Instantly share code, notes, and snippets.

@peterwillcn
Forked from ZhangHanDong/layernorm.rs
Created April 17, 2024 04:07
Show Gist options
  • Save peterwillcn/9ddb219f8e64aba560372094e176bb0f to your computer and use it in GitHub Desktop.
Save peterwillcn/9ddb219f8e64aba560372094e176bb0f to your computer and use it in GitHub Desktop.
Code shared from the Rust Playground
use std::fs::File;
use std::io::prelude::*;
use std::mem;
fn layernorm_forward(output: &mut [f32], mean: &mut [f32], rstd: &mut [f32],
input: &[f32], weight: &[f32], bias: &[f32],
batch_size: usize, time_steps: usize, channels: usize) {
let epsilon = 1e-5;
for b in 0..batch_size {
for t in 0..time_steps {
let x = &input[b * time_steps * channels + t * channels..][..channels];
let m: f32 = x.iter().sum::<f32>() / channels as f32;
let v: f32 = x.iter().map(|&xi| (xi - m).powi(2)).sum::<f32>() / channels as f32;
let s: f32 = 1.0 / (v + epsilon).sqrt();
let output_bt = &mut output[b * time_steps * channels + t * channels..][..channels];
for i in 0..channels {
let n = s * (x[i] - m);
let o = n * weight[i] + bias[i];
output_bt[i] = o;
}
mean[b * time_steps + t] = m;
rstd[b * time_steps + t] = s;
}
}
}
fn layernorm_backward(d_input: &mut [f32], d_weight: &mut [f32], d_bias: &mut [f32],
d_output: &[f32], input: &[f32], weight: &[f32], mean: &[f32], rstd: &[f32],
batch_size: usize, time_steps: usize, channels: usize) {
for b in 0..batch_size {
for t in 0..time_steps {
let d_output_bt = &d_output[b * time_steps * channels + t * channels..][..channels];
let input_bt = &input[b * time_steps * channels + t * channels..][..channels];
let d_input_bt = &mut d_input[b * time_steps * channels + t * channels..][..channels];
let mean_bt = mean[b * time_steps + t];
let rstd_bt = rstd[b * time_steps + t];
let mut dnorm_mean = 0.0;
let mut dnorm_norm_mean = 0.0;
for i in 0..channels {
let norm_bti = (input_bt[i] - mean_bt) * rstd_bt;
let dnorm_i = weight[i] * d_output_bt[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean /= channels as f32;
dnorm_norm_mean /= channels as f32;
for i in 0..channels {
let norm_bti = (input_bt[i] - mean_bt) * rstd_bt;
let dnorm_i = weight[i] * d_output_bt[i];
d_bias[i] += d_output_bt[i];
d_weight[i] += norm_bti * d_output_bt[i];
let mut dval = dnorm_i - dnorm_mean - norm_bti * dnorm_norm_mean;
dval *= rstd_bt;
d_input_bt[i] += dval;
}
}
}
}
fn check_tensor(a: &[f32], b: &[f32], label: &str) -> bool {
let mut is_ok = true;
println!("{}", label);
for (i, (&ai, &bi)) in a.iter().zip(b.iter()).enumerate() {
if (ai - bi).abs() <= 1e-5 {
print!("OK ");
} else {
print!("NOT OK ");
is_ok = false;
}
println!("{} {}", ai, bi);
}
is_ok
}
fn main() {
let batch_size = 2;
let time_steps = 3;
let channels = 4;
let mut x = vec![0.0; batch_size * time_steps * channels];
let mut w = vec![0.0; channels];
let mut b = vec![0.0; channels];
let mut out = vec![0.0; batch_size * time_steps * channels];
let mut mean = vec![0.0; batch_size * time_steps];
let mut rstd = vec![0.0; batch_size * time_steps];
let mut dout = vec![0.0; batch_size * time_steps * channels];
let mut dx = vec![0.0; batch_size * time_steps * channels];
let mut dw = vec![0.0; channels];
let mut db = vec![0.0; channels];
let mut file = File::open("ln.bin").expect("Error opening file");
let mut data = vec![0.0; batch_size * time_steps * channels + 2 * batch_size * time_steps + 2 * channels];
file.read_exact(unsafe { mem::transmute(&mut data[..]) }).unwrap();
let mut offset = 0;
for tensor in [&mut x, &mut w, &mut b, &mut out, &mut mean, &mut rstd, &mut dout, &mut dx, &mut dw, &mut db].iter_mut() {
let size = tensor.len() * std::mem::size_of::<f32>();
let slice = &data[offset..offset + size];
unsafe {
std::ptr::copy_nonoverlapping(slice.as_ptr() as *const f32, tensor.as_mut_ptr(), size / std::mem::size_of::<f32>());
}
offset += size;
}
let mut c_out = vec![0.0; batch_size * time_steps * channels];
let mut c_mean = vec![0.0; batch_size * time_steps];
let mut c_rstd = vec![0.0; batch_size * time_steps];
layernorm_forward(&mut c_out, &mut c_mean, &mut c_rstd, &x, &w, &b, batch_size, time_steps, channels);
check_tensor(&out, &c_out, "out");
check_tensor(&mean, &c_mean, "mean");
check_tensor(&rstd, &c_rstd, "rstd");
let mut c_dx = vec![0.0; batch_size * time_steps * channels];
let mut c_dw = vec![0.0; channels];
let mut c_db = vec![0.0; channels];
layernorm_backward(&mut c_dx, &mut c_dw, &mut c_db, &dout, &x, &w, &c_mean, &c_rstd, batch_size, time_steps, channels);
check_tensor(&c_dx, &dx, "dx");
check_tensor(&c_dw, &dw, "dw");
check_tensor(&c_db, &db, "db");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment