-
-
Save peterwillcn/9ddb219f8e64aba560372094e176bb0f to your computer and use it in GitHub Desktop.
Code shared from the Rust Playground
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::fs::File; | |
use std::io::prelude::*; | |
use std::mem; | |
fn layernorm_forward(output: &mut [f32], mean: &mut [f32], rstd: &mut [f32], | |
input: &[f32], weight: &[f32], bias: &[f32], | |
batch_size: usize, time_steps: usize, channels: usize) { | |
let epsilon = 1e-5; | |
for b in 0..batch_size { | |
for t in 0..time_steps { | |
let x = &input[b * time_steps * channels + t * channels..][..channels]; | |
let m: f32 = x.iter().sum::<f32>() / channels as f32; | |
let v: f32 = x.iter().map(|&xi| (xi - m).powi(2)).sum::<f32>() / channels as f32; | |
let s: f32 = 1.0 / (v + epsilon).sqrt(); | |
let output_bt = &mut output[b * time_steps * channels + t * channels..][..channels]; | |
for i in 0..channels { | |
let n = s * (x[i] - m); | |
let o = n * weight[i] + bias[i]; | |
output_bt[i] = o; | |
} | |
mean[b * time_steps + t] = m; | |
rstd[b * time_steps + t] = s; | |
} | |
} | |
} | |
fn layernorm_backward(d_input: &mut [f32], d_weight: &mut [f32], d_bias: &mut [f32], | |
d_output: &[f32], input: &[f32], weight: &[f32], mean: &[f32], rstd: &[f32], | |
batch_size: usize, time_steps: usize, channels: usize) { | |
for b in 0..batch_size { | |
for t in 0..time_steps { | |
let d_output_bt = &d_output[b * time_steps * channels + t * channels..][..channels]; | |
let input_bt = &input[b * time_steps * channels + t * channels..][..channels]; | |
let d_input_bt = &mut d_input[b * time_steps * channels + t * channels..][..channels]; | |
let mean_bt = mean[b * time_steps + t]; | |
let rstd_bt = rstd[b * time_steps + t]; | |
let mut dnorm_mean = 0.0; | |
let mut dnorm_norm_mean = 0.0; | |
for i in 0..channels { | |
let norm_bti = (input_bt[i] - mean_bt) * rstd_bt; | |
let dnorm_i = weight[i] * d_output_bt[i]; | |
dnorm_mean += dnorm_i; | |
dnorm_norm_mean += dnorm_i * norm_bti; | |
} | |
dnorm_mean /= channels as f32; | |
dnorm_norm_mean /= channels as f32; | |
for i in 0..channels { | |
let norm_bti = (input_bt[i] - mean_bt) * rstd_bt; | |
let dnorm_i = weight[i] * d_output_bt[i]; | |
d_bias[i] += d_output_bt[i]; | |
d_weight[i] += norm_bti * d_output_bt[i]; | |
let mut dval = dnorm_i - dnorm_mean - norm_bti * dnorm_norm_mean; | |
dval *= rstd_bt; | |
d_input_bt[i] += dval; | |
} | |
} | |
} | |
} | |
fn check_tensor(a: &[f32], b: &[f32], label: &str) -> bool { | |
let mut is_ok = true; | |
println!("{}", label); | |
for (i, (&ai, &bi)) in a.iter().zip(b.iter()).enumerate() { | |
if (ai - bi).abs() <= 1e-5 { | |
print!("OK "); | |
} else { | |
print!("NOT OK "); | |
is_ok = false; | |
} | |
println!("{} {}", ai, bi); | |
} | |
is_ok | |
} | |
fn main() { | |
let batch_size = 2; | |
let time_steps = 3; | |
let channels = 4; | |
let mut x = vec![0.0; batch_size * time_steps * channels]; | |
let mut w = vec![0.0; channels]; | |
let mut b = vec![0.0; channels]; | |
let mut out = vec![0.0; batch_size * time_steps * channels]; | |
let mut mean = vec![0.0; batch_size * time_steps]; | |
let mut rstd = vec![0.0; batch_size * time_steps]; | |
let mut dout = vec![0.0; batch_size * time_steps * channels]; | |
let mut dx = vec![0.0; batch_size * time_steps * channels]; | |
let mut dw = vec![0.0; channels]; | |
let mut db = vec![0.0; channels]; | |
let mut file = File::open("ln.bin").expect("Error opening file"); | |
let mut data = vec![0.0; batch_size * time_steps * channels + 2 * batch_size * time_steps + 2 * channels]; | |
file.read_exact(unsafe { mem::transmute(&mut data[..]) }).unwrap(); | |
let mut offset = 0; | |
for tensor in [&mut x, &mut w, &mut b, &mut out, &mut mean, &mut rstd, &mut dout, &mut dx, &mut dw, &mut db].iter_mut() { | |
let size = tensor.len() * std::mem::size_of::<f32>(); | |
let slice = &data[offset..offset + size]; | |
unsafe { | |
std::ptr::copy_nonoverlapping(slice.as_ptr() as *const f32, tensor.as_mut_ptr(), size / std::mem::size_of::<f32>()); | |
} | |
offset += size; | |
} | |
let mut c_out = vec![0.0; batch_size * time_steps * channels]; | |
let mut c_mean = vec![0.0; batch_size * time_steps]; | |
let mut c_rstd = vec![0.0; batch_size * time_steps]; | |
layernorm_forward(&mut c_out, &mut c_mean, &mut c_rstd, &x, &w, &b, batch_size, time_steps, channels); | |
check_tensor(&out, &c_out, "out"); | |
check_tensor(&mean, &c_mean, "mean"); | |
check_tensor(&rstd, &c_rstd, "rstd"); | |
let mut c_dx = vec![0.0; batch_size * time_steps * channels]; | |
let mut c_dw = vec![0.0; channels]; | |
let mut c_db = vec![0.0; channels]; | |
layernorm_backward(&mut c_dx, &mut c_dw, &mut c_db, &dout, &x, &w, &c_mean, &c_rstd, batch_size, time_steps, channels); | |
check_tensor(&c_dx, &dx, "dx"); | |
check_tensor(&c_dw, &dw, "dw"); | |
check_tensor(&c_db, &db, "db"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment