Skip to content

Instantly share code, notes, and snippets.

@Steboss89
Created August 14, 2022 17:28
Show Gist options
  • Save Steboss89/66efb511fb9382e5241be900b052c702 to your computer and use it in GitHub Desktop.
Save Steboss89/66efb511fb9382e5241be900b052c702 to your computer and use it in GitHub Desktop.
Train convolutional neural network
for epoch in 1..N_EPOCHS {
// generate random idxs for batch size
// run all the images divided in batches -> for loop
for i in 1..n_it {
let batch_idxs = generate_random_index(TRAIN_SIZE as i64, BATCH_SIZE);
let batch_images = train_data.index_select(0, &batch_idxs).to_device(vs.device()).to_kind(Kind::Float);
let batch_lbls = train_lbl.index_select(0, &batch_idxs).to_device(vs.device()).to_kind(Kind::Int64);
// compute the loss
let loss = net.forward_t(&batch_images, true).cross_entropy_for_logits(&batch_lbls);
opt.backward_step(&loss);
}
// compute accuracy
let val_accuracy =
net.batch_accuracy_for_logits(&val_data, &val_lbl, vs.device(), 1024);
println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * val_accuracy,);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment