Skip to content

Instantly share code, notes, and snippets.

@Steboss89
Created August 12, 2022 16:27
Show Gist options
  • Save Steboss89/4a66a2b7fa46955883cfb69a48a1493a to your computer and use it in GitHub Desktop.
Save Steboss89/4a66a2b7fa46955883cfb69a48a1493a to your computer and use it in GitHub Desktop.
Run n_epochs in a linear neural network with Rust tch
use std::result::Result;
use std::error::Error;
use mnist::*;
use tch::{kind, no_grad, Kind, Tensor};
use ndarray::{Array3, Array2};
const LABELS: i64 = 10; // number of distinct labels
const HEIGHT: usize = 28;
const WIDTH: usize = 28;
const TRAIN_SIZE: usize = 50000;
const VAL_SIZE: usize = 10000;
const TEST_SIZE: usize =10000;
const N_EPOCHS: i64 = 200;
const THRES: f64 = 0.001;
fn main()-> Result<(), Box<dyn Error>> {
// run epochs
let mut loss_diff;
let mut curr_loss = 0.0;
'train: for epoch in 1..N_EPOCHS{
// neural network multiplication
let logits = train_data.matmul(&ws) + &bs;
// compute the loss as log softmax
let loss = logits.log_softmax(-1, Kind::Float).nll_loss(&train_lbl);
// gradient
ws.zero_grad();
bs.zero_grad();
loss.backward();
// back propgation
no_grad(|| {
ws += ws.grad()*(-1);
bs += bs.grad()*(-1);
});
// validation
let val_logits = val_data.matmul(&ws) + &bs;
let val_accuracy = val_logits
.argmax(Some(-1), false)
.eq_tensor(&val_lbl)
.to_kind(Kind::Float)
.mean(Kind::Float)
.double_value(&[]);
println!(
"epoch: {:4} train loss: {:8.5} val acc: {:5.2}%",
epoch,
loss.double_value(&[]),
100. * val_accuracy
);
// early stop
if epoch == 1{
curr_loss = loss.double_value(&[]);
} else {
loss_diff = (loss.double_value(&[]) - curr_loss).abs();
curr_loss = loss.double_value(&[]);
// if we are less then threshold stop
if loss_diff < THRES {
println!("Target accuracy reached, early stopping");
break 'train;
}
}
}
// the final weight and bias gives us the test accuracy
let test_logits = test_data.matmul(&ws) + &bs;
let test_accuracy = test_logits
.argmax(Some(-1), false)
.eq_tensor(&test_lbl)
.to_kind(Kind::Float)
.mean(Kind::Float)
.double_value(&[]);
println!("Final test accuracy {:5.2}%", 100.*test_accuracy);
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment