Last active
April 3, 2018 08:48
-
-
Save thomwolf/19cf52cb9263880b6c7557c31d4ce352 to your computer and use it in GitHub Desktop.
A PyTorch iterator over module parameters that allows to update module parameters (and not only the data tensor).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_params(module, memo=None, pointers=None): | |
""" Returns an iterator over PyTorch module parameters that allows to update parameters | |
(and not only the data). | |
! Side effect: update shared parameters to point to the first yield instance | |
(i.e. you can update shared parameters and keep them shared) | |
Yields: | |
(Module, string, Parameter): Tuple containing the parameter's module, name and pointer | |
""" | |
if memo is None: | |
memo = set() | |
pointers = {} | |
for name, p in module._parameters.items(): | |
if p not in memo: | |
memo.add(p) | |
pointers[p] = (module, name) | |
yield module, name, p | |
elif p is not None: | |
prev_module, prev_name = pointers[p] | |
module._parameters[name] = prev_module._parameters[prev_name] # update shared parameter pointer | |
for child_module in module.children(): | |
for m, n, p in get_params(child_module, memo, pointers): | |
yield m, n, p |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment