Last active
January 1, 2022 19:11
-
-
Save daskol/aa8e11cc0bf6adb889f4c152ab262ac1 to your computer and use it in GitHub Desktop.
JAX-like routines for module transformations in PyTorch (see jax.tree_util package).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Module tree_util implements routines for inplace transformations on PyTorch | |
| modules as trees. It provides JAX-like API (see jax.tree_util package). | |
| >>> from transformers import RobertaModel | |
| >>> model = RobertaModel.from_pretrained('roberta-base') | |
| >>> converted = map_module(module=model, | |
| >>> func=lambda x, _: convert_linear(x, MyFancyLinear)) | |
| """ | |
| import re | |
| import torch as T | |
| from functools import wraps | |
| from re import Pattern | |
| from typing import Callable, Optional | |
| __all__ = ('convert_linear', 'map_module') | |
| def convert_linear(module: T.nn.Linear, ctor, **kwargs) -> T.nn.Module: | |
| """Function convert_linear takes module and returns linear module with | |
| approximate matmul. Non-linear modules are returned intact. | |
| """ | |
| if not isinstance(module, T.nn.Linear): | |
| return module | |
| return ctor(in_features=module.in_features, | |
| out_features=module.out_features, | |
| bias=module.bias is not None, | |
| device=module.weight.device, | |
| dtype=module.weight.dtype, | |
| **kwargs) | |
| def map_module(root: T.nn.Module, | |
| func: Callable[[T.nn.Module, str], T.nn.Module], | |
| patt: Optional[str] = None) -> T.nn.Module: | |
| """Function map_module applies function to each leaf of module tree which | |
| matches to a specified pattern. | |
| """ | |
| @wraps(func) | |
| def func_safe(*args, **kwargs): | |
| node = func(*args, **kwargs) | |
| if not isinstance(node, T.nn.Module): | |
| raise ValueError('Mapped result should be toch.nn.Module type.') | |
| return node | |
| return _map_module(root, func_safe, re.compile(patt or r'.*'), '') | |
| def _map_module(root: T.nn.Module, | |
| func: Callable[[T.nn.Module, str], T.nn.Module], | |
| patt: Pattern, path: str) -> T.nn.Module: | |
| for name, child in root.named_children(): | |
| node = _map_module(child, func, patt, f'{path}/{name}') | |
| if node != child: | |
| setattr(root, name, node) | |
| if patt.match(path or '/'): | |
| root = func(root, path or '/') | |
| return root |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment