Last active
December 8, 2017 10:00
-
-
Save rkwitt/11be4f3cdad66b98c1bd4ff48f6c3ac0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Simple example of using pytorch to build a linear regression model. | |
""" | |
import torch | |
from torch.autograd import Variable | |
from torch import optim | |
import sys | |
def build_model(): | |
# model: y = kx | |
model = torch.nn.Sequential( | |
torch.nn.Linear(1, 1, bias=False)) | |
return model | |
def train(model, loss, optimizer, x, y): | |
x = Variable(x) # In: x | |
y = Variable(y) # Out: y | |
# zero-out gradients | |
optimizer.zero_grad() | |
# forward pass through the model | |
fx = model.forward(x.view(len(x),1)) | |
# forward pass through the loss function (mean-squared-error) | |
error = loss.forward(fx, y) | |
# backward pass | |
error.backward() | |
# take one optimizer step | |
optimizer.step() | |
# return the loss | |
return error.data[0] | |
def main(): | |
# set RNG seed | |
torch.manual_seed(1234) | |
# create our synthetic (x,y) data | |
X = torch.linspace(-1,1,100) | |
Y = 2*X + torch.randn(X.size())*0.33 | |
# create the model | |
model = build_model() | |
# specify the loss function | |
loss = torch.nn.MSELoss(size_average=True) | |
# specify the optimizer (here: stochastic gradient descent) | |
optimizer = optim.SGD(model.parameters(), lr=0.01) | |
# fix the batch size | |
batch_size = 10 | |
# run over the data 100 times - this is the #epochs | |
for i in range(100): | |
error = 0. | |
# compute the number of batches | |
num_batches = len(X) // batch_size | |
# run over all batches ... | |
for k in range(num_batches): | |
start, end = k * batch_size, (k+1)*batch_size | |
#... and sum up the error per epoch | |
error += train(model, | |
loss, | |
optimizer, | |
X[start:end], | |
Y[start:end]) | |
print i, error / num_batches | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment