Skip to content

Instantly share code, notes, and snippets.

@rkwitt
Last active December 8, 2017 10:00
Show Gist options
  • Save rkwitt/11be4f3cdad66b98c1bd4ff48f6c3ac0 to your computer and use it in GitHub Desktop.
Save rkwitt/11be4f3cdad66b98c1bd4ff48f6c3ac0 to your computer and use it in GitHub Desktop.
"""
Simple example of using pytorch to build a linear regression model.
"""
import torch
from torch.autograd import Variable
from torch import optim
import sys
def build_model():
# model: y = kx
model = torch.nn.Sequential(
torch.nn.Linear(1, 1, bias=False))
return model
def train(model, loss, optimizer, x, y):
x = Variable(x) # In: x
y = Variable(y) # Out: y
# zero-out gradients
optimizer.zero_grad()
# forward pass through the model
fx = model.forward(x.view(len(x),1))
# forward pass through the loss function (mean-squared-error)
error = loss.forward(fx, y)
# backward pass
error.backward()
# take one optimizer step
optimizer.step()
# return the loss
return error.data[0]
def main():
# set RNG seed
torch.manual_seed(1234)
# create our synthetic (x,y) data
X = torch.linspace(-1,1,100)
Y = 2*X + torch.randn(X.size())*0.33
# create the model
model = build_model()
# specify the loss function
loss = torch.nn.MSELoss(size_average=True)
# specify the optimizer (here: stochastic gradient descent)
optimizer = optim.SGD(model.parameters(), lr=0.01)
# fix the batch size
batch_size = 10
# run over the data 100 times - this is the #epochs
for i in range(100):
error = 0.
# compute the number of batches
num_batches = len(X) // batch_size
# run over all batches ...
for k in range(num_batches):
start, end = k * batch_size, (k+1)*batch_size
#... and sum up the error per epoch
error += train(model,
loss,
optimizer,
X[start:end],
Y[start:end])
print i, error / num_batches
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment