Skip to content

Instantly share code, notes, and snippets.

@mlazos
Created April 12, 2025 23:36
Show Gist options
  • Save mlazos/752185f2ed5eae17b02d5e1759e0ee2d to your computer and use it in GitHub Desktop.
Save mlazos/752185f2ed5eae17b02d5e1759e0ee2d to your computer and use it in GitHub Desktop.
# mypy: allow-untyped-defs
from numpy import dtype
from torch._inductor.ir import ComputedBuffer, InputBuffer
from typing import Union
from ..cutlass_utils import try_import_cutlass
if try_import_cutlass():
import ast
import ctypes
import textwrap
from cutlass.backend.evt import ( # type: ignore[import-untyped, import-not-found]
EpilogueFunctorVisitor,
)
from cutlass.backend.evt.backend.emitter_base import ( # type: ignore[import-untyped, import-not-found]
FusionCallbacks,
)
from cutlass.backend.evt.backend.sm90_emitter import ( # type: ignore[import-untyped, import-not-found]
CollectiveEpilogue,
)
from cutlass.backend.evt.frontend import ( # type: ignore[import-untyped, import-not-found]
PythonASTFrontend,
)
from cutlass.backend.evt.ir.tensor import ( # type: ignore[import-untyped, import-not-found]
Tensor as CutlassTensor,
)
from cutlass.backend.epilogue import dtype2ctype # type: ignore[import-untyped, import-not-found]
from cutlass.backend.c_types import EmptyByte
from cutlass_library import (
DataType,
EpilogueScheduleType,
LayoutType,
TileDescription,
)
from torch._inductor.codegen.cuda import cuda_env
from torch._inductor.utils import IndentedBuffer
from ..cutlass_utils import torch_dtype_to_cutlass_type
_CUTLASS_C_DTYPES = set(dtype2ctype.values())
def create_example_tensors(
read_names: list[str],
write_names: list[str],
buffer_renames: dict[str, str],
name_to_buffer: dict[str, Union[ComputedBuffer, InputBuffer]],
):
example_tensors = {}
def cutlass_tensor_from_buffer(buffer: ComputedBuffer):
shape = tuple(int(x) for x in buffer.get_layout().size)
stride = tuple(int(x) for x in buffer.get_layout().stride)
is_column_major = True
for i in range(1, len(shape)):
if shape[i] == 1:
continue
if stride[i] != stride[i - 1] * shape[i - 1]:
is_column_major = False
is_row_major = True
for i in range(len(shape) - 1):
if shape[i] == 1:
continue
if stride[i] != stride[i + 1] * shape[i + 1]:
is_row_major = False
if not is_row_major and not is_column_major:
raise RuntimeError(
f"Cannot create example tensor for {buffer.get_name()} with non-contiguous layout, recieved stride: {stride} and shape: {shape}"
)
return CutlassTensor(
shape=shape,
layout_tag=LayoutType.RowMajor
if is_row_major
else LayoutType.ColumnMajor,
element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype),
)
for name in read_names + write_names:
key = name
if name in buffer_renames:
key = buffer_renames[
name
] # Need to rewrite some special args (e.g. acc is a required arg name)
example_tensors[key] = cutlass_tensor_from_buffer(name_to_buffer[name])
return example_tensors
def trace(
fn_src: str,
example_tensors: dict[str, CutlassTensor],
accum_type: DataType,
output_type: DataType,
tile_description: TileDescription,
epilogue_schedule: EpilogueScheduleType,
name_to_buffer: dict[str, Union[ComputedBuffer, InputBuffer]],
**kwargs,
):
cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type]
assert cuda_arch >= 90, "Only SM90+ is supported for EVT"
epilogue_functor = _trace(fn_src, example_tensors, **kwargs)
visitor = EpilogueFunctorVisitor(cuda_arch, epilogue_functor)
fusion_callbacks = FusionCallbacks(visitor.graph, cuda_arch, emit_CD=False)
collective_epilogue = CollectiveEpilogue(
tile_description,
epilogue_schedule,
accum_type,
output_type,
fusion_callbacks,
)
fusion_arguments = _render_argument_type(epilogue_functor, name_to_buffer)
breakpoint()
return collective_epilogue.emit()
# Based off of
# https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117
# This is modified to enable directly passing the source code of the epilogue vs getting it from a bona-fide python function
# The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval
def _trace(fn_src, example_tensors, **kwargs):
class EpilogueFunctor(PythonASTFrontend):
def __init__(self, **kwargs):
self.source = textwrap.dedent(fn_src)
super().__init__(**kwargs)
def parse(self, example_inputs):
self.example_inputs = example_inputs
self.ast = ast.parse(self.source)
self.visit(self.ast)
epilogue_functor = EpilogueFunctor(**kwargs)
epilogue_functor.trace(example_tensors)
return epilogue_functor
def _render_argument_type(epilogue_functor, name_to_buffer):
epilogue_thread_type = epilogue_functor.epilogue_thread_type
# Fragile, but this is the only way to guarantee t is expected type because t is a local class
def is_nested_visitor_type(t):
return (
".".join([t.__module__, t.__qualname__])
== "cutlass.backend.c_types.visitor_factory.<locals>.VisitorType"
)
buffer = IndentedBuffer()
def render_argument_type(name, t):
if issubclass(t, ctypes.c_byte):
buffer.writeline(f"{{}}, /* {name} */")
else:
fields = [(fname,_get_arg_from_node(ty, name_to_buffer[name])) for fname, ty in t._fields_]
field_strs = [f"/* {fname} */ {str(field)}" for fname, field in fields]
buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */")
def render_thread_type(name, t):
if is_nested_visitor_type(t):
buffer.writeline(f"{{ /* {name} */")
with buffer.indent():
for name, inner_t in t._fields_:
render_thread_type(name, inner_t)
buffer.writeline("},")
else:
render_argument_type(name, t)
buffer.writeline("{{")
with buffer.indent():
render_thread_type("thread", epilogue_thread_type)
buffer.writeline("}};")
return buffer.getvalue()
def _get_arg_from_node(arg_ty, node):
from ..cuda_template import CUTLASSTemplate
# Today, arguments are either a pointer to the
# node's memory, a stride tuple, the datatype
if issubclass(arg_ty, tuple):
return f"{{{','.join([str(int(x)) for x in node.get_layout().stride])}}}"
elif issubclass(arg_ty, ctypes.c_void_p):
return f"{node.get_name()}.get()"
elif arg_ty in _CUTLASS_C_DTYPES: # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently
return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
elif issubclass(arg_ty, EmptyByte):
return "{}"
raise NotImplementedError(f"Unsupported arg type: {arg_ty}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment