Skip to content

Instantly share code, notes, and snippets.

@Steboss89
Created August 14, 2022 17:06
Show Gist options
  • Save Steboss89/d7104ec39745621b64402e3ff9836d72 to your computer and use it in GitHub Desktop.
Save Steboss89/d7104ec39745621b64402e3ff9836d72 to your computer and use it in GitHub Desktop.
Convolutional neural network
use std::result::Result;
use std::error::Error;
use mnist::*;
use tch::{kind, Kind, Tensor, nn, nn::ModuleT, nn::OptimizerConfig, Device};
use ndarray::{Array3, Array2};
const LABELS: i64 = 10; // number of distinct labels
const HEIGHT: usize = 28;
const WIDTH: usize = 28;
const TRAIN_SIZE: usize = 10000;
const VAL_SIZE: usize = 1000;
const TEST_SIZE: usize =1000;
const N_EPOCHS: i64 = 50 ;
const THRES: f64 = 0.001;
const BATCH_SIZE: i64 = 256;
#[derive(Debug)]
struct Net {
conv1: nn::Conv2D,
conv2: nn::Conv2D,
fc1: nn::Linear,
fc2: nn::Linear,
}
impl Net {
fn new(vs: &nn::Path) -> Net {
// stride -- padding -- dilation
let conv1 = nn::conv2d(vs, 1, 32, 5, Default::default());
let conv2 = nn::conv2d(vs, 32, 64, 5, Default::default());
let fc1 = nn::linear(vs, 1024, 1024, Default::default());
let fc2 = nn::linear(vs, 1024, 10, Default::default());
Net { conv1, conv2, fc1, fc2 }
}
}
// forward step
impl nn::ModuleT for Net {
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
xs.view([-1, 1, 28, 28])
.apply(&self.conv1)
.max_pool2d_default(2)
.apply(&self.conv2)
.max_pool2d_default(2)
.view([-1, 1024])
.apply(&self.fc1)
.relu()
.dropout(0.5, train)
.apply(&self.fc2)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment