Created
March 20, 2024 15:54
-
-
Save kadeng/2544c6add99c8b55d9000ffe04e69974 to your computer and use it in GitHub Desktop.
test method attempt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_wrapper_codegen_statically_known_int_or_none(self) -> typing.List[CachingAutotuner]: | |
from torch._dynamo.utils import detect_fake_mode | |
from torch._inductor.codegen.common import boolean_ops | |
from torch._inductor.codegen.wrapper import WrapperCodeGen | |
from torch._inductor.compile_fx import _shape_env_from_inputs | |
from torch._inductor.debug import DebugContext | |
from torch._inductor.graph import GraphLowering | |
from torch._inductor.virtualized import V | |
from torch.fx.passes.fake_tensor_prop import FakeTensorProp | |
def fn_1(x): | |
# no constraint | |
return (x,) | |
def fn_2(x): | |
# constrain in two directions | |
if x.shape[0] > 5: | |
return (x,) | |
if x.shape[0] < 5: | |
return (x,) | |
# x.shape[0] == 5 at this point | |
return (x,) | |
def fn_3(x): | |
# equality constraint, which matches example shape | |
if x.size(0) == 5: | |
return (x,) | |
else: | |
return (x,) | |
torch._dynamo.reset() | |
_x = torch.randn([5, 3, 3]) | |
torch._dynamo.maybe_mark_dynamic(_x, 0) | |
for i, fn in enumerate([fn_1, fn_2, fn_3]): | |
gm = torch.fx.symbolic_trace(fn) # The graph must be compatible with GraphLowering: No call_method ops | |
shape_env = _shape_env_from_inputs(_x) | |
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) | |
FakeTensorProp(gm, mode=fake_mode).propagate(_x) | |
with V.set_fake_mode(fake_mode): | |
graph = GraphLowering( | |
gm, | |
shape_env=shape_env, | |
num_static_inputs=0, | |
) | |
with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()): | |
graph.run(_x) | |
input_layouts = [ | |
inp.layout | |
for inp in graph.graph_inputs.values() | |
if hasattr(inp, "layout") | |
] | |
batch_dim = input_layouts[0].size[0] | |
if i==0: | |
# testing fn_1 | |
assert ( | |
WrapperCodeGen.statically_known_int_or_none(batch_dim) is None | |
), "Should not be statically known on first call" | |
else: | |
# testing fn_2 or fn_3 | |
assert ( | |
WrapperCodeGen.statically_known_int_or_none(batch_dim) == 5 | |
), "Should be limited to exactly 5 on second and third call due to multiple constraints" | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment