Created
August 14, 2022 17:06
-
-
Save Steboss89/d7104ec39745621b64402e3ff9836d72 to your computer and use it in GitHub Desktop.
Convolutional neural network
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::result::Result; | |
use std::error::Error; | |
use mnist::*; | |
use tch::{kind, Kind, Tensor, nn, nn::ModuleT, nn::OptimizerConfig, Device}; | |
use ndarray::{Array3, Array2}; | |
const LABELS: i64 = 10; // number of distinct labels | |
const HEIGHT: usize = 28; | |
const WIDTH: usize = 28; | |
const TRAIN_SIZE: usize = 10000; | |
const VAL_SIZE: usize = 1000; | |
const TEST_SIZE: usize =1000; | |
const N_EPOCHS: i64 = 50 ; | |
const THRES: f64 = 0.001; | |
const BATCH_SIZE: i64 = 256; | |
#[derive(Debug)] | |
struct Net { | |
conv1: nn::Conv2D, | |
conv2: nn::Conv2D, | |
fc1: nn::Linear, | |
fc2: nn::Linear, | |
} | |
impl Net { | |
fn new(vs: &nn::Path) -> Net { | |
// stride -- padding -- dilation | |
let conv1 = nn::conv2d(vs, 1, 32, 5, Default::default()); | |
let conv2 = nn::conv2d(vs, 32, 64, 5, Default::default()); | |
let fc1 = nn::linear(vs, 1024, 1024, Default::default()); | |
let fc2 = nn::linear(vs, 1024, 10, Default::default()); | |
Net { conv1, conv2, fc1, fc2 } | |
} | |
} | |
// forward step | |
impl nn::ModuleT for Net { | |
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor { | |
xs.view([-1, 1, 28, 28]) | |
.apply(&self.conv1) | |
.max_pool2d_default(2) | |
.apply(&self.conv2) | |
.max_pool2d_default(2) | |
.view([-1, 1024]) | |
.apply(&self.fc1) | |
.relu() | |
.dropout(0.5, train) | |
.apply(&self.fc2) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment