Created
August 12, 2022 16:27
-
-
Save Steboss89/4a66a2b7fa46955883cfb69a48a1493a to your computer and use it in GitHub Desktop.
Run n_epochs in a linear neural network with Rust tch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::result::Result; | |
use std::error::Error; | |
use mnist::*; | |
use tch::{kind, no_grad, Kind, Tensor}; | |
use ndarray::{Array3, Array2}; | |
const LABELS: i64 = 10; // number of distinct labels | |
const HEIGHT: usize = 28; | |
const WIDTH: usize = 28; | |
const TRAIN_SIZE: usize = 50000; | |
const VAL_SIZE: usize = 10000; | |
const TEST_SIZE: usize =10000; | |
const N_EPOCHS: i64 = 200; | |
const THRES: f64 = 0.001; | |
fn main()-> Result<(), Box<dyn Error>> { | |
// run epochs | |
let mut loss_diff; | |
let mut curr_loss = 0.0; | |
'train: for epoch in 1..N_EPOCHS{ | |
// neural network multiplication | |
let logits = train_data.matmul(&ws) + &bs; | |
// compute the loss as log softmax | |
let loss = logits.log_softmax(-1, Kind::Float).nll_loss(&train_lbl); | |
// gradient | |
ws.zero_grad(); | |
bs.zero_grad(); | |
loss.backward(); | |
// back propgation | |
no_grad(|| { | |
ws += ws.grad()*(-1); | |
bs += bs.grad()*(-1); | |
}); | |
// validation | |
let val_logits = val_data.matmul(&ws) + &bs; | |
let val_accuracy = val_logits | |
.argmax(Some(-1), false) | |
.eq_tensor(&val_lbl) | |
.to_kind(Kind::Float) | |
.mean(Kind::Float) | |
.double_value(&[]); | |
println!( | |
"epoch: {:4} train loss: {:8.5} val acc: {:5.2}%", | |
epoch, | |
loss.double_value(&[]), | |
100. * val_accuracy | |
); | |
// early stop | |
if epoch == 1{ | |
curr_loss = loss.double_value(&[]); | |
} else { | |
loss_diff = (loss.double_value(&[]) - curr_loss).abs(); | |
curr_loss = loss.double_value(&[]); | |
// if we are less then threshold stop | |
if loss_diff < THRES { | |
println!("Target accuracy reached, early stopping"); | |
break 'train; | |
} | |
} | |
} | |
// the final weight and bias gives us the test accuracy | |
let test_logits = test_data.matmul(&ws) + &bs; | |
let test_accuracy = test_logits | |
.argmax(Some(-1), false) | |
.eq_tensor(&test_lbl) | |
.to_kind(Kind::Float) | |
.mean(Kind::Float) | |
.double_value(&[]); | |
println!("Final test accuracy {:5.2}%", 100.*test_accuracy); | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment