Last active
November 6, 2017 09:40
-
-
Save corochann/7455ca705e325893bb25851f64df4ea3 to your computer and use it in GitHub Desktop.
Early stopping for Chainer using trainer extension.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from chainer import reporter | |
from chainer.training import util | |
class EarlyStoppingTrigger(object): | |
"""Early stopping trigger | |
It observes the value specified by `key`, and invoke a trigger only when | |
observing value satisfies the `stop_condition`. | |
The trigger may be used to `stop_trigger` option of Trainer module for | |
early stopping the training. | |
Args: | |
max_epoch (int or float): Max epoch for the training, even if the value | |
is not reached to the condition specified by `stop_condition`, | |
finish the training if it reaches to `max_epoch` epoch. | |
key (str): Key of value to be observe for `stop_condition`. | |
stop_condition (callable): To check the previous value and current value | |
to decide early stop timing. Default value is `None`, in that case | |
internal `_stop_condition` method is used. | |
eps (float): It is used by the internal `_stop_condition`. | |
trigger: Trigger that decides the comparison interval between previous | |
best value and current value. This must be a tuple in the form of | |
``<int>, 'epoch'`` or ``<int>, 'iteration'`` which is passed to | |
:class:`~chainer.training.triggers.IntervalTrigger`. | |
""" | |
def __init__(self, max_epoch, key, stop_condition=None, eps=0.01, | |
trigger=(1, 'epoch')): | |
self.max_epoch = max_epoch | |
self.eps = eps | |
self._key = key | |
self._current_value = None | |
self._interval_trigger = util.get_trigger(trigger) | |
self._init_summary() | |
self.stop_condition = stop_condition or self._stop_condition | |
def __call__(self, trainer): | |
"""Decides whether the extension should be called on this iteration. | |
Args: | |
trainer (~chainer.training.Trainer): Trainer object that this | |
trigger is associated with. The ``observation`` of this trainer | |
is used to determine if the trigger should fire. | |
Returns: | |
bool: ``True`` if the corresponding extension should be invoked in | |
this iteration. | |
""" | |
epoch_detail = trainer.updater.epoch_detail | |
if self.max_epoch <= epoch_detail: | |
print('Reached to max_epoch.') | |
return True | |
observation = trainer.observation | |
summary = self._summary | |
key = self._key | |
if key in observation: | |
summary.add({key: observation[key]}) | |
if not self._interval_trigger(trainer): | |
return False | |
stats = summary.compute_mean() | |
value = float(stats[key]) # copy to CPU | |
self._init_summary() | |
if self._current_value is None: | |
self._current_value = value | |
return False | |
else: | |
if self.stop_condition(self._current_value, value): | |
# print('Previous value {}, Current value {}' | |
# .format(self._current_value, value)) | |
print('Invoke EarlyStoppingTrigger...') | |
self._current_value = value | |
return True | |
else: | |
self._current_value = value | |
return False | |
def _init_summary(self): | |
self._summary = reporter.DictSummary() | |
def _stop_condition(self, current_value, new_value): | |
return current_value - new_value < self.eps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import argparse | |
import matplotlib | |
matplotlib.use('Agg') | |
import matplotlib.pyplot as plt | |
import chainer | |
import chainer.functions as F | |
import chainer.links as L | |
from chainer import training | |
from chainer.training import extensions | |
from chainer import serializers | |
from early_stopping_trigger import EarlyStoppingTrigger | |
class MLP(chainer.Chain): | |
def __init__(self, n_units, n_out): | |
super(MLP, self).__init__() | |
with self.init_scope(): | |
# input size of each layer will be inferred when omitted | |
self.l1 = L.Linear(n_units) # n_in -> n_units | |
self.l2 = L.Linear(n_out) # n_units -> n_out | |
def __call__(self, x): | |
h1 = F.relu(self.l1(x)) | |
return self.l2(h1) | |
def main(): | |
parser = argparse.ArgumentParser(description='Chainer example: MNIST') | |
parser.add_argument('--batchsize', '-b', type=int, default=100, | |
help='Number of images in each mini-batch') | |
parser.add_argument('--epoch', '-e', type=int, default=20, | |
help='Number of sweeps over the dataset to train') | |
parser.add_argument('--gpu', '-g', type=int, default=-1, | |
help='GPU ID (negative value indicates CPU)') | |
parser.add_argument('--out', '-o', default='result', | |
help='Directory to output the result') | |
parser.add_argument('--resume', '-r', default='', | |
help='Resume the training from snapshot') | |
parser.add_argument('--unit', '-u', type=int, default=50, | |
help='Number of units') | |
args = parser.parse_args() | |
print('GPU: {}'.format(args.gpu)) | |
print('# unit: {}'.format(args.unit)) | |
print('# Minibatch-size: {}'.format(args.batchsize)) | |
print('# epoch: {}'.format(args.epoch)) | |
print('') | |
model = MLP(args.unit, 10) | |
classifier_model = L.Classifier(model) | |
if args.gpu >= 0: | |
chainer.cuda.get_device_from_id(args.gpu).use() # Make a specified GPU current | |
classifier_model.to_gpu() # Copy the model to the GPU | |
# Load the MNIST dataset | |
train, test = chainer.datasets.get_mnist() | |
train_iter = chainer.iterators.SerialIterator(train, args.batchsize) | |
test_iter = chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) | |
optimizer = chainer.optimizers.MomentumSGD() | |
optimizer.setup(classifier_model) | |
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu) | |
early_stop = EarlyStoppingTrigger(args.epoch, key='validation/main/loss', eps=0.01) | |
trainer = training.Trainer(updater, stop_trigger=early_stop, out=args.out) | |
# Evaluate the model with the test dataset for each epoch | |
trainer.extend(extensions.Evaluator(test_iter, classifier_model, device=args.gpu)) | |
# Take a snapshot at each epoch | |
trainer.extend(extensions.snapshot(), trigger=(1, 'epoch')) | |
# Write a log of evaluation statistics for each epoch | |
trainer.extend(extensions.LogReport()) | |
# Print selected entries of the log to stdout | |
trainer.extend(extensions.PrintReport( | |
['epoch', 'main/loss', 'validation/main/loss', | |
'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) | |
if extensions.PlotReport.available(): | |
# Plot graph for loss for each epoch | |
trainer.extend(extensions.PlotReport( | |
['main/loss', 'validation/main/loss'], | |
x_key='epoch', file_name='loss.png')) | |
trainer.extend(extensions.PlotReport( | |
['main/accuracy', 'validation/main/accuracy'], | |
x_key='epoch', | |
file_name='accuracy.png')) | |
# Print a progress bar to stdout | |
trainer.extend(extensions.ProgressBar(training_length=(args.epoch, 'epoch'))) | |
if args.resume: | |
# Resume from a snapshot | |
serializers.load_npz(args.resume, trainer) | |
# Run the training | |
trainer.run() | |
serializers.save_npz('{}/mlp.model'.format(args.out), model) | |
if __name__ == '__main__': | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment