Skip to content

Instantly share code, notes, and snippets.

@okapies
Created June 12, 2019 03:21
Show Gist options
  • Save okapies/ab7c8f413c3bb46c81b4b4e8ba0e7603 to your computer and use it in GitHub Desktop.
Save okapies/ab7c8f413c3bb46c81b4b4e8ba0e7603 to your computer and use it in GitHub Desktop.
A customized train_mnist example to measure the performance of extension
#!/usr/bin/env python
import argparse
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
import numpy as np
# Network definition
class MLP(chainer.Chain):
def __init__(self, n_units, n_out):
super(MLP, self).__init__()
with self.init_scope():
# the size of the inputs to each layer will be inferred
self.l1 = L.Linear(None, n_units) # n_in -> n_units
self.l2 = L.Linear(None, n_units) # n_units -> n_units
self.l3 = L.Linear(None, n_out) # n_units -> n_out
def forward(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)
def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--device', '-d', type=str, default='-1',
help='Device specifier. Either ChainerX device '
'specifier or an integer. If non-negative integer, '
'CuPy arrays with specified device id are used. If '
'negative integer, NumPy arrays are used')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
args = parser.parse_args()
device = chainer.get_device(args.device)
print('Device: {}'.format(device))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
# Set up a neural network to train
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
model = L.Classifier(MLP(args.unit, 10))
model.to_device(device)
device.use()
# Setup an optimizer
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
# Load the MNIST dataset
train, _ = chainer.datasets.get_mnist()
train = chainer.datasets.TupleDataset(
np.stack([train[0][0]]), np.stack([train[0][1]]))
train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
# Set up a trainer
updater = training.updaters.StandardUpdater(
train_iter, optimizer, device=device)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
# Write a log of evaluation statistics for each epoch
trainer.extend(extensions.LogReport())
# Run the training
trainer.run()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment