Created
November 3, 2020 11:53
-
-
Save mberr/15e321ccacdeb517eba3443e080c02a7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import logging | |
logger = logging.getLogger(__name__) | |
# pylint: disable=abstract-method | |
class ExtendedModule(nn.Module): | |
"""Extends nn.Module by a few utility methods.""" | |
@property | |
def device(self) -> torch.device: | |
"""Return the model's device.""" | |
devices = { | |
tensor.data.device | |
for tensor in itertools.chain(self.parameters(), self.buffers()) | |
} | |
if len(devices) == 0: | |
raise ValueError('Could not infer device, since there are neither parameters nor buffers.') | |
elif len(devices) > 1: | |
device_info = dict( | |
parameters=dict(self.named_parameters()), | |
buffers=dict(self.named_buffers()), | |
) | |
raise ValueError(f'Ambiguous device! Found: {devices}\n\n{device_info}') | |
return next(iter(devices)) | |
def reset_parameters(self): | |
"""Reset the model's parameters.""" | |
# Make sure that all modules with parameters do have a reset_parameters method. | |
uninitialized_parameters = set(map(id, self.parameters())) | |
parents = defaultdict(list) | |
# Recursively visit all sub-modules | |
task_list = [] | |
for name, module in self.named_modules(): | |
# skip self | |
if module is self: | |
continue | |
# Track parents for blaming | |
for p in module.parameters(): | |
parents[id(p)].append(module) | |
# call reset_parameters if possible | |
if hasattr(module, 'reset_parameters'): | |
task_list.append((name.count('.'), module)) | |
# initialize from bottom to top | |
# This ensures that specialized initializations will take priority over the default ones of its components. | |
for module in map(itemgetter(1), sorted(task_list, reverse=True, key=itemgetter(0))): | |
module.reset_parameters() | |
uninitialized_parameters.difference_update(map(id, module.parameters())) | |
# emit warning if there where parameters which were not initialised by reset_parameters. | |
if len(uninitialized_parameters) > 0: | |
logger.warning('reset_parameters() not found for all modules containing parameters. %d parameters where likely not initialised.', len(uninitialized_parameters)) | |
# Additional debug information | |
for i, p_id in enumerate(uninitialized_parameters, start=1): | |
logger.debug('[%3d] Parents to blame: %s', i, parents.get(p_id)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment