Last active
December 13, 2018 16:13
-
-
Save pedrohbtp/eff5ddcdbadeef5d01f072c3497994ce to your computer and use it in GitHub Desktop.
Pytroch training example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.optim as optim | |
import torch.nn as nn | |
# instantiate your network that should be defined by you | |
net = Net() | |
# create your optimizer | |
optimizer = optim.SGD(net.parameters(), lr=0.01) | |
# define your criterion for optimization | |
criterion = nn.MSELoss() | |
# dat_set comes from somewhere | |
for data in data_set: | |
# zero the gradient buffers | |
optimizer.zero_grad() | |
# Passes the data through your network | |
output = net.forward(data) | |
# calculates the loss | |
loss = criterion(output, target) | |
# Propagates the loss back | |
loss.backward() | |
# Updates all the weights of the network | |
optimizer.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment