Skip to content

Instantly share code, notes, and snippets.

@jkotra
Last active April 23, 2019 19:21
Show Gist options
  • Save jkotra/a4a228351b94584789fd02109027fd60 to your computer and use it in GitHub Desktop.
Save jkotra/a4a228351b94584789fd02109027fd60 to your computer and use it in GitHub Desktop.
EPOCHS = 30
trn_loss = []
val_loss = []
for epoch in range(EPOCHS):
train_iter = mx.io.NDArrayIter(trn_x, trn_y, 1000, shuffle=True)
val_iter = mx.io.NDArrayIter(val_x, val_y, 1000, shuffle=True)
for trn_batch,val_batch in zip(train_iter,val_iter):
x = trn_batch.data[0].as_in_context(device)
y = trn_batch.label[0].as_in_context(device)
vx = trn_batch.data[0].as_in_context(device)
vy = trn_batch.label[0].as_in_context(device)
with autograd.record():
y_pred = cnn(x)
loss = loss_function(y_pred, y)
accuracy_fn.update(y,y_pred)
ce_loss.update(y,F.softmax(y_pred))
_,training_acc = accuracy_fn.get()
_,training_loss = ce_loss.get()
trn_loss.append(training_loss)
reset_metrics()
#backprop
loss.backward()
trainer.step(batch_size=trn_x.shape[0])
#computing validation loss
y_pred = cnn(vx)
accuracy_fn.update(vy,y_pred)
ce_loss.update(vy,F.softmax(y_pred))
_,validation_acc = accuracy_fn.get()
_,validation_loss = ce_loss.get()
val_loss.append(validation_loss)
reset_metrics()
print("epoch: {} | trn_loss: {} | trn_acc: {} | val_loss: {}".format(
epoch+1,
trn_loss[-1],
training_acc,
val_loss[-1]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment