Created
May 12, 2025 21:53
-
-
Save mlazos/5e69abbb3955e9c383f8339e13cc0640 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import overload | |
import torch | |
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass | |
try_import_cutlass() # comment this out if you have cutlass installed | |
import cutlass | |
import types | |
import ast | |
import textwrap | |
import inspect | |
from cutlass.epilogue import relu | |
from cutlass import Tensor as FakeTensor | |
from cutlass.backend.evt import EpilogueFunctorVisitor | |
from cutlass.backend.evt.backend.emitter_base import FusionCallbacks | |
from cutlass.backend.evt.frontend import PythonASTFrontend | |
# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to | |
# omit this information. | |
print_module = False | |
# The Epilogue Visitor feature currently only works for SM80 and 90 | |
from cutlass.backend.utils.device import device_cc | |
if device_cc() not in [80, 90]: | |
import sys | |
sys.exit() | |
m = 16384 | |
n = m | |
k = 512 | |
type_A = torch.float16 | |
type_B = torch.float16 | |
type_C = torch.float16 | |
type_D = torch.float16 | |
torch.manual_seed(2023) | |
scope_min = -4 | |
scope_max = 4 | |
tensor_A = FakeTensor(shape=(m, k), element=type_C, layout_tag=cutlass.LayoutType.RowMajor) | |
tensor_B = torch.ceil(torch.empty(size=(k, n), dtype=type_B, device="cuda").uniform_(scope_min, scope_max)) | |
tensor_C = torch.ceil(torch.empty(size=(m, n), dtype=type_C, device="cuda").uniform_(scope_min, scope_max)) | |
tensor_D = torch.zeros_like(tensor_C) | |
plan = cutlass.op.Gemm(element=torch.float16, layout=cutlass.LayoutType.RowMajor, element_accumulator=torch.float32) | |
def trace(fn_src, example_tensors, **kwargs): | |
class EpilogueFunctor(PythonASTFrontend): | |
def __init__(self, **kwargs): | |
self.source = textwrap.dedent(fn_src) | |
super().__init__(**kwargs) | |
def parse(self, example_inputs): | |
self.example_inputs = example_inputs | |
self.ast = ast.parse(self.source) | |
self.visit(self.ast) | |
epilogue_functor = EpilogueFunctor(cc=90, **kwargs) | |
epilogue_functor.trace(example_tensors) | |
return epilogue_functor | |
bias_code = """def fn(accum, bias): | |
E = accum | |
D = E + tanh(bias) | |
return D, E | |
""" | |
bias = FakeTensor(shape=(m, 1), element=type_C, layout_tag=cutlass.LayoutType.RowMajor) | |
examples_tensors = { | |
"accum": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor), | |
"acc": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor), | |
"bias": bias, | |
"D": tensor_D, | |
"E": tensor_D, | |
} | |
# Trace and print the epilogue visitor | |
epilogue_visitor = trace(bias_code, examples_tensors) | |
visitor = EpilogueFunctorVisitor(90, epilogue_visitor) | |
fusion_callbacks = FusionCallbacks(visitor.graph, 90, emit_CD=False) | |
print("".join(fusion_callbacks.emit())) | |
#breakpoint() | |
#plan.epilogue_visitor = epilogue_visitor | |
#print(plan.construct().rt_module.emit()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment