Skip to content

Instantly share code, notes, and snippets.

@mberr
Created November 3, 2020 11:53
Show Gist options
  • Save mberr/15e321ccacdeb517eba3443e080c02a7 to your computer and use it in GitHub Desktop.
Save mberr/15e321ccacdeb517eba3443e080c02a7 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
import logging
logger = logging.getLogger(__name__)
# pylint: disable=abstract-method
class ExtendedModule(nn.Module):
"""Extends nn.Module by a few utility methods."""
@property
def device(self) -> torch.device:
"""Return the model's device."""
devices = {
tensor.data.device
for tensor in itertools.chain(self.parameters(), self.buffers())
}
if len(devices) == 0:
raise ValueError('Could not infer device, since there are neither parameters nor buffers.')
elif len(devices) > 1:
device_info = dict(
parameters=dict(self.named_parameters()),
buffers=dict(self.named_buffers()),
)
raise ValueError(f'Ambiguous device! Found: {devices}\n\n{device_info}')
return next(iter(devices))
def reset_parameters(self):
"""Reset the model's parameters."""
# Make sure that all modules with parameters do have a reset_parameters method.
uninitialized_parameters = set(map(id, self.parameters()))
parents = defaultdict(list)
# Recursively visit all sub-modules
task_list = []
for name, module in self.named_modules():
# skip self
if module is self:
continue
# Track parents for blaming
for p in module.parameters():
parents[id(p)].append(module)
# call reset_parameters if possible
if hasattr(module, 'reset_parameters'):
task_list.append((name.count('.'), module))
# initialize from bottom to top
# This ensures that specialized initializations will take priority over the default ones of its components.
for module in map(itemgetter(1), sorted(task_list, reverse=True, key=itemgetter(0))):
module.reset_parameters()
uninitialized_parameters.difference_update(map(id, module.parameters()))
# emit warning if there where parameters which were not initialised by reset_parameters.
if len(uninitialized_parameters) > 0:
logger.warning('reset_parameters() not found for all modules containing parameters. %d parameters where likely not initialised.', len(uninitialized_parameters))
# Additional debug information
for i, p_id in enumerate(uninitialized_parameters, start=1):
logger.debug('[%3d] Parents to blame: %s', i, parents.get(p_id))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment