Skip to content

Instantly share code, notes, and snippets.

@hvy
Last active April 22, 2017 12:00
Show Gist options
  • Save hvy/ddd24c4304fc476d8d97b207310678e8 to your computer and use it in GitHub Desktop.
Save hvy/ddd24c4304fc476d8d97b207310678e8 to your computer and use it in GitHub Desktop.
import argparse
import numpy
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
class MLP(chainer.Chain):
def __init__(self):
super(MLP, self).__init__(
fc1=L.Linear(None, 1024),
fc2=L.Linear(1024, 10),
dummy1=L.LSTM(64, 64),
dummy2=L.BatchNormalization(64)
)
def __call__(self, x):
h = F.relu(self.fc1(x))
return self.fc2(h)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--batchsize', type=int, default=100)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--ps-test', type=int, default=0)
args = parser.parse_args()
model = L.Classifier(MLP())
if args.gpu >= 0:
chainer.cuda.get_device(args.gpu).use()
model.to_gpu()
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
train, test = chainer.datasets.get_mnist()
train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False)
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'))
if args.gpu >= 0:
import cupy
xp = cupy
else:
xp = numpy
# Register the parameter statistics extension
trigger = (1, 'iteration')
if args.ps_test == 0: # A. All links contained in model
ps = extensions.ParameterStatistics(model, trigger=trigger)
elif args.ps_test == 1: # B. Single link, with a prefix
ps = extensions.ParameterStatistics(model.predictor.fc1,
prefix='myprefix',
trigger=trigger)
elif args.ps_test == 2: # C. Specify statistic generator
ps = extensions.ParameterStatistics(model,
statistics={'min': xp.min},
trigger=trigger)
elif args.ps_test == 3:
# D. Specify statistic generator with late registration
ps = extensions.ParameterStatistics(model,
statistics=None,
trigger=trigger)
ps.register_statistics('mean', xp.mean)
ps.register_statistics('max', xp.max)
else: # E. Custom statistic generator, skip grads
ps = extensions.ParameterStatistics(model,
statistics=None,
report_grads=False,
trigger=trigger)
ps.register_statistics('zeros', lambda x: xp.count_nonzero(x == 0))
trainer.extend(ps, trigger=trigger)
trainer.extend(extensions.LogReport(trigger=(1, 'iteration')))
trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
trainer.extend(extensions.ProgressBar())
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
trainer.run()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment