Skip to content

Instantly share code, notes, and snippets.

@mlazos
Created May 12, 2025 21:53
Show Gist options
  • Save mlazos/5e69abbb3955e9c383f8339e13cc0640 to your computer and use it in GitHub Desktop.
Save mlazos/5e69abbb3955e9c383f8339e13cc0640 to your computer and use it in GitHub Desktop.
from typing import overload
import torch
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
try_import_cutlass() # comment this out if you have cutlass installed
import cutlass
import types
import ast
import textwrap
import inspect
from cutlass.epilogue import relu
from cutlass import Tensor as FakeTensor
from cutlass.backend.evt import EpilogueFunctorVisitor
from cutlass.backend.evt.backend.emitter_base import FusionCallbacks
from cutlass.backend.evt.frontend import PythonASTFrontend
# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to
# omit this information.
print_module = False
# The Epilogue Visitor feature currently only works for SM80 and 90
from cutlass.backend.utils.device import device_cc
if device_cc() not in [80, 90]:
import sys
sys.exit()
m = 16384
n = m
k = 512
type_A = torch.float16
type_B = torch.float16
type_C = torch.float16
type_D = torch.float16
torch.manual_seed(2023)
scope_min = -4
scope_max = 4
tensor_A = FakeTensor(shape=(m, k), element=type_C, layout_tag=cutlass.LayoutType.RowMajor)
tensor_B = torch.ceil(torch.empty(size=(k, n), dtype=type_B, device="cuda").uniform_(scope_min, scope_max))
tensor_C = torch.ceil(torch.empty(size=(m, n), dtype=type_C, device="cuda").uniform_(scope_min, scope_max))
tensor_D = torch.zeros_like(tensor_C)
plan = cutlass.op.Gemm(element=torch.float16, layout=cutlass.LayoutType.RowMajor, element_accumulator=torch.float32)
def trace(fn_src, example_tensors, **kwargs):
class EpilogueFunctor(PythonASTFrontend):
def __init__(self, **kwargs):
self.source = textwrap.dedent(fn_src)
super().__init__(**kwargs)
def parse(self, example_inputs):
self.example_inputs = example_inputs
self.ast = ast.parse(self.source)
self.visit(self.ast)
epilogue_functor = EpilogueFunctor(cc=90, **kwargs)
epilogue_functor.trace(example_tensors)
return epilogue_functor
bias_code = """def fn(accum, bias):
E = accum
D = E + tanh(bias)
return D, E
"""
bias = FakeTensor(shape=(m, 1), element=type_C, layout_tag=cutlass.LayoutType.RowMajor)
examples_tensors = {
"accum": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor),
"acc": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor),
"bias": bias,
"D": tensor_D,
"E": tensor_D,
}
# Trace and print the epilogue visitor
epilogue_visitor = trace(bias_code, examples_tensors)
visitor = EpilogueFunctorVisitor(90, epilogue_visitor)
fusion_callbacks = FusionCallbacks(visitor.graph, 90, emit_CD=False)
print("".join(fusion_callbacks.emit()))
#breakpoint()
#plan.epilogue_visitor = epilogue_visitor
#print(plan.construct().rt_module.emit())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment