Skip to content

Instantly share code, notes, and snippets.

@mmitou
Created November 6, 2019 07:11
Show Gist options
  • Select an option

  • Save mmitou/59aa721f277dec705e51028da68eff1a to your computer and use it in GitHub Desktop.

Select an option

Save mmitou/59aa721f277dec705e51028da68eff1a to your computer and use it in GitHub Desktop.
シンプルな線型回帰をpytorchで書いてみた。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
# dataset
# original: f(x) = 2x + 3
dataset_size = 100
lr = 0.01
x_train = np.random.randn(dataset_size)
y_train = (2 * x_train + 3) + np.random.normal(0, 1, dataset_size)
x_train = x_train.reshape(dataset_size, 1)
y_train = y_train.reshape(dataset_size, 1)
# plt.scatter(x_train, y_train)
# plt.show()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(1, 1)
def forward(self, x):
x = self.fc1(x)
return x
model = Net()
class MyOptimizer(optim.Optimizer):
def __init__(self, params, lr=0.01):
defaults = dict(lr=lr)
super(MyOptimizer, self).__init__(params, defaults)
def __setstate__(self, state):
super(MyOptimizer, self).__setstate__(state)
for group in self.param_group:
group.setdefault('nesterov', False)
def step(self, closure=None):
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.data.add_(-group['lr'], p.grad.data)
class MyLoss(nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, inputs, targets):
loss = torch.mean((inputs - targets) ** 2)
return loss
criterion = MyLoss()
optimizer = MyOptimizer(model.parameters())
for i in range(1000):
inputs = torch.from_numpy(x_train)
targets = torch.from_numpy(y_train)
optimizer.zero_grad()
outputs = model(inputs.float())
loss = criterion(outputs.float(), targets.float())
print(loss.item())
loss.backward()
optimizer.step()
for i in model.parameters():
print(i)
predicted = model(torch.from_numpy(x_train).float()).detach().numpy()
plt.plot(x_train, y_train, 'ro', label='orginal')
plt.plot(x_train, predicted, label='fitted')
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment