Created
February 24, 2018 11:44
-
-
Save kmjjacobs/62fc96ece695b47af8d667b060a64559 to your computer and use it in GitHub Desktop.
Chainer - Complete Logistic Regression Model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# The goal of this Gist is to implement a simple (Logistic Regression) model such that most of the functionalities of Chainer are used | |
# @author K.M.J. Jacobs | |
# @date 2018-02-24 | |
# @website https://www.data-blogger.com | |
import chainer | |
from chainer import reporter as reporter_module | |
from chainer.training.extensions import LogReport | |
from chainer import iterators | |
from chainer import training | |
from chainer.datasets import TransformDataset | |
from chainer.training import extensions | |
from chainer.datasets import split_dataset | |
from chainer import optimizers | |
import chainer.optimizer | |
import chainer.initializers | |
import chainer.links as L | |
import chainer.functions as F | |
from chainer import Chain | |
import numpy as np | |
class LogisticRegressionModel(Chain): | |
def __init__(self): | |
super(LogisticRegressionModel, self).__init__() | |
with self.init_scope(): | |
self.w = chainer.Parameter(initializer=chainer.initializers.Normal()) | |
self.w.initialize([3, 1]) | |
def __call__(self, x, t): | |
# Call the loss function | |
return self.loss(x, t) | |
def predict(self, x): | |
# Predict given an input (a, b, 1) | |
z = F.matmul(x, self.w) | |
return 1. / (1. + F.exp(-z)) | |
def loss(self, x, t): | |
# Compute the loss for a given input (a, b, 1) and target | |
y = self.predict(x) | |
loss = -t * F.log(y) - (1 - t) * F.log(1 - y) | |
reporter_module.report({'loss': loss.data[0, 0]}, self) | |
reporter_module.report({'w': self.w[0, 0]}, self) | |
return loss | |
def converter(minibatch, device=None): | |
# For splitting array into inputs / targets | |
inputs = [] | |
targets = [] | |
for item in minibatch: | |
inputs.append(item[:3]) | |
targets.append(item[3]) | |
inputs = np.matrix(inputs) | |
targets = np.array(targets) | |
return inputs, targets | |
# Set the seed for reproduction | |
np.random.seed(0) | |
# The dataset consists of samples (a, b, 1) and the target is a function f such that f(a, b, 1) = a > b | |
# So for example: f(0.5, 0.6, 1) = 0. (False) and f(0.8, 0.2, 1) = 1. (True) since 0.8 > 0.2 | |
# The 1 serves as bias so the model can train for a constant offset | |
N = 10000 | |
data = np.random.random((N, 4)) | |
data[:, 2] = 1. | |
data[:, 3] = data[:, 0] > data[:, 1] | |
# Split the data into a train and a test set such that there are 10 examples in the test set | |
data_test, data_train = split_dataset(data, 10) | |
train_iter = iterators.SerialIterator(data_train, 1, False, False) | |
test_iter = iterators.SerialIterator(data_test, 1, False, False) | |
# Setup the model | |
model = LogisticRegressionModel() | |
# Create the optimizer for the model | |
optimizer = optimizers.SGD() | |
optimizer.use_cleargrads(True) | |
optimizer.setup(model) | |
# Setup the training loop (and use the Evaluator, LogReport and PrintReport extension) with the following properties: | |
# - Run for 10.000 iterations | |
# - Evaluate every 1.000 iterations | |
# - Write logs every 1.000 iterations | |
# - Print the losses and the epoch, iteration and elapsed_time | |
trainer = training.Trainer(training.StandardUpdater(train_iter, optimizer, converter), (10001, 'iteration'), out='result') | |
trainer.extend(extensions.Evaluator(test_iter, model, converter), trigger=(1000, 'iteration')) | |
trainer.extend(extensions.LogReport(trigger=(1000, 'iteration'))) | |
trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/w', 'validation/main/w', 'elapsed_time']))#, 'main/loss', 'validation/main/loss', 'elapsed_time'], )) | |
trainer.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment