Last active
March 16, 2022 18:10
-
-
Save jamesr66a/990a23073fd037adaa8f568569506e9e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.fx | |
class TestModuleWithShapeControlFlow(torch.nn.Module): | |
def forward(self, x): | |
# with normal symtracing, the `x.dim()` accesses would fail | |
if x.dim() == 3: | |
y = x[0, :, :] | |
elif x.dim() == 4: | |
y = x[0, :, :, :] | |
else: | |
raise RuntimeError('Unsupported rank for x') | |
# with normal symtracing, the `y.shape[0] == 3` expressions would be a proxy, | |
# then this code would fail when that's used as a bool | |
return y + y.shape[0] if y.shape[0] == 3 else torch.neg(y) | |
tmwscf = TestModuleWithShapeControlFlow() | |
import warnings | |
from typing import NamedTuple, Iterable | |
# Solution: meta-tensor tracing. We define a MetaTensorTracer and MetaTensorProxy | |
# that takes `MetaTensor` instances that describe the inputs, and carries forward | |
# the metadata for intermediate values in the program. This allows control flow based | |
# on expressions like `dim`, `shape`, or `dtype` to work during tracing, producing | |
# a trace that is specialized to those values | |
class MetaTensorProxy(torch.fx.Proxy): | |
def __init__(self, node, tracer): | |
self._meta_tensor = None | |
super().__init__(node, tracer) | |
def install_meta_tensor(self, meta_tensor): | |
assert isinstance(meta_tensor, torch.Tensor) and meta_tensor.device == torch.device('meta') | |
self._meta_tensor = meta_tensor | |
def dim(self): | |
if self._meta_tensor is None: | |
return super().__getattr__('dim') | |
return self._meta_tensor.dim() | |
@property | |
def shape(self): | |
if self._meta_tensor is None: | |
return super().__getattr__('shape') | |
return self._meta_tensor.shape | |
class MetaTensorTracer(torch.fx.Tracer): | |
def create_proxy(self, kind, target, args, kwargs, name = None, type_expr = None, proxy_factory_fn = None): | |
if proxy_factory_fn is not None: | |
raise RuntimeError("Don't support custom proxy factory function for MetaTensorTracer") | |
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, lambda n: MetaTensorProxy(n, self)) | |
def extract_meta(a): | |
if isinstance(a, MetaTensorProxy): | |
if getattr(a, '_meta', None) is not None: | |
return a._meta_tensor | |
else: | |
return None | |
return a | |
try: | |
meta_args = torch.fx.node.map_aggregate(args if args else (), extract_meta) | |
meta_kwargs = torch.fx.node.map_aggregate(kwargs if kwargs else {}, extract_meta) | |
if kind == 'call_function': | |
meta_target = target | |
elif kind == 'call_method': | |
assert isinstance(args[0], torch.fx.Proxy) | |
meta_target = getattr(torch.Tensor, target) | |
elif kind == 'call_module': | |
raise RuntimeError('Not yet implemented') | |
elif kind == 'placeholder': | |
proxy.install_meta_tensor(next(self.concrete_meta_iter)) | |
return proxy | |
else: | |
assert False, f'Unknown target {kind}' | |
meta_out = meta_target(*meta_args, **meta_kwargs) | |
if isinstance(meta_out, torch.Tensor): | |
proxy.install_meta_tensor(meta_out) | |
except Exception as e: | |
warnings.warn(f"Could not compute shape for value {proxy}: {e}") | |
return proxy | |
def trace(self, root, concrete_args = None, concrete_metas = None): | |
self.concrete_metas = concrete_metas | |
self.concrete_meta_iter = iter(self.concrete_metas) | |
return super().trace(root, concrete_args) | |
mtt = MetaTensorTracer() | |
traced = mtt.trace(tmwscf, concrete_metas=[torch.empty(3, 4, 5, device='meta')]) | |
gm = torch.fx.GraphModule(mtt.root, traced) | |
print(gm.code) | |
x = torch.randn(3, 4, 5) | |
torch.testing.assert_allclose(gm(x), tmwscf(x)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment