Created
June 12, 2019 03:21
-
-
Save okapies/ab7c8f413c3bb46c81b4b4e8ba0e7603 to your computer and use it in GitHub Desktop.
A customized train_mnist example to measure the performance of extension
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
import chainer | |
import chainer.functions as F | |
import chainer.links as L | |
from chainer import training | |
from chainer.training import extensions | |
import numpy as np | |
# Network definition | |
class MLP(chainer.Chain): | |
def __init__(self, n_units, n_out): | |
super(MLP, self).__init__() | |
with self.init_scope(): | |
# the size of the inputs to each layer will be inferred | |
self.l1 = L.Linear(None, n_units) # n_in -> n_units | |
self.l2 = L.Linear(None, n_units) # n_units -> n_units | |
self.l3 = L.Linear(None, n_out) # n_units -> n_out | |
def forward(self, x): | |
h1 = F.relu(self.l1(x)) | |
h2 = F.relu(self.l2(h1)) | |
return self.l3(h2) | |
def main(): | |
parser = argparse.ArgumentParser(description='Chainer example: MNIST') | |
parser.add_argument('--batchsize', '-b', type=int, default=100, | |
help='Number of images in each mini-batch') | |
parser.add_argument('--epoch', '-e', type=int, default=20, | |
help='Number of sweeps over the dataset to train') | |
parser.add_argument('--device', '-d', type=str, default='-1', | |
help='Device specifier. Either ChainerX device ' | |
'specifier or an integer. If non-negative integer, ' | |
'CuPy arrays with specified device id are used. If ' | |
'negative integer, NumPy arrays are used') | |
parser.add_argument('--out', '-o', default='result', | |
help='Directory to output the result') | |
parser.add_argument('--unit', '-u', type=int, default=1000, | |
help='Number of units') | |
args = parser.parse_args() | |
device = chainer.get_device(args.device) | |
print('Device: {}'.format(device)) | |
print('# unit: {}'.format(args.unit)) | |
print('# Minibatch-size: {}'.format(args.batchsize)) | |
print('# epoch: {}'.format(args.epoch)) | |
print('') | |
# Set up a neural network to train | |
# Classifier reports softmax cross entropy loss and accuracy at every | |
# iteration, which will be used by the PrintReport extension below. | |
model = L.Classifier(MLP(args.unit, 10)) | |
model.to_device(device) | |
device.use() | |
# Setup an optimizer | |
optimizer = chainer.optimizers.Adam() | |
optimizer.setup(model) | |
# Load the MNIST dataset | |
train, _ = chainer.datasets.get_mnist() | |
train = chainer.datasets.TupleDataset( | |
np.stack([train[0][0]]), np.stack([train[0][1]])) | |
train_iter = chainer.iterators.SerialIterator(train, args.batchsize) | |
# Set up a trainer | |
updater = training.updaters.StandardUpdater( | |
train_iter, optimizer, device=device) | |
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) | |
# Write a log of evaluation statistics for each epoch | |
trainer.extend(extensions.LogReport()) | |
# Run the training | |
trainer.run() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment