Skip to content

Instantly share code, notes, and snippets.

@tsuchm
Last active February 18, 2020 05:05
Show Gist options
  • Save tsuchm/23dfb7562a7145b7af4825c985d5aed7 to your computer and use it in GitHub Desktop.
Save tsuchm/23dfb7562a7145b7af4825c985d5aed7 to your computer and use it in GitHub Desktop.
from chainer import function, training
from chainer import reporter as reporter_module
from chainer.dataset import convert
from collections import defaultdict
class PreciseEvaluator(training.extensions.Evaluator):
""":class:`chainer.training.extensions.Evaluator` module uses
:func:`chainer.reporter.compute_mean` to calculate the mean
accuracy and the mean loss over the multiple mini batches.
Because `compute_mean()` does not take the variation of mini batch
size into account, `Evaluator` may return an inaccurate loss when
small validation set is used and its size is quite near from the
multiple of the mini batch size.
For example, consider the case that the number of validation
instances is 130 and the size of mini batch is 128. For this
case, `iterator()' produces two mini batches, whose sizes are:
The size of the 1st mini batch is 128, and
The size of the 2nd mini batch is 2.
For this case, the loss of the 2nd mini batch is unfairly assessed
60 times larger than the loss of the 1st mini batch.
This module is designed to avoid the above problem.
"""
def evaluate(self):
iterator = self._iterators['main']
eval_func = self.eval_func or self._targets['main']
if self.eval_hook:
self.eval_hook(self)
if hasattr(iterator, 'reset'):
iterator.reset()
it = iterator
else:
it = copy.copy(iterator)
if self._progress_bar:
pbar = _IteratorProgressBar(iterator=it)
total = 0
summary = defaultdict(lambda: 0)
for batch in it:
observation = {}
with reporter_module.report_scope(observation):
in_arrays = convert._call_converter(self.converter, batch, self.device)
with function.no_backprop_mode():
if isinstance(in_arrays, tuple):
eval_func(*in_arrays)
elif isinstance(in_arrays, dict):
eval_func(**in_arrays)
else:
eval_func(in_arrays)
total += len(batch)
for k,v in observation.items():
summary[k] += v * len(batch)
if self._progress_bar:
pbar.update()
if self._progress_bar:
pbar.close()
return {k: v/total for k,v in summary.items()}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment