Last active
February 18, 2020 05:05
-
-
Save tsuchm/23dfb7562a7145b7af4825c985d5aed7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from chainer import function, training | |
from chainer import reporter as reporter_module | |
from chainer.dataset import convert | |
from collections import defaultdict | |
class PreciseEvaluator(training.extensions.Evaluator): | |
""":class:`chainer.training.extensions.Evaluator` module uses | |
:func:`chainer.reporter.compute_mean` to calculate the mean | |
accuracy and the mean loss over the multiple mini batches. | |
Because `compute_mean()` does not take the variation of mini batch | |
size into account, `Evaluator` may return an inaccurate loss when | |
small validation set is used and its size is quite near from the | |
multiple of the mini batch size. | |
For example, consider the case that the number of validation | |
instances is 130 and the size of mini batch is 128. For this | |
case, `iterator()' produces two mini batches, whose sizes are: | |
The size of the 1st mini batch is 128, and | |
The size of the 2nd mini batch is 2. | |
For this case, the loss of the 2nd mini batch is unfairly assessed | |
60 times larger than the loss of the 1st mini batch. | |
This module is designed to avoid the above problem. | |
""" | |
def evaluate(self): | |
iterator = self._iterators['main'] | |
eval_func = self.eval_func or self._targets['main'] | |
if self.eval_hook: | |
self.eval_hook(self) | |
if hasattr(iterator, 'reset'): | |
iterator.reset() | |
it = iterator | |
else: | |
it = copy.copy(iterator) | |
if self._progress_bar: | |
pbar = _IteratorProgressBar(iterator=it) | |
total = 0 | |
summary = defaultdict(lambda: 0) | |
for batch in it: | |
observation = {} | |
with reporter_module.report_scope(observation): | |
in_arrays = convert._call_converter(self.converter, batch, self.device) | |
with function.no_backprop_mode(): | |
if isinstance(in_arrays, tuple): | |
eval_func(*in_arrays) | |
elif isinstance(in_arrays, dict): | |
eval_func(**in_arrays) | |
else: | |
eval_func(in_arrays) | |
total += len(batch) | |
for k,v in observation.items(): | |
summary[k] += v * len(batch) | |
if self._progress_bar: | |
pbar.update() | |
if self._progress_bar: | |
pbar.close() | |
return {k: v/total for k,v in summary.items()} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment