Last active
May 29, 2025 19:57
-
-
Save mlazos/f8baaeabb0defd2e13715a7cfeef30c6 to your computer and use it in GitHub Desktop.
topological visitor bug
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cutlass_fused_add_mm_relu_1dd8740c = async_compile.cuda(r''' | |
#include <exception> | |
#include <iostream> | |
#include <memory> | |
#include <random> | |
#include <vector> | |
#include "cute/tensor.hpp" | |
#include "cutlass/cutlass.h" | |
#include "cutlass/numeric_types.h" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include "cutlass/util/device_memory.h" | |
#include "cutlass/gemm/gemm.h" | |
#include "cutlass/gemm/device/gemm_universal.h" | |
#include "cutlass/gemm/device/gemm_universal_adapter.h" | |
#include "cutlass/gemm/kernel/gemm_universal.hpp" | |
#include "cutlass/gemm/device/gemm_sparse.h" | |
#include "cutlass/gemm/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/default_epilogue.hpp" | |
#include "cutlass/epilogue/thread/linear_combination.h" | |
#include "cutlass/epilogue/thread/activation.h" | |
#include "cutlass/gemm/dispatch_policy.hpp" | |
#include "cutlass/gemm/kernel/tile_scheduler.hpp" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
// We compile all models with -fvisibility=hidden. Any symbols that need to be | |
// exposed in the final shared library must be declared with PT_EXPORT to make | |
// them visible. | |
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) | |
#define PT_EXPORT __attribute__((__visibility__("default"))) | |
#else | |
#ifdef _WIN32 | |
#define PT_EXPORT __declspec(dllexport) | |
#else | |
#define PT_EXPORT | |
#endif | |
#endif | |
using namespace cute; | |
#define CUTLASS_CHECK(status) \ | |
{ \ | |
cutlass::Status error = status; \ | |
if (error != cutlass::Status::kSuccess) { \ | |
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \ | |
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \ | |
throw std::runtime_error(msg); \ | |
} \ | |
} | |
// Used as pass-through functor in EVT just for type casting / rounding | |
template <typename T> | |
struct identity_op { | |
CUTLASS_HOST_DEVICE | |
T operator()(T val) const { return val; } | |
}; | |
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< | |
cute::Shape<_128, _128, _64>, cutlass::epilogue::collective::EpilogueTileAuto, | |
float, cutlass::half_t, | |
cutlass::epilogue::TmaWarpSpecializedCooperative | |
>; | |
using Accum = cutlass::epilogue::fusion::Sm90AccFetch; | |
using Compute0 = cutlass::epilogue::fusion::Sm90Compute< | |
cutlass::epilogue::thread::ReLu, float, float, | |
cutlass::FloatRoundStyle::round_to_nearest | |
>; | |
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT< | |
Compute0, | |
Accum>; | |
using Tmp0Descriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< | |
EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, cutlass::half_t | |
>; | |
using Tmp0 = cutlass::epilogue::fusion::Sm90AuxStore< | |
Tmp0Descriptor::Stages, typename Tmp0Descriptor::EpilogueTile, cutlass::half_t, | |
cutlass::FloatRoundStyle::round_to_nearest, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename Tmp0Descriptor::SmemLayoutAtom, | |
typename Tmp0Descriptor::CopyOpR2S | |
>; | |
using EVTTmp0 = cutlass::epilogue::fusion::Sm90EVT< | |
Tmp0, | |
EVTCompute0>; | |
using Compute1 = cutlass::epilogue::fusion::Sm90Compute< | |
cutlass::plus, cutlass::half_t, float, | |
cutlass::FloatRoundStyle::round_to_nearest | |
>; | |
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT< | |
Compute1, | |
EVTTmp0>; | |
using ElementD = cutlass::half_t; | |
using StrideD = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>; | |
using ElementC = void; | |
using StrideC = StrideD; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue = | |
typename cutlass::epilogue::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cute::Shape<cute::_128, cute::_128, cute::_64>, | |
cute::Shape<cute::_1, cute::_2, cute::_1>, | |
cutlass::epilogue::collective::EpilogueTileAuto, | |
float, float, | |
void, cutlass::layout::ColumnMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::epilogue::TmaWarpSpecializedCooperative, | |
EVTCompute1 | |
>::CollectiveOp; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop = | |
typename cutlass::gemm::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
float, | |
cute::Shape<cute::_128, cute::_128, cute::_64>, | |
cute::Shape<cute::_1, cute::_2, cute::_1>, | |
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue::SharedStorage))>, | |
cutlass::gemm::KernelTmaWarpSpecializedCooperative | |
>::CollectiveOp; | |
// Gemm operator cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< | |
cute::Shape<int,int,int,int>, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue, | |
cutlass::gemm::StreamKScheduler>; | |
// Define named type | |
struct cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma : | |
public cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base { }; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma>; | |
// When workspace_size is not a nullptr, populates requested workspace_size and returns. | |
// Otherwise, computes the Gemm kernel using the given workspace ptr. | |
extern "C" { | |
PT_EXPORT int cutlass_fused_add_mm_relu_1dd8740c(const __half* X, const __half* W, const __half* Y, __half* buf1, __half* buf2, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const uint8_t swizzle, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) { | |
try { | |
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::ElementAccumulator; | |
using coord_t = cutlass::gemm::GemmCoord::Index; | |
static cutlass::KernelHardwareInfo hw_info; | |
if (hw_info.sm_count == 0) { | |
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); | |
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); | |
} | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::Arguments arguments; | |
// Initialize GemmUniversal3xInstance arguments. | |
arguments = { | |
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode | |
{ | |
static_cast<coord_t>(M), | |
static_cast<coord_t>(N), | |
static_cast<coord_t>(K), | |
static_cast<coord_t>(B) | |
}, // ProblemShape problem_shape | |
{ | |
(cutlass::half_t*)(X), // ElementA const* ptr_A | |
{ | |
lda /* stride_x0 */, | |
cute::Int<1>{} /* stride_x1 */, | |
0 /* batch_stride_x */ | |
}, // StrideA dA | |
(cutlass::half_t*)(W), // ElementB const* ptr_B | |
{ | |
cute::Int<1>{} /* stride_w1 */, | |
ldb /* stride_w0 */, | |
0 /* batch_stride_w */ | |
}, // StrideB dB | |
}, // MainloopArguments mainloop | |
// see https://tinyurl.com/4rk89z48 | |
{ | |
{ /* thread */ | |
{ /* tmp_0 */ | |
{ /* compute_0 */ | |
{}, /* accum */ | |
{}, /* compute_0 */ | |
}, | |
{/* ptr_aux */ (cutlass::half_t*) buf1, /* dAux */ {512, _1{}, _0{}}}, /* tmp_0 */ | |
}, | |
{}, /* compute_1 */ | |
} | |
, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) | |
nullptr, // ElementC const* ptr_C | |
{ | |
cute::Int<1>{} /* stride_bias0 */, | |
cute::Int<1>{} /* stride_bias1 */, | |
0 /* batch_stride_bias */ | |
}, // StrideC dC | |
(cutlass::half_t*)(buf2), // ElementD const* ptr_D | |
{ | |
ldd /* stride_y0 */, | |
cute::Int<1>{} /* stride_y1 */, | |
0 /* batch_stride_y */ | |
}, // StrideD dD | |
}, // EpilogueArguments epilogue, | |
hw_info | |
}; | |
arguments.scheduler.max_swizzle_size = swizzle; | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type gemm_op; | |
if (workspace_size) { | |
*workspace_size = gemm_op.get_workspace_size(arguments); | |
return 0; | |
} | |
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers | |
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS | |
{ | |
auto status = gemm_op.can_implement(arguments); | |
CUTLASS_CHECK(status); | |
} | |
#endif | |
#ifdef CUTLASS_DEBUG_TRACE_LEVEL | |
#if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
{ | |
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
// we don't need a print statement, it's happening inside the function. | |
gemm_op.maximum_active_blocks(); | |
} | |
#endif | |
#endif | |
{ | |
auto status = gemm_op.initialize(arguments, workspace, stream); | |
CUTLASS_CHECK(status); | |
} | |
{ | |
auto status = gemm_op(stream); | |
CUTLASS_CHECK(status); | |
} | |
} | |
catch (std::exception& e) { | |
std::cerr << "Runtime error: " << e.what() << std::endl; | |
return -1; | |
} | |
catch (...) { | |
return -1; | |
} | |
return 0; | |
} | |
} | |
// configuration name: cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma | |
''', 'so', aot_compile=False) | |
async_compile.wait(globals()) | |
del async_compile | |
def call(args): | |
arg0_1, arg1_1, arg2_1 = args | |
args.clear() | |
assert_size_stride(arg0_1, (1024, 512), (512, 1)) | |
assert_size_stride(arg1_1, (512, 512), (512, 1)) | |
assert_size_stride(arg2_1, (1024, 512), (512, 1)) | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
buf1 = empty_strided_cuda((1024, 512), (512, 1), torch.float16) | |
buf2 = empty_strided_cuda((1024, 512), (512, 1), torch.float16) | |
stream0 = get_raw_stream(0) | |
cutlass_fused_add_mm_relu_1dd8740c.cutlass_fused_add_mm_relu_1dd8740c(c_void_p(arg0_1.data_ptr()), c_void_p(arg1_1.data_ptr()), c_void_p(buf2.data_ptr()), c_void_p(buf1.data_ptr()), c_void_p(buf2.data_ptr()), 1024, 512, 512, 1, 512, 512, 0, 512, 1, None, None, c_void_p(stream0)) | |
del arg0_1 | |
del arg1_1 | |
return (buf2, buf1, ) | |
def benchmark_compiled_module(times=10, repeat=10): | |
from torch._dynamo.testing import rand_strided | |
from torch._inductor.utils import print_performance | |
arg0_1 = rand_strided((1024, 512), (512, 1), device='cuda:0', dtype=torch.float16) | |
arg1_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float16) | |
arg2_1 = rand_strided((1024, 512), (512, 1), device='cuda:0', dtype=torch.float16) | |
fn = lambda: call([arg0_1, arg1_1, arg2_1]) | |
return print_performance(fn, times=times, repeat=repeat) | |
if __name__ == "__main__": | |
from torch._inductor.wrapper_benchmark import compiled_module_main | |
compiled_module_main('None', benchmark_compiled_module) | |
[DEBUG]:Output code written to: /tmp/torchinductor_mlazos/ms/cmsxzzeygcm7dipqepz5y2lhgy6ccolpfz2x7mfemsksgtiy6mpq.py | |
Eframes [('total', 1)] | |
stats [('calls_captured', 4)] | |
inductor [('cuda_epilogue_fusion_counter', 2), ('fxgraph_cache_miss', 1)] | |
aot_autograd [('total', 1), ('autograd_cache_miss', 1), ('not_ok', 1)] | |
graph_break [] | |
aten_mm_info [('aten.mm_1024_512_512', 1)] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@unittest.skipIf(not SM90OrLater, "need sm_90") | |
@use_evt_config | |
@evt_all_ops | |
def test_evt_multi_output(self, op): | |
class TestModel(torch.nn.Module): | |
def forward(self, a, b, extra_args): | |
acc = a @ b | |
z0 = acc.relu() | |
z = z0 + extra_args[0] | |
y = z0 + z0 | |
return y, z0 | |
M = 1024 | |
N = 512 | |
a = torch.ones(M, N).cuda().half() | |
b = torch.ones(N, N).cuda().half() | |
extra_args = gen_args(op, (M, N)) | |
model = TestModel().cuda() | |
result = torch.compile(model)(a, b, extra_args) | |
ref_result = model(a, b, extra_args) | |
self.assertEqual( | |
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 2 | |
) | |
torch.testing.assert_close(result, ref_result) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
`--> TORCH_LOGS="output_code" python test/inductor/test_cutlass_backend.py -k test_evt_multi_output_add | |
[DEBUG]:Output code: | |
# AOT ID: ['0_inference'] | |
from ctypes import c_void_p, c_long, c_int | |
import torch | |
import math | |
import random | |
import os | |
import tempfile | |
from math import inf, nan | |
from cmath import nanj | |
from torch._inductor.hooks import run_intermediate_hooks | |
from torch._inductor.utils import maybe_profile | |
from torch._inductor.codegen.memory_planning import _align as align | |
from torch import device, empty_strided | |
from torch._inductor.async_compile import AsyncCompile | |
from torch._inductor.select_algorithm import extern_kernels | |
from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
aten = torch.ops.aten | |
inductor_ops = torch.ops.inductor | |
_quantized = torch.ops._quantized | |
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
assert_alignment = torch._C._dynamo.guards.assert_alignment | |
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu | |
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
async_compile = AsyncCompile() | |
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p | |
# kernel path: /tmp/torchinductor_mlazos/hd/chd2bgm73l6drb74vbk2utbdyi5me675nixcvnpg5lgqqhsyi7cv.py | |
# Topologically Sorted Source Nodes: [acc, z0, z, y], Original ATen: [aten.mm, aten.relu, aten.add] | |
# Source node to ATen node mapping: | |
# acc => mm | |
# y => add_1 | |
# z => add | |
# z0 => relu | |
# Graph fragment: | |
# %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg0_1, %arg1_1), kwargs = {}) | |
# %relu : [num_users=3] = call_function[target=torch.ops.aten.relu.default](args = (%mm,), kwargs = {}) | |
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%relu, %arg2_1), kwargs = {}) | |
# %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%relu, %add), kwargs = {}) | |
cutlass_fused_add_mm_relu_b08c11bd = async_compile.cuda(r''' | |
#include <exception> | |
#include <iostream> | |
#include <memory> | |
#include <random> | |
#include <vector> | |
#include "cute/tensor.hpp" | |
#include "cutlass/cutlass.h" | |
#include "cutlass/numeric_types.h" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include "cutlass/util/device_memory.h" | |
#include "cutlass/gemm/gemm.h" | |
#include "cutlass/gemm/device/gemm_universal.h" | |
#include "cutlass/gemm/device/gemm_universal_adapter.h" | |
#include "cutlass/gemm/kernel/gemm_universal.hpp" | |
#include "cutlass/gemm/device/gemm_sparse.h" | |
#include "cutlass/gemm/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/default_epilogue.hpp" | |
#include "cutlass/epilogue/thread/linear_combination.h" | |
#include "cutlass/epilogue/thread/activation.h" | |
#include "cutlass/gemm/dispatch_policy.hpp" | |
#include "cutlass/gemm/kernel/tile_scheduler.hpp" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
// We compile all models with -fvisibility=hidden. Any symbols that need to be | |
// exposed in the final shared library must be declared with PT_EXPORT to make | |
// them visible. | |
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) | |
#define PT_EXPORT __attribute__((__visibility__("default"))) | |
#else | |
#ifdef _WIN32 | |
#define PT_EXPORT __declspec(dllexport) | |
#else | |
#define PT_EXPORT | |
#endif | |
#endif | |
using namespace cute; | |
#define CUTLASS_CHECK(status) \ | |
{ \ | |
cutlass::Status error = status; \ | |
if (error != cutlass::Status::kSuccess) { \ | |
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \ | |
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \ | |
throw std::runtime_error(msg); \ | |
} \ | |
} | |
// Used as pass-through functor in EVT just for type casting / rounding | |
template <typename T> | |
struct identity_op { | |
CUTLASS_HOST_DEVICE | |
T operator()(T val) const { return val; } | |
}; | |
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< | |
cute::Shape<_128, _128, _64>, cutlass::epilogue::collective::EpilogueTileAuto, | |
float, cutlass::half_t, | |
cutlass::epilogue::TmaWarpSpecializedCooperative | |
>; | |
using Accum = cutlass::epilogue::fusion::Sm90AccFetch; | |
using Arg21Descriptor = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, cutlass::half_t>; | |
using Arg21 = cutlass::epilogue::fusion::Sm90AuxLoad< | |
Arg21Descriptor::Stages, typename Arg21Descriptor::EpilogueTile, cutlass::half_t, | |
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename Arg21Descriptor::SmemLayoutAtom, typename Arg21Descriptor::CopyOpS2R | |
>; | |
using Compute0 = cutlass::epilogue::fusion::Sm90Compute< | |
cutlass::epilogue::thread::ReLu, float, float, | |
cutlass::FloatRoundStyle::round_to_nearest | |
>; | |
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT< | |
Compute0, | |
Accum>; | |
using Tmp0Descriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< | |
EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, cutlass::half_t | |
>; | |
using Tmp0 = cutlass::epilogue::fusion::Sm90AuxStore< | |
Tmp0Descriptor::Stages, typename Tmp0Descriptor::EpilogueTile, cutlass::half_t, | |
cutlass::FloatRoundStyle::round_to_nearest, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename Tmp0Descriptor::SmemLayoutAtom, | |
typename Tmp0Descriptor::CopyOpR2S | |
>; | |
using EVTTmp0 = cutlass::epilogue::fusion::Sm90EVT< | |
Tmp0, | |
EVTCompute0>; | |
using Compute1 = cutlass::epilogue::fusion::Sm90Compute< | |
cutlass::plus, float, float, | |
cutlass::FloatRoundStyle::round_to_nearest | |
>; | |
using Compute2 = cutlass::epilogue::fusion::Sm90Compute< | |
cutlass::plus, cutlass::half_t, float, | |
cutlass::FloatRoundStyle::round_to_nearest | |
>; | |
using DagCompute2 = cutlass::epilogue::fusion::Sm90TopologicalVisitor< | |
float, | |
cute::tuple< | |
cute::seq<>, | |
cute::seq<>, | |
cute::seq<1, 0>, | |
cute::seq<1, 2>, | |
>, | |
Arg21, | |
EVTTmp0, | |
Compute1, | |
Compute2 | |
>; | |
using ElementD = cutlass::half_t; | |
using StrideD = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>; | |
using ElementC = void; | |
using StrideC = StrideD; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue = | |
typename cutlass::epilogue::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cute::Shape<cute::_128, cute::_128, cute::_64>, | |
cute::Shape<cute::_1, cute::_2, cute::_1>, | |
cutlass::epilogue::collective::EpilogueTileAuto, | |
float, float, | |
void, cutlass::layout::ColumnMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::epilogue::TmaWarpSpecializedCooperative, | |
DagCompute2 | |
>::CollectiveOp; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop = | |
typename cutlass::gemm::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
float, | |
cute::Shape<cute::_128, cute::_128, cute::_64>, | |
cute::Shape<cute::_1, cute::_2, cute::_1>, | |
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue::SharedStorage))>, | |
cutlass::gemm::KernelTmaWarpSpecializedCooperative | |
>::CollectiveOp; | |
// Gemm operator cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< | |
cute::Shape<int,int,int,int>, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue, | |
cutlass::gemm::StreamKScheduler>; | |
// Define named type | |
struct cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma : | |
public cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base { }; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma>; | |
// When workspace_size is not a nullptr, populates requested workspace_size and returns. | |
// Otherwise, computes the Gemm kernel using the given workspace ptr. | |
extern "C" { | |
PT_EXPORT int cutlass_fused_add_mm_relu_b08c11bd(const __half* X, const __half* W, const __half* arg2_1, const __half* Y, __half* buf1, __half* buf2, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const uint8_t swizzle, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) { | |
try { | |
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::ElementAccumulator; | |
using coord_t = cutlass::gemm::GemmCoord::Index; | |
static cutlass::KernelHardwareInfo hw_info; | |
if (hw_info.sm_count == 0) { | |
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); | |
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); | |
} | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::Arguments arguments; | |
// Initialize GemmUniversal3xInstance arguments. | |
arguments = { | |
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode | |
{ | |
static_cast<coord_t>(M), | |
static_cast<coord_t>(N), | |
static_cast<coord_t>(K), | |
static_cast<coord_t>(B) | |
}, // ProblemShape problem_shape | |
{ | |
(cutlass::half_t*)(X), // ElementA const* ptr_A | |
{ | |
lda /* stride_x0 */, | |
cute::Int<1>{} /* stride_x1 */, | |
0 /* batch_stride_x */ | |
}, // StrideA dA | |
(cutlass::half_t*)(W), // ElementB const* ptr_B | |
{ | |
cute::Int<1>{} /* stride_w1 */, | |
ldb /* stride_w0 */, | |
0 /* batch_stride_w */ | |
}, // StrideB dB | |
}, // MainloopArguments mainloop | |
// see https://tinyurl.com/4rk89z48 | |
{ | |
{ /* thread */ | |
{/* ptr_aux */ (cutlass::half_t*) arg2_1, /* null_default */ cutlass::half_t(0), /* dAux */ {512, _1{}, _0{}}}, /* arg2_1 */ | |
{ /* tmp_0 */ | |
{ /* compute_0 */ | |
{}, /* accum */ | |
{}, /* compute_0 */ | |
}, | |
{/* ptr_aux */ (cutlass::half_t*) buf1, /* dAux */ {512, _1{}, _0{}}}, /* tmp_0 */ | |
}, | |
{}, /* compute_1 */ | |
{}, /* compute_2 */ | |
} | |
, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) | |
nullptr, // ElementC const* ptr_C | |
{ | |
cute::Int<1>{} /* stride_bias0 */, | |
cute::Int<1>{} /* stride_bias1 */, | |
0 /* batch_stride_bias */ | |
}, // StrideC dC | |
(cutlass::half_t*)(buf2), // ElementD const* ptr_D | |
{ | |
ldd /* stride_y0 */, | |
cute::Int<1>{} /* stride_y1 */, | |
0 /* batch_stride_y */ | |
}, // StrideD dD | |
}, // EpilogueArguments epilogue, | |
hw_info | |
}; | |
arguments.scheduler.max_swizzle_size = swizzle; | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type gemm_op; | |
if (workspace_size) { | |
*workspace_size = gemm_op.get_workspace_size(arguments); | |
return 0; | |
} | |
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers | |
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS | |
{ | |
auto status = gemm_op.can_implement(arguments); | |
CUTLASS_CHECK(status); | |
} | |
#endif | |
#ifdef CUTLASS_DEBUG_TRACE_LEVEL | |
#if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
{ | |
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
// we don't need a print statement, it's happening inside the function. | |
gemm_op.maximum_active_blocks(); | |
} | |
#endif | |
#endif | |
{ | |
auto status = gemm_op.initialize(arguments, workspace, stream); | |
CUTLASS_CHECK(status); | |
} | |
{ | |
auto status = gemm_op(stream); | |
CUTLASS_CHECK(status); | |
} | |
} | |
catch (std::exception& e) { | |
std::cerr << "Runtime error: " << e.what() << std::endl; | |
return -1; | |
} | |
catch (...) { | |
return -1; | |
} | |
return 0; | |
} | |
} | |
// configuration name: cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma | |
''', 'so', aot_compile=False) | |
async_compile.wait(globals()) | |
del async_compile | |
def call(args): | |
arg0_1, arg1_1, arg2_1 = args | |
args.clear() | |
assert_size_stride(arg0_1, (1024, 512), (512, 1)) | |
assert_size_stride(arg1_1, (512, 512), (512, 1)) | |
assert_size_stride(arg2_1, (1024, 512), (512, 1)) | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
buf1 = empty_strided_cuda((1024, 512), (512, 1), torch.float16) | |
buf2 = empty_strided_cuda((1024, 512), (512, 1), torch.float16) | |
stream0 = get_raw_stream(0) | |
cutlass_fused_add_mm_relu_b08c11bd.cutlass_fused_add_mm_relu_b08c11bd(c_void_p(arg0_1.data_ptr()), c_void_p(arg1_1.data_ptr()), c_void_p(arg2_1.data_ptr()), c_void_p(buf2.data_ptr()), c_void_p(buf1.data_ptr()), c_void_p(buf2.data_ptr()), 1024, 512, 512, 1, 512, 512, 0, 512, 1, None, None, c_void_p(stream0)) | |
del arg0_1 | |
del arg1_1 | |
del arg2_1 | |
return (buf2, buf1, ) | |
def benchmark_compiled_module(times=10, repeat=10): | |
from torch._dynamo.testing import rand_strided | |
from torch._inductor.utils import print_performance | |
arg0_1 = rand_strided((1024, 512), (512, 1), device='cuda:0', dtype=torch.float16) | |
arg1_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float16) | |
arg2_1 = rand_strided((1024, 512), (512, 1), device='cuda:0', dtype=torch.float16) | |
fn = lambda: call([arg0_1, arg1_1, arg2_1]) | |
return print_performance(fn, times=times, repeat=repeat) | |
if __name__ == "__main__": | |
from torch._inductor.wrapper_benchmark import compiled_module_main | |
compiled_module_main('None', benchmark_compiled_module) | |
[DEBUG]:Output code written to: /tmp/torchinductor_mlazos/ak/cakiohsxeutme3m3pv5zhmwelrwwdj7onpshavvnsedj4mjuqb4b.py | |
[INFO]:Output code written to: /tmp/torchinductor_mlazos/ak/cakiohsxeutme3m3pv5zhmwelrwwdj7onpshavvnsedj4mjuqb4b.py | |
frames [('total', 1), ('ok', 1)] | |
stats [('calls_captured', 4), ('unique_graphs', 1)] | |
inductor [('cuda_epilogue_fusion_counter', 2), ('fxgraph_cache_miss', 1)] | |
aot_autograd [('total', 1), ('autograd_cache_miss', 1), ('autograd_cache_saved', 1), ('ok', 1)] | |
graph_break [] | |
aten_mm_info [('aten.mm_1024_512_512', 1)] | |
. | |
---------------------------------------------------------------------- | |
Ran 1 test in 55.937s | |
OK |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@unittest.skipIf(not SM90OrLater, "need sm_90") | |
@use_evt_config | |
@evt_all_ops | |
def test_evt_multi_output(self, op): | |
class TestModel(torch.nn.Module): | |
def forward(self, a, b, extra_args): | |
acc = a @ b | |
z0 = acc.relu() | |
z = z0 + extra_args[0] | |
y = z0 + z | |
return y, z0 | |
M = 1024 | |
N = 512 | |
a = torch.ones(M, N).cuda().half() | |
b = torch.ones(N, N).cuda().half() | |
extra_args = gen_args(op, (M, N)) | |
model = TestModel().cuda() | |
result = torch.compile(model)(a, b, extra_args) | |
ref_result = model(a, b, extra_args) | |
self.assertEqual( | |
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 2 | |
) | |
torch.testing.assert_close(result, ref_result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment