Skip to content

Instantly share code, notes, and snippets.

@kadeng
Created March 20, 2024 15:54
Show Gist options
  • Save kadeng/2544c6add99c8b55d9000ffe04e69974 to your computer and use it in GitHub Desktop.
Save kadeng/2544c6add99c8b55d9000ffe04e69974 to your computer and use it in GitHub Desktop.
test method attempt
def test_wrapper_codegen_statically_known_int_or_none(self) -> typing.List[CachingAutotuner]:
from torch._dynamo.utils import detect_fake_mode
from torch._inductor.codegen.common import boolean_ops
from torch._inductor.codegen.wrapper import WrapperCodeGen
from torch._inductor.compile_fx import _shape_env_from_inputs
from torch._inductor.debug import DebugContext
from torch._inductor.graph import GraphLowering
from torch._inductor.virtualized import V
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
def fn_1(x):
# no constraint
return (x,)
def fn_2(x):
# constrain in two directions
if x.shape[0] > 5:
return (x,)
if x.shape[0] < 5:
return (x,)
# x.shape[0] == 5 at this point
return (x,)
def fn_3(x):
# equality constraint, which matches example shape
if x.size(0) == 5:
return (x,)
else:
return (x,)
torch._dynamo.reset()
_x = torch.randn([5, 3, 3])
torch._dynamo.maybe_mark_dynamic(_x, 0)
for i, fn in enumerate([fn_1, fn_2, fn_3]):
gm = torch.fx.symbolic_trace(fn) # The graph must be compatible with GraphLowering: No call_method ops
shape_env = _shape_env_from_inputs(_x)
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
FakeTensorProp(gm, mode=fake_mode).propagate(_x)
with V.set_fake_mode(fake_mode):
graph = GraphLowering(
gm,
shape_env=shape_env,
num_static_inputs=0,
)
with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()):
graph.run(_x)
input_layouts = [
inp.layout
for inp in graph.graph_inputs.values()
if hasattr(inp, "layout")
]
batch_dim = input_layouts[0].size[0]
if i==0:
# testing fn_1
assert (
WrapperCodeGen.statically_known_int_or_none(batch_dim) is None
), "Should not be statically known on first call"
else:
# testing fn_2 or fn_3
assert (
WrapperCodeGen.statically_known_int_or_none(batch_dim) == 5
), "Should be limited to exactly 5 on second and third call due to multiple constraints"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment