Created
August 12, 2022 16:38
-
-
Save Steboss89/f0c06b879e077460e040fac4ea8a5fc6 to your computer and use it in GitHub Desktop.
Sequential neural network implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::result::Result; | |
use std::error::Error; | |
use mnist::*; | |
use tch::{kind, Kind, Tensor, nn, nn::Module, nn::OptimizerConfig, Device}; | |
use ndarray::{Array3, Array2}; | |
const LABELS: i64 = 10; // number of distinct labels | |
const HEIGHT: usize = 28; | |
const WIDTH: usize = 28; | |
const IMAGE_DIM: i64 = 784; | |
const HIDDEN_NODES: i64 = 128; | |
const TRAIN_SIZE: usize = 50000; | |
const VAL_SIZE: usize = 10000; | |
const TEST_SIZE: usize =10000; | |
const N_EPOCHS: i64 = 200; | |
const THRES: f64 = 0.001; | |
const BATCH_SIZE: i64 = 256; | |
fn net(vs: &nn::Path) -> impl Module{ | |
nn::seq() | |
.add(nn::linear(vs/"layer1", IMAGE_DIM, HIDDEN_NODES, Default::default() )) | |
.add_fn(|xs| xs.relu()) | |
.add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default())) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment