Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Last active March 16, 2022 18:10
Show Gist options
  • Save jamesr66a/990a23073fd037adaa8f568569506e9e to your computer and use it in GitHub Desktop.
Save jamesr66a/990a23073fd037adaa8f568569506e9e to your computer and use it in GitHub Desktop.
import torch
import torch.fx
class TestModuleWithShapeControlFlow(torch.nn.Module):
def forward(self, x):
# with normal symtracing, the `x.dim()` accesses would fail
if x.dim() == 3:
y = x[0, :, :]
elif x.dim() == 4:
y = x[0, :, :, :]
else:
raise RuntimeError('Unsupported rank for x')
# with normal symtracing, the `y.shape[0] == 3` expressions would be a proxy,
# then this code would fail when that's used as a bool
return y + y.shape[0] if y.shape[0] == 3 else torch.neg(y)
tmwscf = TestModuleWithShapeControlFlow()
import warnings
from typing import NamedTuple, Iterable
# Solution: meta-tensor tracing. We define a MetaTensorTracer and MetaTensorProxy
# that takes `MetaTensor` instances that describe the inputs, and carries forward
# the metadata for intermediate values in the program. This allows control flow based
# on expressions like `dim`, `shape`, or `dtype` to work during tracing, producing
# a trace that is specialized to those values
class MetaTensorProxy(torch.fx.Proxy):
def __init__(self, node, tracer):
self._meta_tensor = None
super().__init__(node, tracer)
def install_meta_tensor(self, meta_tensor):
assert isinstance(meta_tensor, torch.Tensor) and meta_tensor.device == torch.device('meta')
self._meta_tensor = meta_tensor
def dim(self):
if self._meta_tensor is None:
return super().__getattr__('dim')
return self._meta_tensor.dim()
@property
def shape(self):
if self._meta_tensor is None:
return super().__getattr__('shape')
return self._meta_tensor.shape
class MetaTensorTracer(torch.fx.Tracer):
def create_proxy(self, kind, target, args, kwargs, name = None, type_expr = None, proxy_factory_fn = None):
if proxy_factory_fn is not None:
raise RuntimeError("Don't support custom proxy factory function for MetaTensorTracer")
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, lambda n: MetaTensorProxy(n, self))
def extract_meta(a):
if isinstance(a, MetaTensorProxy):
if getattr(a, '_meta', None) is not None:
return a._meta_tensor
else:
return None
return a
try:
meta_args = torch.fx.node.map_aggregate(args if args else (), extract_meta)
meta_kwargs = torch.fx.node.map_aggregate(kwargs if kwargs else {}, extract_meta)
if kind == 'call_function':
meta_target = target
elif kind == 'call_method':
assert isinstance(args[0], torch.fx.Proxy)
meta_target = getattr(torch.Tensor, target)
elif kind == 'call_module':
raise RuntimeError('Not yet implemented')
elif kind == 'placeholder':
proxy.install_meta_tensor(next(self.concrete_meta_iter))
return proxy
else:
assert False, f'Unknown target {kind}'
meta_out = meta_target(*meta_args, **meta_kwargs)
if isinstance(meta_out, torch.Tensor):
proxy.install_meta_tensor(meta_out)
except Exception as e:
warnings.warn(f"Could not compute shape for value {proxy}: {e}")
return proxy
def trace(self, root, concrete_args = None, concrete_metas = None):
self.concrete_metas = concrete_metas
self.concrete_meta_iter = iter(self.concrete_metas)
return super().trace(root, concrete_args)
mtt = MetaTensorTracer()
traced = mtt.trace(tmwscf, concrete_metas=[torch.empty(3, 4, 5, device='meta')])
gm = torch.fx.GraphModule(mtt.root, traced)
print(gm.code)
x = torch.randn(3, 4, 5)
torch.testing.assert_allclose(gm(x), tmwscf(x))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment