Skip to content

Instantly share code, notes, and snippets.

Last active November 6, 2017 09:40
Show Gist options
  • Save corochann/7455ca705e325893bb25851f64df4ea3 to your computer and use it in GitHub Desktop.
Save corochann/7455ca705e325893bb25851f64df4ea3 to your computer and use it in GitHub Desktop.
Early stopping for Chainer using trainer extension.
from chainer import reporter
from import util
class EarlyStoppingTrigger(object):
"""Early stopping trigger
It observes the value specified by `key`, and invoke a trigger only when
observing value satisfies the `stop_condition`.
The trigger may be used to `stop_trigger` option of Trainer module for
early stopping the training.
max_epoch (int or float): Max epoch for the training, even if the value
is not reached to the condition specified by `stop_condition`,
finish the training if it reaches to `max_epoch` epoch.
key (str): Key of value to be observe for `stop_condition`.
stop_condition (callable): To check the previous value and current value
to decide early stop timing. Default value is `None`, in that case
internal `_stop_condition` method is used.
eps (float): It is used by the internal `_stop_condition`.
trigger: Trigger that decides the comparison interval between previous
best value and current value. This must be a tuple in the form of
``<int>, 'epoch'`` or ``<int>, 'iteration'`` which is passed to
def __init__(self, max_epoch, key, stop_condition=None, eps=0.01,
trigger=(1, 'epoch')):
self.max_epoch = max_epoch
self.eps = eps
self._key = key
self._current_value = None
self._interval_trigger = util.get_trigger(trigger)
self.stop_condition = stop_condition or self._stop_condition
def __call__(self, trainer):
"""Decides whether the extension should be called on this iteration.
trainer ( Trainer object that this
trigger is associated with. The ``observation`` of this trainer
is used to determine if the trigger should fire.
bool: ``True`` if the corresponding extension should be invoked in
this iteration.
epoch_detail = trainer.updater.epoch_detail
if self.max_epoch <= epoch_detail:
print('Reached to max_epoch.')
return True
observation = trainer.observation
summary = self._summary
key = self._key
if key in observation:
summary.add({key: observation[key]})
if not self._interval_trigger(trainer):
return False
stats = summary.compute_mean()
value = float(stats[key]) # copy to CPU
if self._current_value is None:
self._current_value = value
return False
if self.stop_condition(self._current_value, value):
# print('Previous value {}, Current value {}'
# .format(self._current_value, value))
print('Invoke EarlyStoppingTrigger...')
self._current_value = value
return True
self._current_value = value
return False
def _init_summary(self):
self._summary = reporter.DictSummary()
def _stop_condition(self, current_value, new_value):
return current_value - new_value < self.eps
from __future__ import print_function
import argparse
import matplotlib
import matplotlib.pyplot as plt
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from import extensions
from chainer import serializers
from early_stopping_trigger import EarlyStoppingTrigger
class MLP(chainer.Chain):
def __init__(self, n_units, n_out):
super(MLP, self).__init__()
with self.init_scope():
# input size of each layer will be inferred when omitted
self.l1 = L.Linear(n_units) # n_in -> n_units
self.l2 = L.Linear(n_out) # n_units -> n_out
def __call__(self, x):
h1 = F.relu(self.l1(x))
return self.l2(h1)
def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=50,
help='Number of units')
args = parser.parse_args()
print('GPU: {}'.format(args.gpu))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
model = MLP(args.unit, 10)
classifier_model = L.Classifier(model)
if args.gpu >= 0:
chainer.cuda.get_device_from_id(args.gpu).use() # Make a specified GPU current
classifier_model.to_gpu() # Copy the model to the GPU
# Load the MNIST dataset
train, test = chainer.datasets.get_mnist()
train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False)
optimizer = chainer.optimizers.MomentumSGD()
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
early_stop = EarlyStoppingTrigger(args.epoch, key='validation/main/loss', eps=0.01)
trainer = training.Trainer(updater, stop_trigger=early_stop, out=args.out)
# Evaluate the model with the test dataset for each epoch
trainer.extend(extensions.Evaluator(test_iter, classifier_model, device=args.gpu))
# Take a snapshot at each epoch
trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))
# Write a log of evaluation statistics for each epoch
# Print selected entries of the log to stdout
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
if extensions.PlotReport.available():
# Plot graph for loss for each epoch
['main/loss', 'validation/main/loss'],
x_key='epoch', file_name='loss.png'))
['main/accuracy', 'validation/main/accuracy'],
# Print a progress bar to stdout
trainer.extend(extensions.ProgressBar(training_length=(args.epoch, 'epoch')))
if args.resume:
# Resume from a snapshot
serializers.load_npz(args.resume, trainer)
# Run the training
serializers.save_npz('{}/mlp.model'.format(args.out), model)
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment