Last active
July 12, 2021 11:06
-
-
Save veer66/ded8a5b884c015729661e68f3b11e0e5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use tch::nn::Module; | |
| use tch::nn::OptimizerConfig; | |
| use tch::nn::RNN; | |
| use tch::Reduction; | |
| use tch::{nn, Device, Tensor}; | |
| use std::fs::File; | |
| use std::io::BufWriter; | |
| use std::io::Write; | |
| const SYN_LEN: usize = 100000; | |
| const STEP: f32 = 1f32; | |
| const WINDOW_SIZE: usize = 100; | |
| const INPUT_DIM: i64 = 1i64; | |
| const OUTPUT_DIM: i64 = 1i64; | |
| const HIDDEN_DIM: i64 = 256i64; | |
| const NUM_LAYERS: i64 = 2i64; | |
| const LEARNING_RATE: f64 = 0.001; | |
| const BATCH_SIZE: usize = 1000; | |
| const MAX_EPOCH: usize = 3000; | |
| const TEST_RATIO: f64 = 0.9; | |
| const ANS_PATH: &str = "lstm-prac2-ans.csv"; | |
| fn main() { | |
| let mut x = 1f32; | |
| let mut dataset = vec![]; | |
| for _i in 1..SYN_LEN { | |
| dataset.push(x); | |
| x += STEP; | |
| } | |
| let device = Device::cuda_if_available(); | |
| let vs = nn::VarStore::new(device); | |
| let lstm = nn::lstm( | |
| &vs.root(), | |
| INPUT_DIM, | |
| HIDDEN_DIM, | |
| nn::RNNConfig { | |
| has_biases: true, | |
| num_layers: NUM_LAYERS, | |
| dropout: 0., | |
| train: true, | |
| bidirectional: false, | |
| batch_first: true, | |
| }, | |
| ); | |
| let linear = nn::linear(vs.root(), HIDDEN_DIM, OUTPUT_DIM, Default::default()); | |
| let mut optimizer = nn::Adam::default().build(&vs, LEARNING_RATE).unwrap(); | |
| let mut data = vec![]; | |
| let record_size = dataset.len() - WINDOW_SIZE; | |
| let training_record_size = (record_size as f64 * TEST_RATIO).round() as usize; | |
| for i in 0..record_size { | |
| for j in i..(i + WINDOW_SIZE) { | |
| data.push(dataset[j]); | |
| } | |
| } | |
| for t in 0..=MAX_EPOCH { | |
| let mut sum_loss = 0.; | |
| let mut batch_i = 0; | |
| loop { | |
| if batch_i * BATCH_SIZE >= training_record_size { | |
| break; | |
| } | |
| let first = batch_i * BATCH_SIZE; | |
| let last = ((batch_i + 1) * BATCH_SIZE).min(training_record_size); | |
| let actual_batch_size = last - first; | |
| let xy = Tensor::of_slice(&data[(first * WINDOW_SIZE)..(last * WINDOW_SIZE)]); | |
| let xy = xy.view([actual_batch_size as i64, WINDOW_SIZE as i64, 1i64]); | |
| let x = xy.narrow(1, 0, (WINDOW_SIZE - 1) as i64).to(device); | |
| let y = xy.narrow(1, (WINDOW_SIZE - 1) as i64, 1).to(device); | |
| let (lstm_out, _) = lstm.seq(&x); | |
| let y_predict = linear | |
| .forward(&lstm_out) | |
| .narrow(1, WINDOW_SIZE as i64 - 2, 1); | |
| let loss = y_predict.mse_loss(&y, Reduction::Mean); | |
| optimizer.zero_grad(); | |
| optimizer.backward_step_clip(&loss, 0.5); | |
| sum_loss += f64::from(loss); | |
| batch_i += 1; | |
| } | |
| println!("t={} loss={}", t, sum_loss); | |
| } | |
| let result_path = File::create(ANS_PATH).unwrap(); | |
| let mut write_buf = BufWriter::new(result_path); | |
| write_buf | |
| .write(format!("predict,answer\n").as_bytes()) | |
| .unwrap(); | |
| let mut batch_i = 0; | |
| loop { | |
| if batch_i * BATCH_SIZE + training_record_size >= record_size { | |
| break; | |
| } | |
| println!("WRITE-ANS {}", batch_i); | |
| let first = batch_i * BATCH_SIZE + training_record_size; | |
| let last = ((batch_i + 1) * BATCH_SIZE + training_record_size).min(record_size); | |
| let actual_batch_size = last - first; | |
| let xy = Tensor::of_slice(&data[(first * WINDOW_SIZE)..(last * WINDOW_SIZE)]); | |
| let xy = xy.view([actual_batch_size as i64, WINDOW_SIZE as i64, 1i64]); | |
| let x = xy.narrow(1, 0, (WINDOW_SIZE - 1) as i64).to(device); | |
| let y = Vec::<i64>::from(xy.narrow(1, (WINDOW_SIZE - 1) as i64, 1)); | |
| let (lstm_out, _) = lstm.seq(&x); | |
| let y_predict = Vec::<i64>::from(linear | |
| .forward(&lstm_out) | |
| .narrow(1, WINDOW_SIZE as i64 - 2, 1)); | |
| assert_eq!(y_predict.len(), y.len()); | |
| for i in 0..y.len() { | |
| write_buf | |
| .write(format!("{},{}\n", y_predict[i], y[i]).as_bytes()) | |
| .unwrap(); | |
| } | |
| batch_i += 1; | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment