Skip to content

Instantly share code, notes, and snippets.

@bkaankuguoglu
Created April 1, 2021 20:20
Show Gist options
  • Save bkaankuguoglu/6e511582ad658273565cd906f759a947 to your computer and use it in GitHub Desktop.
Save bkaankuguoglu/6e511582ad658273565cd906f759a947 to your computer and use it in GitHub Desktop.
class Optimization:
def __init__(self, model, loss_fn, optimizer):
self.model = model
self.loss_fn = loss_fn
self.optimizer = optimizer
self.train_losses = []
self.val_losses = []
def train_step(self, x, y):
# Sets model to train mode
self.model.train()
# Makes predictions
yhat = self.model(x)
# Computes loss
loss = self.loss_fn(y, yhat)
# Computes gradients
loss.backward()
# Updates parameters and zeroes gradients
self.optimizer.step()
self.optimizer.zero_grad()
# Returns the loss
return loss.item()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment