Last active
May 15, 2025 03:59
-
-
Save mlazos/06931e54fcf0849a3497d27deb9a8ee8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <exception> | |
#include <iostream> | |
#include <memory> | |
#include <random> | |
#include <vector> | |
#include "cute/tensor.hpp" | |
#include "cutlass/cutlass.h" | |
#include "cutlass/numeric_types.h" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include "cutlass/util/device_memory.h" | |
#include "cutlass/gemm/gemm.h" | |
#include "cutlass/gemm/device/gemm_universal.h" | |
#include "cutlass/gemm/device/gemm_universal_adapter.h" | |
#include "cutlass/gemm/kernel/gemm_universal.hpp" | |
#include "cutlass/gemm/device/gemm_sparse.h" | |
#include "cutlass/gemm/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/default_epilogue.hpp" | |
#include "cutlass/epilogue/thread/linear_combination.h" | |
#include "cutlass/epilogue/thread/activation.h" | |
#include "cutlass/gemm/dispatch_policy.hpp" | |
#include "cutlass/gemm/kernel/tile_scheduler.hpp" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
// We compile all models with -fvisibility=hidden. Any symbols that need to be | |
// exposed in the final shared library must be declared with PT_EXPORT to make | |
// them visible. | |
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) | |
#define PT_EXPORT __attribute__((__visibility__("default"))) | |
#else | |
#ifdef _WIN32 | |
#define PT_EXPORT __declspec(dllexport) | |
#else | |
#define PT_EXPORT | |
#endif | |
#endif | |
using namespace cute; | |
#define CUTLASS_CHECK(status) \ | |
{ \ | |
cutlass::Status error = status; \ | |
if (error != cutlass::Status::kSuccess) { \ | |
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \ | |
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \ | |
throw std::runtime_error(msg); \ | |
} \ | |
} | |
// Used as pass-through functor in EVT just for type casting / rounding | |
template <typename T> | |
struct identity_op { | |
CUTLASS_HOST_DEVICE | |
T operator()(T val) const { return val; } | |
}; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue = | |
typename cutlass::epilogue::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cute::Shape<cute::_128, cute::_128, cute::_64>, | |
cute::Shape<cute::_1, cute::_2, cute::_1>, | |
cutlass::epilogue::collective::EpilogueTileAuto, | |
float, float, | |
void, cutlass::layout::ColumnMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::epilogue::TmaWarpSpecializedCooperative, | |
cutlass::epilogue::fusion::LinearCombination< | |
cutlass::half_t, | |
float, | |
void, | |
float | |
> | |
>::CollectiveOp; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop = | |
typename cutlass::gemm::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
float, | |
cute::Shape<cute::_128, cute::_128, cute::_64>, | |
cute::Shape<cute::_1, cute::_2, cute::_1>, | |
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue::SharedStorage))>, | |
cutlass::gemm::KernelTmaWarpSpecializedCooperative | |
>::CollectiveOp; | |
// Gemm operator cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< | |
cute::Shape<int,int,int,int>, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue, | |
cutlass::gemm::StreamKScheduler>; | |
// Define named type | |
struct cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma : | |
public cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base { }; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma>; | |
// When workspace_size is not a nullptr, populates requested workspace_size and returns. | |
// Otherwise, computes the Gemm kernel using the given workspace ptr. | |
extern "C" { | |
PT_EXPORT int Placeholder.KERNEL_NAME(const at::Half* X, const at::Half* W, at::Half* Y, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const uint8_t swizzle, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) { | |
try { | |
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::ElementAccumulator; | |
using coord_t = cutlass::gemm::GemmCoord::Index; | |
static cutlass::KernelHardwareInfo hw_info; | |
if (hw_info.sm_count == 0) { | |
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); | |
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); | |
} | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::Arguments arguments; | |
// Initialize GemmUniversal3xInstance arguments. | |
arguments = { | |
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode | |
{ | |
static_cast<coord_t>(M), | |
static_cast<coord_t>(N), | |
static_cast<coord_t>(K), | |
static_cast<coord_t>(B) | |
}, // ProblemShape problem_shape | |
{ | |
(cutlass::half_t*)(X), // ElementA const* ptr_A | |
{ | |
lda /* stride_x0 */, | |
cute::Int<1>{} /* stride_x1 */, | |
0 /* batch_stride_x */ | |
}, // StrideA dA | |
(cutlass::half_t*)(W), // ElementB const* ptr_B | |
{ | |
cute::Int<1>{} /* stride_w1 */, | |
ldb /* stride_w0 */, | |
0 /* batch_stride_w */ | |
}, // StrideB dB | |
}, // MainloopArguments mainloop | |
// see https://tinyurl.com/4rk89z48 | |
{ | |
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) | |
nullptr, // ElementC const* ptr_C | |
{ | |
cute::Int<1>{} /* stride_bias0 */, | |
cute::Int<1>{} /* stride_bias1 */, | |
0 /* batch_stride_bias */ | |
}, // StrideC dC | |
(cutlass::half_t*)(Y), // ElementD const* ptr_D | |
{ | |
ldd /* stride_y0 */, | |
cute::Int<1>{} /* stride_y1 */, | |
0 /* batch_stride_y */ | |
}, // StrideD dD | |
}, // EpilogueArguments epilogue, | |
hw_info | |
}; | |
arguments.scheduler.max_swizzle_size = swizzle; | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type gemm_op; | |
if (workspace_size) { | |
*workspace_size = gemm_op.get_workspace_size(arguments); | |
return 0; | |
} | |
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers | |
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS | |
{ | |
auto status = gemm_op.can_implement(arguments); | |
CUTLASS_CHECK(status); | |
} | |
#endif | |
#ifdef CUTLASS_DEBUG_TRACE_LEVEL | |
#if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
{ | |
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
// we don't need a print statement, it's happening inside the function. | |
gemm_op.maximum_active_blocks(); | |
} | |
#endif | |
#endif | |
{ | |
auto status = gemm_op.initialize(arguments, workspace, stream); | |
CUTLASS_CHECK(status); | |
} | |
{ | |
auto status = gemm_op(stream); | |
CUTLASS_CHECK(status); | |
} | |
} | |
catch (std::exception& e) { | |
std::cerr << "Runtime error: " << e.what() << std::endl; | |
return -1; | |
} | |
catch (...) { | |
return -1; | |
} | |
return 0; | |
} | |
} | |
// configuration name: cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <exception> | |
#include <iostream> | |
#include <memory> | |
#include <random> | |
#include <vector> | |
#include "cute/tensor.hpp" | |
#include "cutlass/cutlass.h" | |
#include "cutlass/numeric_types.h" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include "cutlass/util/device_memory.h" | |
#include "cutlass/gemm/gemm.h" | |
#include "cutlass/gemm/device/gemm_universal.h" | |
#include "cutlass/gemm/device/gemm_universal_adapter.h" | |
#include "cutlass/gemm/kernel/gemm_universal.hpp" | |
#include "cutlass/gemm/device/gemm_sparse.h" | |
#include "cutlass/gemm/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/default_epilogue.hpp" | |
#include "cutlass/epilogue/thread/linear_combination.h" | |
#include "cutlass/epilogue/thread/activation.h" | |
#include "cutlass/gemm/dispatch_policy.hpp" | |
#include "cutlass/gemm/kernel/tile_scheduler.hpp" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
// We compile all models with -fvisibility=hidden. Any symbols that need to be | |
// exposed in the final shared library must be declared with PT_EXPORT to make | |
// them visible. | |
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) | |
#define PT_EXPORT __attribute__((__visibility__("default"))) | |
#else | |
#ifdef _WIN32 | |
#define PT_EXPORT __declspec(dllexport) | |
#else | |
#define PT_EXPORT | |
#endif | |
#endif | |
using namespace cute; | |
#define CUTLASS_CHECK(status) \ | |
{ \ | |
cutlass::Status error = status; \ | |
if (error != cutlass::Status::kSuccess) { \ | |
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \ | |
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \ | |
throw std::runtime_error(msg); \ | |
} \ | |
} | |
// Used as pass-through functor in EVT just for type casting / rounding | |
template <typename T> | |
struct identity_op { | |
CUTLASS_HOST_DEVICE | |
T operator()(T val) const { return val; } | |
}; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue = | |
typename cutlass::epilogue::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cute::Shape<cute::_128, cute::_128, cute::_64>, | |
cute::Shape<cute::_1, cute::_2, cute::_1>, | |
cutlass::epilogue::collective::EpilogueTileAuto, | |
float, float, | |
void, cutlass::layout::ColumnMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::epilogue::TmaWarpSpecializedCooperative, | |
cutlass::epilogue::fusion::LinearCombination< | |
cutlass::half_t, | |
float, | |
void, | |
float | |
> | |
>::CollectiveOp; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop = | |
typename cutlass::gemm::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
float, | |
cute::Shape<cute::_128, cute::_128, cute::_64>, | |
cute::Shape<cute::_1, cute::_2, cute::_1>, | |
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue::SharedStorage))>, | |
cutlass::gemm::KernelTmaWarpSpecializedCooperative | |
>::CollectiveOp; | |
// Gemm operator cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< | |
cute::Shape<int,int,int,int>, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue, | |
cutlass::gemm::StreamKScheduler>; | |
// Define named type | |
struct cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma : | |
public cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base { }; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma>; | |
// When workspace_size is not a nullptr, populates requested workspace_size and returns. | |
// Otherwise, computes the Gemm kernel using the given workspace ptr. | |
extern "C" { | |
PT_EXPORT int Placeholder.KERNEL_NAME(const at::Half* X, const at::Half* W, at::Half* Y, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const uint8_t swizzle, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) { | |
try { | |
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::ElementAccumulator; | |
using coord_t = cutlass::gemm::GemmCoord::Index; | |
static cutlass::KernelHardwareInfo hw_info; | |
if (hw_info.sm_count == 0) { | |
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); | |
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); | |
} | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::Arguments arguments; | |
// Initialize GemmUniversal3xInstance arguments. | |
arguments = { | |
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode | |
{ | |
static_cast<coord_t>(M), | |
static_cast<coord_t>(N), | |
static_cast<coord_t>(K), | |
static_cast<coord_t>(B) | |
}, // ProblemShape problem_shape | |
{ | |
(cutlass::half_t*)(X), // ElementA const* ptr_A | |
{ | |
lda /* stride_x0 */, | |
cute::Int<1>{} /* stride_x1 */, | |
0 /* batch_stride_x */ | |
}, // StrideA dA | |
(cutlass::half_t*)(W), // ElementB const* ptr_B | |
{ | |
cute::Int<1>{} /* stride_w1 */, | |
ldb /* stride_w0 */, | |
0 /* batch_stride_w */ | |
}, // StrideB dB | |
}, // MainloopArguments mainloop | |
// see https://tinyurl.com/4rk89z48 | |
{ | |
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) | |
nullptr, // ElementC const* ptr_C | |
{ | |
cute::Int<1>{} /* stride_bias0 */, | |
cute::Int<1>{} /* stride_bias1 */, | |
0 /* batch_stride_bias */ | |
}, // StrideC dC | |
(cutlass::half_t*)(Y), // ElementD const* ptr_D | |
{ | |
ldd /* stride_y0 */, | |
cute::Int<1>{} /* stride_y1 */, | |
0 /* batch_stride_y */ | |
}, // StrideD dD | |
}, // EpilogueArguments epilogue, | |
hw_info | |
}; | |
arguments.scheduler.max_swizzle_size = swizzle; | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type gemm_op; | |
if (workspace_size) { | |
*workspace_size = gemm_op.get_workspace_size(arguments); | |
return 0; | |
} | |
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers | |
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS | |
{ | |
auto status = gemm_op.can_implement(arguments); | |
CUTLASS_CHECK(status); | |
} | |
#endif | |
#ifdef CUTLASS_DEBUG_TRACE_LEVEL | |
#if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
{ | |
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
// we don't need a print statement, it's happening inside the function. | |
gemm_op.maximum_active_blocks(); | |
} | |
#endif | |
#endif | |
{ | |
auto status = gemm_op.initialize(arguments, workspace, stream); | |
CUTLASS_CHECK(status); | |
} | |
{ | |
auto status = gemm_op(stream); | |
CUTLASS_CHECK(status); | |
} | |
} | |
catch (std::exception& e) { | |
std::cerr << "Runtime error: " << e.what() << std::endl; | |
return -1; | |
} | |
catch (...) { | |
return -1; | |
} | |
return 0; | |
} | |
} | |
// configuration name: cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <exception> | |
#include <iostream> | |
#include <memory> | |
#include <random> | |
#include <vector> | |
#include "cute/tensor.hpp" | |
#include "cutlass/cutlass.h" | |
#include "cutlass/numeric_types.h" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include "cutlass/util/device_memory.h" | |
#include "cutlass/gemm/gemm.h" | |
#include "cutlass/gemm/device/gemm_universal.h" | |
#include "cutlass/gemm/device/gemm_universal_adapter.h" | |
#include "cutlass/gemm/kernel/gemm_universal.hpp" | |
#include "cutlass/gemm/device/gemm_sparse.h" | |
#include "cutlass/gemm/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/default_epilogue.hpp" | |
#include "cutlass/epilogue/thread/linear_combination.h" | |
#include "cutlass/epilogue/thread/activation.h" | |
#include "cutlass/gemm/dispatch_policy.hpp" | |
#include "cutlass/gemm/kernel/tile_scheduler.hpp" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
// We compile all models with -fvisibility=hidden. Any symbols that need to be | |
// exposed in the final shared library must be declared with PT_EXPORT to make | |
// them visible. | |
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) | |
#define PT_EXPORT __attribute__((__visibility__("default"))) | |
#else | |
#ifdef _WIN32 | |
#define PT_EXPORT __declspec(dllexport) | |
#else | |
#define PT_EXPORT | |
#endif | |
#endif | |
using namespace cute; | |
#define CUTLASS_CHECK(status) \ | |
{ \ | |
cutlass::Status error = status; \ | |
if (error != cutlass::Status::kSuccess) { \ | |
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \ | |
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \ | |
throw std::runtime_error(msg); \ | |
} \ | |
} | |
// Used as pass-through functor in EVT just for type casting / rounding | |
template <typename T> | |
struct identity_op { | |
CUTLASS_HOST_DEVICE | |
T operator()(T val) const { return val; } | |
}; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue = | |
typename cutlass::epilogue::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cute::Shape<cute::_128, cute::_128, cute::_64>, | |
cute::Shape<cute::_1, cute::_2, cute::_1>, | |
cutlass::epilogue::collective::EpilogueTileAuto, | |
float, float, | |
void, cutlass::layout::ColumnMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::epilogue::TmaWarpSpecializedCooperative, | |
cutlass::epilogue::fusion::LinearCombination< | |
cutlass::half_t, | |
float, | |
void, | |
float | |
> | |
>::CollectiveOp; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop = | |
typename cutlass::gemm::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
float, | |
cute::Shape<cute::_128, cute::_128, cute::_64>, | |
cute::Shape<cute::_1, cute::_2, cute::_1>, | |
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue::SharedStorage))>, | |
cutlass::gemm::KernelTmaWarpSpecializedCooperative | |
>::CollectiveOp; | |
// Gemm operator cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< | |
cute::Shape<int,int,int,int>, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue, | |
cutlass::gemm::StreamKScheduler>; | |
// Define named type | |
struct cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma : | |
public cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base { }; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma>; | |
// When workspace_size is not a nullptr, populates requested workspace_size and returns. | |
// Otherwise, computes the Gemm kernel using the given workspace ptr. | |
extern "C" { | |
PT_EXPORT int Placeholder.KERNEL_NAME(const at::Half* X, const at::Half* W, at::Half* Y, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const uint8_t swizzle, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) { | |
try { | |
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::ElementAccumulator; | |
using coord_t = cutlass::gemm::GemmCoord::Index; | |
static cutlass::KernelHardwareInfo hw_info; | |
if (hw_info.sm_count == 0) { | |
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); | |
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); | |
} | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::Arguments arguments; | |
// Initialize GemmUniversal3xInstance arguments. | |
arguments = { | |
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode | |
{ | |
static_cast<coord_t>(M), | |
static_cast<coord_t>(N), | |
static_cast<coord_t>(K), | |
static_cast<coord_t>(B) | |
}, // ProblemShape problem_shape | |
{ | |
(cutlass::half_t*)(X), // ElementA const* ptr_A | |
{ | |
lda /* stride_x0 */, | |
cute::Int<1>{} /* stride_x1 */, | |
0 /* batch_stride_x */ | |
}, // StrideA dA | |
(cutlass::half_t*)(W), // ElementB const* ptr_B | |
{ | |
cute::Int<1>{} /* stride_w1 */, | |
ldb /* stride_w0 */, | |
0 /* batch_stride_w */ | |
}, // StrideB dB | |
}, // MainloopArguments mainloop | |
// see https://tinyurl.com/4rk89z48 | |
{ | |
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) | |
nullptr, // ElementC const* ptr_C | |
{ | |
cute::Int<1>{} /* stride_bias0 */, | |
cute::Int<1>{} /* stride_bias1 */, | |
0 /* batch_stride_bias */ | |
}, // StrideC dC | |
(cutlass::half_t*)(Y), // ElementD const* ptr_D | |
{ | |
ldd /* stride_y0 */, | |
cute::Int<1>{} /* stride_y1 */, | |
0 /* batch_stride_y */ | |
}, // StrideD dD | |
}, // EpilogueArguments epilogue, | |
hw_info | |
}; | |
arguments.scheduler.max_swizzle_size = swizzle; | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type gemm_op; | |
if (workspace_size) { | |
*workspace_size = gemm_op.get_workspace_size(arguments); | |
return 0; | |
} | |
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers | |
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS | |
{ | |
auto status = gemm_op.can_implement(arguments); | |
CUTLASS_CHECK(status); | |
} | |
#endif | |
#ifdef CUTLASS_DEBUG_TRACE_LEVEL | |
#if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
{ | |
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
// we don't need a print statement, it's happening inside the function. | |
gemm_op.maximum_active_blocks(); | |
} | |
#endif | |
#endif | |
{ | |
auto status = gemm_op.initialize(arguments, workspace, stream); | |
CUTLASS_CHECK(status); | |
} | |
{ | |
auto status = gemm_op(stream); | |
CUTLASS_CHECK(status); | |
} | |
} | |
catch (std::exception& e) { | |
std::cerr << "Runtime error: " << e.what() << std::endl; | |
return -1; | |
} | |
catch (...) { | |
return -1; | |
} | |
return 0; | |
} | |
} | |
// configuration name: cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment