Skip to content

Instantly share code, notes, and snippets.

@dvgodoy
Created April 27, 2019 17:12
Show Gist options
  • Save dvgodoy/69b703868a862fdf4fc9d05dccad5476 to your computer and use it in GitHub Desktop.
Save dvgodoy/69b703868a862fdf4fc9d05dccad5476 to your computer and use it in GitHub Desktop.
def make_train_step(model, loss_fn, optimizer):
# Builds function that performs a step in the train loop
def train_step(x, y):
# Sets model to TRAIN mode
model.train()
# Makes predictions
yhat = model(x)
# Computes loss
loss = loss_fn(y, yhat)
# Computes gradients
loss.backward()
# Updates parameters and zeroes gradients
optimizer.step()
optimizer.zero_grad()
# Returns the loss
return loss.item()
# Returns the function that will be called inside the train loop
return train_step
# Creates the train_step function for our model, loss function and optimizer
train_step = make_train_step(model, loss_fn, optimizer)
losses = []
# For each epoch...
for epoch in range(n_epochs):
# Performs one train step and returns the corresponding loss
loss = train_step(x_train_tensor, y_train_tensor)
losses.append(loss)
# Checks model's parameters
print(model.state_dict())
@RustyToms
Copy link

Going through your tutorial at https://towardsdatascience.com/understanding-pytorch-with-an-example-a-step-by-step-tutorial-81fc5f8c4e8e , thanks for that, it's very helpful!

I've been running all the code in a Google Colab. I ran into trouble with this one. You need to redefine the optimizer before you pass it to the constructor in line 22, otherwise the optimizer is still linked to the previous examples parameters. At least that is what I think is happening. This makes it work for me:

# Creates the train_step function for our model, loss function and optimizer
optimizer = optim.SGD(model.parameters(), lr=lr)
train_step = make_train_step(model, loss_fn, optimizer)
losses = []

The only other thing I had to add in going through all the code up to this point, was this code block at the very beginning:

!pip install torchviz
import numpy as np

Well done, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment