Created
April 27, 2019 17:12
-
-
Save dvgodoy/69b703868a862fdf4fc9d05dccad5476 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def make_train_step(model, loss_fn, optimizer): | |
| # Builds function that performs a step in the train loop | |
| def train_step(x, y): | |
| # Sets model to TRAIN mode | |
| model.train() | |
| # Makes predictions | |
| yhat = model(x) | |
| # Computes loss | |
| loss = loss_fn(y, yhat) | |
| # Computes gradients | |
| loss.backward() | |
| # Updates parameters and zeroes gradients | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| # Returns the loss | |
| return loss.item() | |
| # Returns the function that will be called inside the train loop | |
| return train_step | |
| # Creates the train_step function for our model, loss function and optimizer | |
| train_step = make_train_step(model, loss_fn, optimizer) | |
| losses = [] | |
| # For each epoch... | |
| for epoch in range(n_epochs): | |
| # Performs one train step and returns the corresponding loss | |
| loss = train_step(x_train_tensor, y_train_tensor) | |
| losses.append(loss) | |
| # Checks model's parameters | |
| print(model.state_dict()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Going through your tutorial at https://towardsdatascience.com/understanding-pytorch-with-an-example-a-step-by-step-tutorial-81fc5f8c4e8e , thanks for that, it's very helpful!
I've been running all the code in a Google Colab. I ran into trouble with this one. You need to redefine the optimizer before you pass it to the constructor in line 22, otherwise the optimizer is still linked to the previous examples parameters. At least that is what I think is happening. This makes it work for me:
The only other thing I had to add in going through all the code up to this point, was this code block at the very beginning:
Well done, thanks!