Skip to content

Instantly share code, notes, and snippets.

@Steboss89
Created August 14, 2022 16:53
Show Gist options
  • Save Steboss89/e539bdf2427efabdd72c4814d4b2f983 to your computer and use it in GitHub Desktop.
Save Steboss89/e539bdf2427efabdd72c4814d4b2f983 to your computer and use it in GitHub Desktop.
Train a sequential nnet
for epoch in 1..N_EPOCHS {
let loss = net.forward(&train_data).cross_entropy_for_logits(&train_lbl);
// backward step
opt.backward_step(&loss);
//accuracy on test
let val_accuracy = net.forward(&val_data).accuracy_for_logits(&val_lbl);
println!(
"epoch: {:4} train loss: {:8.5} val acc: {:5.2}%",
epoch,
f64::from(&loss),
100. * f64::from(&val_accuracy),
);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment