Skip to content

Instantly share code, notes, and snippets.

@mlazos
Last active May 29, 2025 19:57
Show Gist options
  • Save mlazos/f8baaeabb0defd2e13715a7cfeef30c6 to your computer and use it in GitHub Desktop.
Save mlazos/f8baaeabb0defd2e13715a7cfeef30c6 to your computer and use it in GitHub Desktop.
topological visitor bug
cutlass_fused_add_mm_relu_1dd8740c = async_compile.cuda(r'''
#include <exception>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/device/gemm_sparse.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
// We compile all models with -fvisibility=hidden. Any symbols that need to be
// exposed in the final shared library must be declared with PT_EXPORT to make
// them visible.
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
#define PT_EXPORT __attribute__((__visibility__("default")))
#else
#ifdef _WIN32
#define PT_EXPORT __declspec(dllexport)
#else
#define PT_EXPORT
#endif
#endif
using namespace cute;
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \
throw std::runtime_error(msg); \
} \
}
// Used as pass-through functor in EVT just for type casting / rounding
template <typename T>
struct identity_op {
CUTLASS_HOST_DEVICE
T operator()(T val) const { return val; }
};
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
cute::Shape<_128, _128, _64>, cutlass::epilogue::collective::EpilogueTileAuto,
float, cutlass::half_t,
cutlass::epilogue::TmaWarpSpecializedCooperative
>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::ReLu, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<
Compute0,
Accum>;
using Tmp0Descriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, cutlass::half_t
>;
using Tmp0 = cutlass::epilogue::fusion::Sm90AuxStore<
Tmp0Descriptor::Stages, typename Tmp0Descriptor::EpilogueTile, cutlass::half_t,
cutlass::FloatRoundStyle::round_to_nearest, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename Tmp0Descriptor::SmemLayoutAtom,
typename Tmp0Descriptor::CopyOpR2S
>;
using EVTTmp0 = cutlass::epilogue::fusion::Sm90EVT<
Tmp0,
EVTCompute0>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus, cutlass::half_t, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<
Compute1,
EVTTmp0>;
using ElementD = cutlass::half_t;
using StrideD = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>;
using ElementC = void;
using StrideC = StrideD;
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cute::Shape<cute::_128, cute::_128, cute::_64>,
cute::Shape<cute::_1, cute::_2, cute::_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, cutlass::layout::ColumnMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::epilogue::TmaWarpSpecializedCooperative,
EVTCompute1
>::CollectiveOp;
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
float,
cute::Shape<cute::_128, cute::_128, cute::_64>,
cute::Shape<cute::_1, cute::_2, cute::_1>,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue::SharedStorage))>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
// Gemm operator cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop,
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue,
cutlass::gemm::StreamKScheduler>;
// Define named type
struct cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma :
public cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base { };
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma>;
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
// Otherwise, computes the Gemm kernel using the given workspace ptr.
extern "C" {
PT_EXPORT int cutlass_fused_add_mm_relu_1dd8740c(const __half* X, const __half* W, const __half* Y, __half* buf1, __half* buf2, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const uint8_t swizzle, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) {
try {
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
if (hw_info.sm_count == 0) {
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::Arguments arguments;
// Initialize GemmUniversal3xInstance arguments.
arguments = {
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
{
static_cast<coord_t>(M),
static_cast<coord_t>(N),
static_cast<coord_t>(K),
static_cast<coord_t>(B)
}, // ProblemShape problem_shape
{
(cutlass::half_t*)(X), // ElementA const* ptr_A
{
lda /* stride_x0 */,
cute::Int<1>{} /* stride_x1 */,
0 /* batch_stride_x */
}, // StrideA dA
(cutlass::half_t*)(W), // ElementB const* ptr_B
{
cute::Int<1>{} /* stride_w1 */,
ldb /* stride_w0 */,
0 /* batch_stride_w */
}, // StrideB dB
}, // MainloopArguments mainloop
// see https://tinyurl.com/4rk89z48
{
{ /* thread */
{ /* tmp_0 */
{ /* compute_0 */
{}, /* accum */
{}, /* compute_0 */
},
{/* ptr_aux */ (cutlass::half_t*) buf1, /* dAux */ {512, _1{}, _0{}}}, /* tmp_0 */
},
{}, /* compute_1 */
}
, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
nullptr, // ElementC const* ptr_C
{
cute::Int<1>{} /* stride_bias0 */,
cute::Int<1>{} /* stride_bias1 */,
0 /* batch_stride_bias */
}, // StrideC dC
(cutlass::half_t*)(buf2), // ElementD const* ptr_D
{
ldd /* stride_y0 */,
cute::Int<1>{} /* stride_y1 */,
0 /* batch_stride_y */
}, // StrideD dD
}, // EpilogueArguments epilogue,
hw_info
};
arguments.scheduler.max_swizzle_size = swizzle;
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type gemm_op;
if (workspace_size) {
*workspace_size = gemm_op.get_workspace_size(arguments);
return 0;
}
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS
{
auto status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
}
#endif
#ifdef CUTLASS_DEBUG_TRACE_LEVEL
#if CUTLASS_DEBUG_TRACE_LEVEL == 1
{
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1
// we don't need a print statement, it's happening inside the function.
gemm_op.maximum_active_blocks();
}
#endif
#endif
{
auto status = gemm_op.initialize(arguments, workspace, stream);
CUTLASS_CHECK(status);
}
{
auto status = gemm_op(stream);
CUTLASS_CHECK(status);
}
}
catch (std::exception& e) {
std::cerr << "Runtime error: " << e.what() << std::endl;
return -1;
}
catch (...) {
return -1;
}
return 0;
}
}
// configuration name: cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma
''', 'so', aot_compile=False)
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1 = args
args.clear()
assert_size_stride(arg0_1, (1024, 512), (512, 1))
assert_size_stride(arg1_1, (512, 512), (512, 1))
assert_size_stride(arg2_1, (1024, 512), (512, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf1 = empty_strided_cuda((1024, 512), (512, 1), torch.float16)
buf2 = empty_strided_cuda((1024, 512), (512, 1), torch.float16)
stream0 = get_raw_stream(0)
cutlass_fused_add_mm_relu_1dd8740c.cutlass_fused_add_mm_relu_1dd8740c(c_void_p(arg0_1.data_ptr()), c_void_p(arg1_1.data_ptr()), c_void_p(buf2.data_ptr()), c_void_p(buf1.data_ptr()), c_void_p(buf2.data_ptr()), 1024, 512, 512, 1, 512, 512, 0, 512, 1, None, None, c_void_p(stream0))
del arg0_1
del arg1_1
return (buf2, buf1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((1024, 512), (512, 1), device='cuda:0', dtype=torch.float16)
arg1_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float16)
arg2_1 = rand_strided((1024, 512), (512, 1), device='cuda:0', dtype=torch.float16)
fn = lambda: call([arg0_1, arg1_1, arg2_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
[DEBUG]:Output code written to: /tmp/torchinductor_mlazos/ms/cmsxzzeygcm7dipqepz5y2lhgy6ccolpfz2x7mfemsksgtiy6mpq.py
Eframes [('total', 1)]
stats [('calls_captured', 4)]
inductor [('cuda_epilogue_fusion_counter', 2), ('fxgraph_cache_miss', 1)]
aot_autograd [('total', 1), ('autograd_cache_miss', 1), ('not_ok', 1)]
graph_break []
aten_mm_info [('aten.mm_1024_512_512', 1)]
@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
@evt_all_ops
def test_evt_multi_output(self, op):
class TestModel(torch.nn.Module):
def forward(self, a, b, extra_args):
acc = a @ b
z0 = acc.relu()
z = z0 + extra_args[0]
y = z0 + z0
return y, z0
M = 1024
N = 512
a = torch.ones(M, N).cuda().half()
b = torch.ones(N, N).cuda().half()
extra_args = gen_args(op, (M, N))
model = TestModel().cuda()
result = torch.compile(model)(a, b, extra_args)
ref_result = model(a, b, extra_args)
self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 2
)
torch.testing.assert_close(result, ref_result)
`--> TORCH_LOGS="output_code" python test/inductor/test_cutlass_backend.py -k test_evt_multi_output_add
[DEBUG]:Output code:
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /tmp/torchinductor_mlazos/hd/chd2bgm73l6drb74vbk2utbdyi5me675nixcvnpg5lgqqhsyi7cv.py
# Topologically Sorted Source Nodes: [acc, z0, z, y], Original ATen: [aten.mm, aten.relu, aten.add]
# Source node to ATen node mapping:
# acc => mm
# y => add_1
# z => add
# z0 => relu
# Graph fragment:
# %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg0_1, %arg1_1), kwargs = {})
# %relu : [num_users=3] = call_function[target=torch.ops.aten.relu.default](args = (%mm,), kwargs = {})
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%relu, %arg2_1), kwargs = {})
# %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%relu, %add), kwargs = {})
cutlass_fused_add_mm_relu_b08c11bd = async_compile.cuda(r'''
#include <exception>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/device/gemm_sparse.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
// We compile all models with -fvisibility=hidden. Any symbols that need to be
// exposed in the final shared library must be declared with PT_EXPORT to make
// them visible.
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
#define PT_EXPORT __attribute__((__visibility__("default")))
#else
#ifdef _WIN32
#define PT_EXPORT __declspec(dllexport)
#else
#define PT_EXPORT
#endif
#endif
using namespace cute;
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \
throw std::runtime_error(msg); \
} \
}
// Used as pass-through functor in EVT just for type casting / rounding
template <typename T>
struct identity_op {
CUTLASS_HOST_DEVICE
T operator()(T val) const { return val; }
};
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
cute::Shape<_128, _128, _64>, cutlass::epilogue::collective::EpilogueTileAuto,
float, cutlass::half_t,
cutlass::epilogue::TmaWarpSpecializedCooperative
>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Arg21Descriptor = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, cutlass::half_t>;
using Arg21 = cutlass::epilogue::fusion::Sm90AuxLoad<
Arg21Descriptor::Stages, typename Arg21Descriptor::EpilogueTile, cutlass::half_t,
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename Arg21Descriptor::SmemLayoutAtom, typename Arg21Descriptor::CopyOpS2R
>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::ReLu, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<
Compute0,
Accum>;
using Tmp0Descriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, cutlass::half_t
>;
using Tmp0 = cutlass::epilogue::fusion::Sm90AuxStore<
Tmp0Descriptor::Stages, typename Tmp0Descriptor::EpilogueTile, cutlass::half_t,
cutlass::FloatRoundStyle::round_to_nearest, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename Tmp0Descriptor::SmemLayoutAtom,
typename Tmp0Descriptor::CopyOpR2S
>;
using EVTTmp0 = cutlass::epilogue::fusion::Sm90EVT<
Tmp0,
EVTCompute0>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using Compute2 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus, cutlass::half_t, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using DagCompute2 = cutlass::epilogue::fusion::Sm90TopologicalVisitor<
float,
cute::tuple<
cute::seq<>,
cute::seq<>,
cute::seq<1, 0>,
cute::seq<1, 2>,
>,
Arg21,
EVTTmp0,
Compute1,
Compute2
>;
using ElementD = cutlass::half_t;
using StrideD = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>;
using ElementC = void;
using StrideC = StrideD;
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cute::Shape<cute::_128, cute::_128, cute::_64>,
cute::Shape<cute::_1, cute::_2, cute::_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, cutlass::layout::ColumnMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::epilogue::TmaWarpSpecializedCooperative,
DagCompute2
>::CollectiveOp;
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
float,
cute::Shape<cute::_128, cute::_128, cute::_64>,
cute::Shape<cute::_1, cute::_2, cute::_1>,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue::SharedStorage))>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
// Gemm operator cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop,
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue,
cutlass::gemm::StreamKScheduler>;
// Define named type
struct cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma :
public cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base { };
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma>;
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
// Otherwise, computes the Gemm kernel using the given workspace ptr.
extern "C" {
PT_EXPORT int cutlass_fused_add_mm_relu_b08c11bd(const __half* X, const __half* W, const __half* arg2_1, const __half* Y, __half* buf1, __half* buf2, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const uint8_t swizzle, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) {
try {
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
if (hw_info.sm_count == 0) {
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::Arguments arguments;
// Initialize GemmUniversal3xInstance arguments.
arguments = {
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
{
static_cast<coord_t>(M),
static_cast<coord_t>(N),
static_cast<coord_t>(K),
static_cast<coord_t>(B)
}, // ProblemShape problem_shape
{
(cutlass::half_t*)(X), // ElementA const* ptr_A
{
lda /* stride_x0 */,
cute::Int<1>{} /* stride_x1 */,
0 /* batch_stride_x */
}, // StrideA dA
(cutlass::half_t*)(W), // ElementB const* ptr_B
{
cute::Int<1>{} /* stride_w1 */,
ldb /* stride_w0 */,
0 /* batch_stride_w */
}, // StrideB dB
}, // MainloopArguments mainloop
// see https://tinyurl.com/4rk89z48
{
{ /* thread */
{/* ptr_aux */ (cutlass::half_t*) arg2_1, /* null_default */ cutlass::half_t(0), /* dAux */ {512, _1{}, _0{}}}, /* arg2_1 */
{ /* tmp_0 */
{ /* compute_0 */
{}, /* accum */
{}, /* compute_0 */
},
{/* ptr_aux */ (cutlass::half_t*) buf1, /* dAux */ {512, _1{}, _0{}}}, /* tmp_0 */
},
{}, /* compute_1 */
{}, /* compute_2 */
}
, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
nullptr, // ElementC const* ptr_C
{
cute::Int<1>{} /* stride_bias0 */,
cute::Int<1>{} /* stride_bias1 */,
0 /* batch_stride_bias */
}, // StrideC dC
(cutlass::half_t*)(buf2), // ElementD const* ptr_D
{
ldd /* stride_y0 */,
cute::Int<1>{} /* stride_y1 */,
0 /* batch_stride_y */
}, // StrideD dD
}, // EpilogueArguments epilogue,
hw_info
};
arguments.scheduler.max_swizzle_size = swizzle;
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type gemm_op;
if (workspace_size) {
*workspace_size = gemm_op.get_workspace_size(arguments);
return 0;
}
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS
{
auto status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
}
#endif
#ifdef CUTLASS_DEBUG_TRACE_LEVEL
#if CUTLASS_DEBUG_TRACE_LEVEL == 1
{
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1
// we don't need a print statement, it's happening inside the function.
gemm_op.maximum_active_blocks();
}
#endif
#endif
{
auto status = gemm_op.initialize(arguments, workspace, stream);
CUTLASS_CHECK(status);
}
{
auto status = gemm_op(stream);
CUTLASS_CHECK(status);
}
}
catch (std::exception& e) {
std::cerr << "Runtime error: " << e.what() << std::endl;
return -1;
}
catch (...) {
return -1;
}
return 0;
}
}
// configuration name: cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma
''', 'so', aot_compile=False)
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1 = args
args.clear()
assert_size_stride(arg0_1, (1024, 512), (512, 1))
assert_size_stride(arg1_1, (512, 512), (512, 1))
assert_size_stride(arg2_1, (1024, 512), (512, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf1 = empty_strided_cuda((1024, 512), (512, 1), torch.float16)
buf2 = empty_strided_cuda((1024, 512), (512, 1), torch.float16)
stream0 = get_raw_stream(0)
cutlass_fused_add_mm_relu_b08c11bd.cutlass_fused_add_mm_relu_b08c11bd(c_void_p(arg0_1.data_ptr()), c_void_p(arg1_1.data_ptr()), c_void_p(arg2_1.data_ptr()), c_void_p(buf2.data_ptr()), c_void_p(buf1.data_ptr()), c_void_p(buf2.data_ptr()), 1024, 512, 512, 1, 512, 512, 0, 512, 1, None, None, c_void_p(stream0))
del arg0_1
del arg1_1
del arg2_1
return (buf2, buf1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((1024, 512), (512, 1), device='cuda:0', dtype=torch.float16)
arg1_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float16)
arg2_1 = rand_strided((1024, 512), (512, 1), device='cuda:0', dtype=torch.float16)
fn = lambda: call([arg0_1, arg1_1, arg2_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
[DEBUG]:Output code written to: /tmp/torchinductor_mlazos/ak/cakiohsxeutme3m3pv5zhmwelrwwdj7onpshavvnsedj4mjuqb4b.py
[INFO]:Output code written to: /tmp/torchinductor_mlazos/ak/cakiohsxeutme3m3pv5zhmwelrwwdj7onpshavvnsedj4mjuqb4b.py
frames [('total', 1), ('ok', 1)]
stats [('calls_captured', 4), ('unique_graphs', 1)]
inductor [('cuda_epilogue_fusion_counter', 2), ('fxgraph_cache_miss', 1)]
aot_autograd [('total', 1), ('autograd_cache_miss', 1), ('autograd_cache_saved', 1), ('ok', 1)]
graph_break []
aten_mm_info [('aten.mm_1024_512_512', 1)]
.
----------------------------------------------------------------------
Ran 1 test in 55.937s
OK
@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
@evt_all_ops
def test_evt_multi_output(self, op):
class TestModel(torch.nn.Module):
def forward(self, a, b, extra_args):
acc = a @ b
z0 = acc.relu()
z = z0 + extra_args[0]
y = z0 + z
return y, z0
M = 1024
N = 512
a = torch.ones(M, N).cuda().half()
b = torch.ones(N, N).cuda().half()
extra_args = gen_args(op, (M, N))
model = TestModel().cuda()
result = torch.compile(model)(a, b, extra_args)
ref_result = model(a, b, extra_args)
self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 2
)
torch.testing.assert_close(result, ref_result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment