Skip to content

Instantly share code, notes, and snippets.

@daskol
Last active January 1, 2022 19:11
Show Gist options
  • Save daskol/aa8e11cc0bf6adb889f4c152ab262ac1 to your computer and use it in GitHub Desktop.
Save daskol/aa8e11cc0bf6adb889f4c152ab262ac1 to your computer and use it in GitHub Desktop.
JAX-like routines for module transformations in PyTorch (see jax.tree_util package).
"""Module tree_util implements routines for inplace transformations on PyTorch
modules as trees. It provides JAX-like API (see jax.tree_util package).
>>> from transformers import RobertaModel
>>> model = RobertaModel.from_pretrained('roberta-base')
>>> converted = map_module(module=model,
>>> func=lambda x, _: convert_linear(x, MyFancyLinear))
"""
import re
import torch as T
from functools import wraps
from re import Pattern
from typing import Callable, Optional
__all__ = ('convert_linear', 'map_module')
def convert_linear(module: T.nn.Linear, ctor, **kwargs) -> T.nn.Module:
"""Function convert_linear takes module and returns linear module with
approximate matmul. Non-linear modules are returned intact.
"""
if not isinstance(module, T.nn.Linear):
return module
return ctor(in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
**kwargs)
def map_module(root: T.nn.Module,
func: Callable[[T.nn.Module, str], T.nn.Module],
patt: Optional[str] = None) -> T.nn.Module:
"""Function map_module applies function to each leaf of module tree which
matches to a specified pattern.
"""
@wraps(func)
def func_safe(*args, **kwargs):
node = func(*args, **kwargs)
if not isinstance(node, T.nn.Module):
raise ValueError('Mapped result should be toch.nn.Module type.')
return node
return _map_module(root, func_safe, re.compile(patt or r'.*'), '')
def _map_module(root: T.nn.Module,
func: Callable[[T.nn.Module, str], T.nn.Module],
patt: Pattern, path: str) -> T.nn.Module:
for name, child in root.named_children():
node = _map_module(child, func, patt, f'{path}/{name}')
if node != child:
setattr(root, name, node)
if patt.match(path or '/'):
root = func(root, path or '/')
return root
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment