Skip to content

Instantly share code, notes, and snippets.

@mlazos
Last active May 15, 2025 03:59
Show Gist options
  • Save mlazos/06931e54fcf0849a3497d27deb9a8ee8 to your computer and use it in GitHub Desktop.
Save mlazos/06931e54fcf0849a3497d27deb9a8ee8 to your computer and use it in GitHub Desktop.
#include <exception>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/device/gemm_sparse.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
// We compile all models with -fvisibility=hidden. Any symbols that need to be
// exposed in the final shared library must be declared with PT_EXPORT to make
// them visible.
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
#define PT_EXPORT __attribute__((__visibility__("default")))
#else
#ifdef _WIN32
#define PT_EXPORT __declspec(dllexport)
#else
#define PT_EXPORT
#endif
#endif
using namespace cute;
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \
throw std::runtime_error(msg); \
} \
}
// Used as pass-through functor in EVT just for type casting / rounding
template <typename T>
struct identity_op {
CUTLASS_HOST_DEVICE
T operator()(T val) const { return val; }
};
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cute::Shape<cute::_128, cute::_128, cute::_64>,
cute::Shape<cute::_1, cute::_2, cute::_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, cutlass::layout::ColumnMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::epilogue::TmaWarpSpecializedCooperative,
cutlass::epilogue::fusion::LinearCombination<
cutlass::half_t,
float,
void,
float
>
>::CollectiveOp;
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
float,
cute::Shape<cute::_128, cute::_128, cute::_64>,
cute::Shape<cute::_1, cute::_2, cute::_1>,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue::SharedStorage))>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
// Gemm operator cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop,
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue,
cutlass::gemm::StreamKScheduler>;
// Define named type
struct cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma :
public cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base { };
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma>;
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
// Otherwise, computes the Gemm kernel using the given workspace ptr.
extern "C" {
PT_EXPORT int Placeholder.KERNEL_NAME(const at::Half* X, const at::Half* W, at::Half* Y, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const uint8_t swizzle, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) {
try {
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
if (hw_info.sm_count == 0) {
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::Arguments arguments;
// Initialize GemmUniversal3xInstance arguments.
arguments = {
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
{
static_cast<coord_t>(M),
static_cast<coord_t>(N),
static_cast<coord_t>(K),
static_cast<coord_t>(B)
}, // ProblemShape problem_shape
{
(cutlass::half_t*)(X), // ElementA const* ptr_A
{
lda /* stride_x0 */,
cute::Int<1>{} /* stride_x1 */,
0 /* batch_stride_x */
}, // StrideA dA
(cutlass::half_t*)(W), // ElementB const* ptr_B
{
cute::Int<1>{} /* stride_w1 */,
ldb /* stride_w0 */,
0 /* batch_stride_w */
}, // StrideB dB
}, // MainloopArguments mainloop
// see https://tinyurl.com/4rk89z48
{
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
nullptr, // ElementC const* ptr_C
{
cute::Int<1>{} /* stride_bias0 */,
cute::Int<1>{} /* stride_bias1 */,
0 /* batch_stride_bias */
}, // StrideC dC
(cutlass::half_t*)(Y), // ElementD const* ptr_D
{
ldd /* stride_y0 */,
cute::Int<1>{} /* stride_y1 */,
0 /* batch_stride_y */
}, // StrideD dD
}, // EpilogueArguments epilogue,
hw_info
};
arguments.scheduler.max_swizzle_size = swizzle;
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type gemm_op;
if (workspace_size) {
*workspace_size = gemm_op.get_workspace_size(arguments);
return 0;
}
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS
{
auto status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
}
#endif
#ifdef CUTLASS_DEBUG_TRACE_LEVEL
#if CUTLASS_DEBUG_TRACE_LEVEL == 1
{
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1
// we don't need a print statement, it's happening inside the function.
gemm_op.maximum_active_blocks();
}
#endif
#endif
{
auto status = gemm_op.initialize(arguments, workspace, stream);
CUTLASS_CHECK(status);
}
{
auto status = gemm_op(stream);
CUTLASS_CHECK(status);
}
}
catch (std::exception& e) {
std::cerr << "Runtime error: " << e.what() << std::endl;
return -1;
}
catch (...) {
return -1;
}
return 0;
}
}
// configuration name: cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma
#include <exception>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/device/gemm_sparse.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
// We compile all models with -fvisibility=hidden. Any symbols that need to be
// exposed in the final shared library must be declared with PT_EXPORT to make
// them visible.
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
#define PT_EXPORT __attribute__((__visibility__("default")))
#else
#ifdef _WIN32
#define PT_EXPORT __declspec(dllexport)
#else
#define PT_EXPORT
#endif
#endif
using namespace cute;
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \
throw std::runtime_error(msg); \
} \
}
// Used as pass-through functor in EVT just for type casting / rounding
template <typename T>
struct identity_op {
CUTLASS_HOST_DEVICE
T operator()(T val) const { return val; }
};
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cute::Shape<cute::_128, cute::_128, cute::_64>,
cute::Shape<cute::_1, cute::_2, cute::_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, cutlass::layout::ColumnMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::epilogue::TmaWarpSpecializedCooperative,
cutlass::epilogue::fusion::LinearCombination<
cutlass::half_t,
float,
void,
float
>
>::CollectiveOp;
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
float,
cute::Shape<cute::_128, cute::_128, cute::_64>,
cute::Shape<cute::_1, cute::_2, cute::_1>,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue::SharedStorage))>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
// Gemm operator cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop,
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue,
cutlass::gemm::StreamKScheduler>;
// Define named type
struct cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma :
public cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base { };
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma>;
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
// Otherwise, computes the Gemm kernel using the given workspace ptr.
extern "C" {
PT_EXPORT int Placeholder.KERNEL_NAME(const at::Half* X, const at::Half* W, at::Half* Y, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const uint8_t swizzle, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) {
try {
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
if (hw_info.sm_count == 0) {
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::Arguments arguments;
// Initialize GemmUniversal3xInstance arguments.
arguments = {
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
{
static_cast<coord_t>(M),
static_cast<coord_t>(N),
static_cast<coord_t>(K),
static_cast<coord_t>(B)
}, // ProblemShape problem_shape
{
(cutlass::half_t*)(X), // ElementA const* ptr_A
{
lda /* stride_x0 */,
cute::Int<1>{} /* stride_x1 */,
0 /* batch_stride_x */
}, // StrideA dA
(cutlass::half_t*)(W), // ElementB const* ptr_B
{
cute::Int<1>{} /* stride_w1 */,
ldb /* stride_w0 */,
0 /* batch_stride_w */
}, // StrideB dB
}, // MainloopArguments mainloop
// see https://tinyurl.com/4rk89z48
{
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
nullptr, // ElementC const* ptr_C
{
cute::Int<1>{} /* stride_bias0 */,
cute::Int<1>{} /* stride_bias1 */,
0 /* batch_stride_bias */
}, // StrideC dC
(cutlass::half_t*)(Y), // ElementD const* ptr_D
{
ldd /* stride_y0 */,
cute::Int<1>{} /* stride_y1 */,
0 /* batch_stride_y */
}, // StrideD dD
}, // EpilogueArguments epilogue,
hw_info
};
arguments.scheduler.max_swizzle_size = swizzle;
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type gemm_op;
if (workspace_size) {
*workspace_size = gemm_op.get_workspace_size(arguments);
return 0;
}
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS
{
auto status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
}
#endif
#ifdef CUTLASS_DEBUG_TRACE_LEVEL
#if CUTLASS_DEBUG_TRACE_LEVEL == 1
{
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1
// we don't need a print statement, it's happening inside the function.
gemm_op.maximum_active_blocks();
}
#endif
#endif
{
auto status = gemm_op.initialize(arguments, workspace, stream);
CUTLASS_CHECK(status);
}
{
auto status = gemm_op(stream);
CUTLASS_CHECK(status);
}
}
catch (std::exception& e) {
std::cerr << "Runtime error: " << e.what() << std::endl;
return -1;
}
catch (...) {
return -1;
}
return 0;
}
}
// configuration name: cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma
#include <exception>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/device/gemm_sparse.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
// We compile all models with -fvisibility=hidden. Any symbols that need to be
// exposed in the final shared library must be declared with PT_EXPORT to make
// them visible.
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
#define PT_EXPORT __attribute__((__visibility__("default")))
#else
#ifdef _WIN32
#define PT_EXPORT __declspec(dllexport)
#else
#define PT_EXPORT
#endif
#endif
using namespace cute;
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \
throw std::runtime_error(msg); \
} \
}
// Used as pass-through functor in EVT just for type casting / rounding
template <typename T>
struct identity_op {
CUTLASS_HOST_DEVICE
T operator()(T val) const { return val; }
};
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cute::Shape<cute::_128, cute::_128, cute::_64>,
cute::Shape<cute::_1, cute::_2, cute::_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, cutlass::layout::ColumnMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::epilogue::TmaWarpSpecializedCooperative,
cutlass::epilogue::fusion::LinearCombination<
cutlass::half_t,
float,
void,
float
>
>::CollectiveOp;
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
float,
cute::Shape<cute::_128, cute::_128, cute::_64>,
cute::Shape<cute::_1, cute::_2, cute::_1>,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue::SharedStorage))>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
// Gemm operator cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop,
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue,
cutlass::gemm::StreamKScheduler>;
// Define named type
struct cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma :
public cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base { };
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma>;
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
// Otherwise, computes the Gemm kernel using the given workspace ptr.
extern "C" {
PT_EXPORT int Placeholder.KERNEL_NAME(const at::Half* X, const at::Half* W, at::Half* Y, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const uint8_t swizzle, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) {
try {
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
if (hw_info.sm_count == 0) {
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::Arguments arguments;
// Initialize GemmUniversal3xInstance arguments.
arguments = {
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
{
static_cast<coord_t>(M),
static_cast<coord_t>(N),
static_cast<coord_t>(K),
static_cast<coord_t>(B)
}, // ProblemShape problem_shape
{
(cutlass::half_t*)(X), // ElementA const* ptr_A
{
lda /* stride_x0 */,
cute::Int<1>{} /* stride_x1 */,
0 /* batch_stride_x */
}, // StrideA dA
(cutlass::half_t*)(W), // ElementB const* ptr_B
{
cute::Int<1>{} /* stride_w1 */,
ldb /* stride_w0 */,
0 /* batch_stride_w */
}, // StrideB dB
}, // MainloopArguments mainloop
// see https://tinyurl.com/4rk89z48
{
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
nullptr, // ElementC const* ptr_C
{
cute::Int<1>{} /* stride_bias0 */,
cute::Int<1>{} /* stride_bias1 */,
0 /* batch_stride_bias */
}, // StrideC dC
(cutlass::half_t*)(Y), // ElementD const* ptr_D
{
ldd /* stride_y0 */,
cute::Int<1>{} /* stride_y1 */,
0 /* batch_stride_y */
}, // StrideD dD
}, // EpilogueArguments epilogue,
hw_info
};
arguments.scheduler.max_swizzle_size = swizzle;
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type gemm_op;
if (workspace_size) {
*workspace_size = gemm_op.get_workspace_size(arguments);
return 0;
}
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS
{
auto status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
}
#endif
#ifdef CUTLASS_DEBUG_TRACE_LEVEL
#if CUTLASS_DEBUG_TRACE_LEVEL == 1
{
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1
// we don't need a print statement, it's happening inside the function.
gemm_op.maximum_active_blocks();
}
#endif
#endif
{
auto status = gemm_op.initialize(arguments, workspace, stream);
CUTLASS_CHECK(status);
}
{
auto status = gemm_op(stream);
CUTLASS_CHECK(status);
}
}
catch (std::exception& e) {
std::cerr << "Runtime error: " << e.what() << std::endl;
return -1;
}
catch (...) {
return -1;
}
return 0;
}
}
// configuration name: cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment