Last active
August 4, 2020 16:06
-
-
Save inferrna/b1b49b9d0e161a5104670ecb7870e19a to your computer and use it in GitHub Desktop.
Try to build multilayer perceptron based on https://github.com/raskr/rust-autograd
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate autograd as ag; | |
extern crate ndarray; | |
use ag::{ndarray_ext as arr, NdArray}; | |
use ag::optimizers::adam; | |
use ag::rand::seq::SliceRandom; | |
use ag::tensor::Variable; | |
use ag::ndarray::s; | |
use rand::RngCore; | |
use std::ops::{Mul, Add, Div}; | |
const TAU: f32 = 6.28318530717958647692528676655900577_f32; | |
fn get_permutation(size: usize) -> Vec<usize> { | |
let mut perm: Vec<usize> = (0..size).collect(); | |
perm.shuffle(&mut rand::thread_rng()); | |
perm | |
} | |
fn main() { | |
let rng = ag::ndarray_ext::ArrayRng::<f32>::default(); | |
let w_arr = arr::into_shared(rng.glorot_uniform(&[2, 1])); | |
//Additional layer(s) | |
let w2_arr = arr::into_shared(rng.glorot_uniform(&[1, 32])); | |
let w3_arr = arr::into_shared(rng.glorot_uniform(&[32, 1])); | |
let b_arr = arr::into_shared(arr::zeros(&[1, 1])); | |
let adam_state = adam::AdamState::new(&[&w_arr, &w2_arr, &w3_arr, &b_arr]); | |
let num_samples = 10000; | |
let max_epoch = 3000; | |
//All data | |
let (mut x_values, mut y_values): (Vec<Vec<f32>>, Vec<f32>) = (0..num_samples) | |
.map(|i| (i as f32 / num_samples as f32)) | |
.map(|i| | |
(vec![i.mul(TAU).sin().add(1.0).div(2.0), | |
i.mul(TAU).cos().add(1.0).div(2.0)], | |
i)) | |
.unzip(); | |
//Test data | |
let (x_test, y_test) = { | |
let mut rnd = rand::thread_rng(); | |
let mut x_test: Vec<Vec<f32>> = vec![]; | |
let mut y_test: Vec<f32> = vec![]; | |
for _ in 0..num_samples/100 { | |
let idx = rnd.next_u64() as usize % x_values.len(); | |
x_test.push(x_values.remove(idx)); | |
y_test.push(y_values.remove(idx)); | |
} | |
(x_test, y_test) | |
}; | |
let train_samples = y_values.len(); | |
let test_samples = y_test.len(); | |
let x_train: Vec<f32> = x_values.into_iter().flatten().collect(); | |
let x_test: Vec<f32> = x_test.into_iter().flatten().collect(); | |
//Check data correctness | |
let sanity_test_idx = 69; | |
assert_eq!(y_values[sanity_test_idx].mul(TAU).sin().add(1.0).div(2.0), x_train[sanity_test_idx*2]); | |
assert_eq!(y_values[sanity_test_idx].mul(TAU).cos().add(1.0).div(2.0), x_train[sanity_test_idx*2+1]); | |
assert_eq!(y_test[sanity_test_idx].mul(TAU).sin().add(1.0).div(2.0), x_test[sanity_test_idx*2]); | |
assert_eq!(y_test[sanity_test_idx].mul(TAU).cos().add(1.0).div(2.0), x_test[sanity_test_idx*2+1]); | |
//Convert data to ndarray | |
let as_arr = NdArray::from_shape_vec; | |
let y_train = as_arr(ag::ndarray::IxDyn(&[train_samples, 1]), y_values).unwrap(); | |
let x_train = as_arr(ag::ndarray::IxDyn(&[train_samples, 2]), x_train).unwrap(); | |
let y_test = as_arr(ag::ndarray::IxDyn(&[test_samples, 1]), y_test).unwrap(); | |
let x_test = as_arr(ag::ndarray::IxDyn(&[test_samples, 2]), x_test).unwrap(); | |
//Train | |
for epoch in 0..max_epoch { | |
ag::with(|g| { | |
let w = g.variable(w_arr.clone()); | |
let w2 = g.variable(w2_arr.clone()); | |
let w3 = g.variable(w3_arr.clone()); | |
let b = g.variable(b_arr.clone()); | |
let x = g.placeholder(&[-1,2]); | |
let y = g.placeholder(&[-1,1]); | |
let xw = g.matmul(x, w); | |
let xww2 = g.matmul(xw, w2); | |
let z = g.sigmoid(g.matmul(xww2, w3) + b); | |
let mean_loss = g.reduce_mean(g.square(g.sub(z, &y)), &[0,1], false); | |
if epoch % 1000 == 0 || (epoch < 1000 && epoch % 10 == 0) { | |
let acc = mean_loss.eval(&[x.given(x_train.view()), y.given(y_train.view())]).unwrap(); | |
println!( | |
"Epoch {}, train error: {:?}", | |
epoch, acc.view() | |
); | |
} | |
let grads = &g.grad(&[&mean_loss], &[w, w2, w3, b]); | |
let update_ops: &[ag::Tensor<f32>] = | |
&adam::Adam::default().compute_updates(&[w, w2, w3, b], grads, &adam_state, g); | |
let batch_size = 50isize; | |
let num_batches = train_samples / batch_size as usize; | |
for i in get_permutation(num_batches) { | |
let i = i as isize * batch_size; | |
let y_batch = y_train.slice(s![i..i + batch_size, ..]).into_dyn(); | |
let x_batch = x_train.slice(s![i..i + batch_size, ..]).into_dyn(); | |
g.eval(update_ops, &[x.given(x_batch), y.given(y_batch)]); | |
} | |
}); | |
} | |
//Test | |
ag::with(|g| { | |
let w = g.variable(w_arr.clone()); | |
let w2 = g.variable(w2_arr.clone()); | |
let w3 = g.variable(w3_arr.clone()); | |
let b = g.variable(b_arr.clone()); | |
let x = g.placeholder(&[-1,2]); | |
let y = g.placeholder(&[-1,1]); | |
// -- test -- | |
let xw = g.matmul(x, w); | |
let xww2 = g.matmul(xw, w2); | |
let z = g.sigmoid(g.matmul(xww2, w3) + b); | |
let predictions = z; | |
let accuracy = g.reduce_mean(g.square(g.sub(predictions, &y)), &[0,1], false); | |
let acc = accuracy.eval(&[x.given(x_test.view()), y.given(y_test.view())]).unwrap(); | |
let values = z.eval(&[x.given(x_test.view()), y.given(y_test.view())]).unwrap(); | |
println!( | |
"test error: {:?}, result values = \n{:?}\noriginal values = \n{:?}", | |
acc.view(), values.view(), y_test.view() | |
); | |
}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment