Created
December 20, 2023 17:53
-
-
Save kadeng/8fa35f5e42ed111f8de8d8623f16ec88 to your computer and use it in GitHub Desktop.
Cutlass Error repro cases
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
# Change the environment variables to point to Cutlass and CUDA Toolkit and run this, | |
# passing any of the standalone repro_N.cu files as argument. It will compile and run the | |
# example. | |
set -x | |
export REPRO_CUTLASS_PATH=/home/klondenberg/github/pytorch/pytorch/third_party/cutlass | |
export REPRO_CUDA_PATH=/home/klondenberg/local/cuda121 | |
$REPRO_CUDA_PATH/bin/nvcc -t=0 -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -w -gencode=arch=compute_90a,code=[sm_90a,compute_90a] -O1 -std=c++17 --expt-relaxed-constexpr -Xcompiler=-fPIC --use_fast_math -Xcompiler=-fno-strict-aliasing -Xcompiler -fvisibility=hidden -Xcompiler=-Wconversion -I${REPRO_CUTLASS_PATH}/include -I${REPRO_CUTLASS_PATH}/tools/library/include -I${REPRO_CUTLASS_PATH}/tools/library/src -I${REPRO_CUTLASS_PATH}/tools/util/include -L${REPRO_CUDA_PATH}/lib64 -L${REPRO_CUDA_PATH}/lib64/stubs -lcuda -lcudart -DGENERATE_STANDALONE_RUNNER -DNDEBUG -DCUTLASS_DEBUG_TRACE_LEVEL=1 -o "${@}.exe" "$@" | |
"./${@}.exe" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <exception> | |
#include <iostream> | |
#include <memory> | |
#include <random> | |
#include <vector> | |
#include "cute/tensor.hpp" | |
#include "cutlass/cutlass.h" | |
#include "cutlass/numeric_types.h" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include "cutlass/util/device_memory.h" | |
#include "cutlass/gemm/gemm.h" | |
#include "cutlass/gemm/device/gemm_universal.h" | |
#include "cutlass/gemm/device/gemm_universal_adapter.h" | |
#include "cutlass/gemm/kernel/gemm_universal.hpp" | |
#include "cutlass/gemm/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/default_epilogue.hpp" | |
#include "cutlass/epilogue/thread/linear_combination.h" | |
#include "cutlass/epilogue/thread/activation.h" | |
#include "cutlass/gemm/dispatch_policy.hpp" | |
#include "cutlass/gemm/kernel/tile_scheduler.hpp" | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
#ifdef GENERATE_STANDALONE_RUNNER | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
#include "cutlass/util/reference/device/gemm_complex.h" | |
#include "cutlass/util/reference/device/tensor_compare.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include <iostream> | |
#endif | |
// We compile all models with -fvisibility=hidden. Any symbols that need to be | |
// exposed in the final shared library must be declared with PT_EXPORT to make | |
// them visible. | |
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) | |
#define PT_EXPORT __attribute__((__visibility__("default"))) | |
#else | |
#ifdef _WIN32 | |
#define PT_EXPORT __declspec(dllexport) | |
#else | |
#define PT_EXPORT | |
#endif | |
#endif | |
using bfloat16 = nv_bfloat16; | |
using namespace cute; | |
#define CUTLASS_CHECK(status) \ | |
{ \ | |
cutlass::Status error = status; \ | |
if (error != cutlass::Status::kSuccess) { \ | |
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \ | |
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \ | |
throw std::runtime_error(msg); \ | |
} \ | |
} | |
// Used as pass-through functor in EVT just for type casting / rounding | |
template <typename T> | |
struct identity_op { | |
CUTLASS_HOST_DEVICE | |
T operator()(T val) const { return val; } | |
}; | |
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecializedCooperative; | |
static_assert(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> || | |
cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>, | |
"Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue"); | |
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; | |
using ElementAcc = float; | |
using ElementD = cutlass::half_t; | |
using ElementC = cutlass::half_t; | |
using TileShapeMNK = cute::Shape<cute::_128, cute::_256, cute::_64>; | |
using ClusterShapeMNK = cute::Shape<cute::_2,cute::_1,cute::_1>; | |
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; | |
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< | |
TileShapeMNK, | |
EpilogueTileType, | |
ElementC, | |
ElementD, | |
EpilogueScheduleType | |
>; | |
using ADDMM_EVT = // alpha * acc + beta * C | |
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, | |
ElementD, ElementAcc, RoundStyle>, // beta * C + (alpha * acc) | |
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // beta | |
cutlass::epilogue::fusion::Sm90SrcFetch, // C | |
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementAcc, | |
ElementAcc, RoundStyle>, // alpha * acc | |
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // alpha | |
cutlass::epilogue::fusion::Sm90AccFetch // acc | |
>>; | |
using EVT_expr_1 = ADDMM_EVT; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue_functor = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,EVT_expr_1>; | |
; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue = | |
typename cutlass::epilogue::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
TileShapeMNK, | |
ClusterShapeMNK, | |
EpilogueTileType, | |
float, float, | |
cutlass::half_t, cutlass::layout::ColumnMajor, 8, | |
cutlass::half_t, cutlass::layout::ColumnMajor, 8, | |
EpilogueScheduleType, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue_functor | |
>::CollectiveOp; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop = | |
typename cutlass::gemm::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
float, | |
cute::Shape<cute::_128, cute::_256, cute::_64>, | |
cute::Shape<cute::_2,cute::_1,cute::_1>, | |
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue::SharedStorage)>, | |
cutlass::gemm::KernelTmaWarpSpecializedCooperative | |
>::CollectiveOp; | |
// Gemm operator cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< | |
cute::Shape<int,int,int,int>, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue, | |
cutlass::gemm::StreamKScheduler>; | |
// Define named type | |
struct cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma : | |
public cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base { }; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma>; | |
// When workspace_size is not a nullptr, populates requested workspace_size and returns. | |
// Otherwise, computes the Gemm kernel using the given workspace ptr. | |
extern "C" { | |
PT_EXPORT int cuda_cutlass_gemm_1(const half* Bias, const half* X, const half* W, half* Y, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) { | |
try { | |
int64_t B = 1; | |
int64_t M = 1024L; | |
int64_t K = 256L; | |
int64_t N = 109760L; | |
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::ElementAccumulator; | |
using coord_t = cutlass::gemm::GemmCoord::Index; | |
static cutlass::KernelHardwareInfo hw_info; | |
if (hw_info.sm_count == 0) { | |
// @TODO kadeng: Add support for Multi-GPU machines with heterogeneous SM counts | |
// for now we just pick the SM count of the first GPU | |
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); | |
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); | |
} | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::Arguments arguments; | |
// Initialize GemmUniversal3xInstance arguments. | |
arguments = { | |
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode | |
{ | |
static_cast<coord_t>(N), | |
static_cast<coord_t>(M), | |
static_cast<coord_t>(K), | |
static_cast<coord_t>(B) | |
}, // ProblemShape problem_shape | |
{ | |
(cutlass::half_t*)(W), // ElementA const* ptr_A | |
{ | |
256L /* stride_x0 */, | |
cute::Int<1>{} /* stride_x1 */, | |
0 /* batch_stride_x */ | |
}, // StrideA dA | |
(cutlass::half_t*)(X), // ElementB const* ptr_B | |
{ | |
cute::Int<1>{} /* stride_w1 */, | |
1024L /* stride_w0 */, | |
0 /* batch_stride_w */ | |
}, // StrideB dB | |
}, // MainloopArguments mainloop | |
// see https://tinyurl.com/4rk89z48 | |
{ | |
{ | |
{ // ADDMM Arguments: ternary op : beta * C + (alpha * acc) | |
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : beta | |
{}, // leaf op+args : C | |
{ // binary op : alpha * acc | |
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : alpha | |
{}, // leaf op+args : acc | |
{} // binary args : multiplies | |
}, // end binary op | |
{} // ternary args : multiply_add | |
} // end ternary op | |
}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) | |
(cutlass::half_t*)(Bias), // ElementC const* ptr_C | |
{ | |
cute::Int<1>{} /* stride_bias0 */, | |
0L /* stride_bias1 */, | |
0 /* batch_stride_bias */ | |
}, // StrideC dC | |
(cutlass::half_t*)(Y), // ElementD const* ptr_D | |
{ | |
cute::Int<1>{} /* stride_y0 */, | |
109760L /* stride_y1 */, | |
0 /* batch_stride_y */ | |
}, // StrideD dD | |
}, // EpilogueArguments epilogue, | |
hw_info | |
}; | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type gemm_op; | |
if (workspace_size) { | |
*workspace_size = gemm_op.get_workspace_size(arguments); | |
return 0; | |
} | |
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers | |
{ | |
if (!X) { | |
int64_t X_size = 262144L; | |
if (X_size > 0) { | |
throw std::runtime_error("input X is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!W) { | |
int64_t W_size = 28098560L; | |
if (W_size > 0) { | |
throw std::runtime_error("input W is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!Bias) { | |
int64_t Bias_size = 112394240L; | |
if (Bias_size > 0) { | |
throw std::runtime_error("input Bias is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!Y) { | |
int64_t Y_size = 112394240L; | |
if (Y_size > 0) { | |
throw std::runtime_error("input Y is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
auto status = gemm_op.can_implement(arguments); | |
CUTLASS_CHECK(status); | |
} | |
#ifdef CUTLASS_DEBUG_TRACE_LEVEL | |
#if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
{ | |
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
// we don't need a print statement, it's happening inside the function. | |
gemm_op.maximum_active_blocks(); | |
} | |
#endif | |
#endif | |
{ | |
auto status = gemm_op.initialize(arguments, workspace, stream); | |
CUTLASS_CHECK(status); | |
} | |
{ | |
auto status = gemm_op(stream); | |
CUTLASS_CHECK(status); | |
} | |
} | |
catch (std::exception& e) { | |
std::cerr << "Runtime error: " << e.what() << std::endl; | |
return -1; | |
} | |
catch (...) { | |
return -1; | |
} | |
return 0; | |
} | |
} | |
#ifdef GENERATE_STANDALONE_RUNNER | |
/// Helper to initialize a block of device data | |
template <class Element> | |
bool initialize_block( | |
cutlass::DeviceAllocation<Element>& block, | |
uint64_t seed, float max=1.0, float min=-1.0) { | |
if (block.size()<=0) return false; | |
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min)); | |
cutlass::reference::device::BlockFillRandomUniform( | |
block.get(), block.size(), seed, scope_max, scope_min, 0); | |
return true; | |
} | |
extern "C" int run_standalone(uint64_t seed, int repetitions) { | |
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl; | |
size_t workspace_size = 0; | |
size_t* workspace_size_ptr = &workspace_size; | |
using ElementA = cutlass::half_t; | |
using ElementB = cutlass::half_t; | |
using ElementC = cutlass::half_t; // may not be void | |
using ElementD = cutlass::half_t; | |
cutlass::DeviceAllocation<ElementA> X_data(262144); | |
initialize_block(X_data, seed++); | |
cutlass::DeviceAllocation<ElementB> W_data(28098560); | |
initialize_block(W_data, seed++); | |
cutlass::DeviceAllocation<ElementC> Bias_data(109760); | |
initialize_block(Bias_data, seed++); | |
cutlass::DeviceAllocation<ElementD> Y_data(112394240); | |
cutlass::DeviceAllocation<uint8_t> workspace_data; | |
// Call once with workspace_size_ptr set to get workspace size | |
std::cout << "Calling once to get workspace size" << std::endl; | |
cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);; | |
// Allocate workspace if neccessary | |
if (workspace_size > 0) { | |
workspace_data.reset(workspace_size); | |
std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl; | |
} | |
std::cout << "Calling Kernel as cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;" << std::endl; | |
workspace_size_ptr = nullptr; | |
for (int i=0; i<repetitions; i++) { | |
cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);; | |
} | |
cudaError_t result = cudaDeviceSynchronize(); | |
if (result != cudaSuccess) { | |
std::cerr << "Device synchronize failed with error " | |
<< cudaGetErrorString(result) << std::endl; | |
return result; | |
} | |
return 0; | |
} | |
int main(int argc, char** argv) { | |
// warmup | |
run_standalone(1, 2); | |
// repeat | |
return run_standalone(2, 10); | |
} | |
#endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <exception> | |
#include <iostream> | |
#include <memory> | |
#include <random> | |
#include <vector> | |
#include "cute/tensor.hpp" | |
#include "cutlass/cutlass.h" | |
#include "cutlass/numeric_types.h" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include "cutlass/util/device_memory.h" | |
#include "cutlass/gemm/gemm.h" | |
#include "cutlass/gemm/device/gemm_universal.h" | |
#include "cutlass/gemm/device/gemm_universal_adapter.h" | |
#include "cutlass/gemm/kernel/gemm_universal.hpp" | |
#include "cutlass/gemm/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/default_epilogue.hpp" | |
#include "cutlass/epilogue/thread/linear_combination.h" | |
#include "cutlass/epilogue/thread/activation.h" | |
#include "cutlass/gemm/dispatch_policy.hpp" | |
#include "cutlass/gemm/kernel/tile_scheduler.hpp" | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
#ifdef GENERATE_STANDALONE_RUNNER | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
#include "cutlass/util/reference/device/gemm_complex.h" | |
#include "cutlass/util/reference/device/tensor_compare.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include <iostream> | |
#endif | |
// We compile all models with -fvisibility=hidden. Any symbols that need to be | |
// exposed in the final shared library must be declared with PT_EXPORT to make | |
// them visible. | |
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) | |
#define PT_EXPORT __attribute__((__visibility__("default"))) | |
#else | |
#ifdef _WIN32 | |
#define PT_EXPORT __declspec(dllexport) | |
#else | |
#define PT_EXPORT | |
#endif | |
#endif | |
using bfloat16 = nv_bfloat16; | |
using namespace cute; | |
#define CUTLASS_CHECK(status) \ | |
{ \ | |
cutlass::Status error = status; \ | |
if (error != cutlass::Status::kSuccess) { \ | |
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \ | |
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \ | |
throw std::runtime_error(msg); \ | |
} \ | |
} | |
// Used as pass-through functor in EVT just for type casting / rounding | |
template <typename T> | |
struct identity_op { | |
CUTLASS_HOST_DEVICE | |
T operator()(T val) const { return val; } | |
}; | |
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; | |
static_assert(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> || | |
cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>, | |
"Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue"); | |
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; | |
using ElementAcc = float; | |
using ElementD = cutlass::half_t; | |
using ElementC = cutlass::half_t; | |
using TileShapeMNK = cute::Shape<cute::_64, cute::_32, cute::_32>; | |
using ClusterShapeMNK = cute::Shape<cute::_1,cute::_1,cute::_1>; | |
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; | |
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< | |
TileShapeMNK, | |
EpilogueTileType, | |
ElementC, | |
ElementD, | |
EpilogueScheduleType | |
>; | |
using ADDMM_EVT = // alpha * acc + beta * C | |
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, | |
ElementD, ElementAcc, RoundStyle>, // beta * C + (alpha * acc) | |
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // beta | |
cutlass::epilogue::fusion::Sm90SrcFetch, // C | |
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementAcc, | |
ElementAcc, RoundStyle>, // alpha * acc | |
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // alpha | |
cutlass::epilogue::fusion::Sm90AccFetch // acc | |
>>; | |
using EVT_expr_1 = ADDMM_EVT; | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,EVT_expr_1>; | |
; | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_epilogue = | |
typename cutlass::epilogue::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
TileShapeMNK, | |
ClusterShapeMNK, | |
EpilogueTileType, | |
float, float, | |
cutlass::half_t, cutlass::layout::ColumnMajor, 8, | |
cutlass::half_t, cutlass::layout::ColumnMajor, 8, | |
EpilogueScheduleType, | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor | |
>::CollectiveOp; | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_mainloop = | |
typename cutlass::gemm::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::half_t, cutlass::layout::ColumnMajor, 8, | |
float, | |
cute::Shape<cute::_64, cute::_32, cute::_32>, | |
cute::Shape<cute::_1,cute::_1,cute::_1>, | |
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_epilogue::SharedStorage)>, | |
cutlass::gemm::KernelTmaWarpSpecializedPingpong | |
>::CollectiveOp; | |
// Gemm operator cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< | |
cute::Shape<int,int,int,int>, | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_mainloop, | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_epilogue, | |
cutlass::gemm::PersistentScheduler>; | |
// Define named type | |
struct cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma : | |
public cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_base { }; | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma>; | |
// When workspace_size is not a nullptr, populates requested workspace_size and returns. | |
// Otherwise, computes the Gemm kernel using the given workspace ptr. | |
extern "C" { | |
PT_EXPORT int cuda_cutlass_gemm_0(const half* Bias, const half* X, const half* W, half* Y, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) { | |
try { | |
int64_t B = 1; | |
int64_t M = 1024L; | |
int64_t K = 5952L; | |
int64_t N = 1024L; | |
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_device_type::ElementAccumulator; | |
using coord_t = cutlass::gemm::GemmCoord::Index; | |
static cutlass::KernelHardwareInfo hw_info; | |
if (hw_info.sm_count == 0) { | |
// @TODO kadeng: Add support for Multi-GPU machines with heterogeneous SM counts | |
// for now we just pick the SM count of the first GPU | |
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); | |
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); | |
} | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_device_type::Arguments arguments; | |
// Initialize GemmUniversal3xInstance arguments. | |
arguments = { | |
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode | |
{ | |
static_cast<coord_t>(N), | |
static_cast<coord_t>(M), | |
static_cast<coord_t>(K), | |
static_cast<coord_t>(B) | |
}, // ProblemShape problem_shape | |
{ | |
(cutlass::half_t*)(W), // ElementA const* ptr_A | |
{ | |
5952L /* stride_x0 */, | |
cute::Int<1>{} /* stride_x1 */, | |
0 /* batch_stride_x */ | |
}, // StrideA dA | |
(cutlass::half_t*)(X), // ElementB const* ptr_B | |
{ | |
5952L /* stride_w1 */, | |
cute::Int<1>{} /* stride_w0 */, | |
0 /* batch_stride_w */ | |
}, // StrideB dB | |
}, // MainloopArguments mainloop | |
// see https://tinyurl.com/4rk89z48 | |
{ | |
{ | |
{ // ADDMM Arguments: ternary op : beta * C + (alpha * acc) | |
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : beta | |
{}, // leaf op+args : C | |
{ // binary op : alpha * acc | |
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : alpha | |
{}, // leaf op+args : acc | |
{} // binary args : multiplies | |
}, // end binary op | |
{} // ternary args : multiply_add | |
} // end ternary op | |
}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) | |
(cutlass::half_t*)(Bias), // ElementC const* ptr_C | |
{ | |
cute::Int<1>{} /* stride_bias0 */, | |
0L /* stride_bias1 */, | |
0 /* batch_stride_bias */ | |
}, // StrideC dC | |
(cutlass::half_t*)(Y), // ElementD const* ptr_D | |
{ | |
cute::Int<1>{} /* stride_y0 */, | |
1024L /* stride_y1 */, | |
0 /* batch_stride_y */ | |
}, // StrideD dD | |
}, // EpilogueArguments epilogue, | |
hw_info | |
}; | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_device_type gemm_op; | |
if (workspace_size) { | |
*workspace_size = gemm_op.get_workspace_size(arguments); | |
return 0; | |
} | |
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers | |
{ | |
if (!X) { | |
int64_t X_size = 6094848L; | |
if (X_size > 0) { | |
throw std::runtime_error("input X is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!W) { | |
int64_t W_size = 6094848L; | |
if (W_size > 0) { | |
throw std::runtime_error("input W is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!Bias) { | |
int64_t Bias_size = 1048576L; | |
if (Bias_size > 0) { | |
throw std::runtime_error("input Bias is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!Y) { | |
int64_t Y_size = 1048576L; | |
if (Y_size > 0) { | |
throw std::runtime_error("input Y is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
auto status = gemm_op.can_implement(arguments); | |
CUTLASS_CHECK(status); | |
} | |
#ifdef CUTLASS_DEBUG_TRACE_LEVEL | |
#if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
{ | |
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
// we don't need a print statement, it's happening inside the function. | |
gemm_op.maximum_active_blocks(); | |
} | |
#endif | |
#endif | |
{ | |
auto status = gemm_op.initialize(arguments, workspace, stream); | |
CUTLASS_CHECK(status); | |
} | |
{ | |
auto status = gemm_op(stream); | |
CUTLASS_CHECK(status); | |
} | |
} | |
catch (std::exception& e) { | |
std::cerr << "Runtime error: " << e.what() << std::endl; | |
return -1; | |
} | |
catch (...) { | |
return -1; | |
} | |
return 0; | |
} | |
} | |
#ifdef GENERATE_STANDALONE_RUNNER | |
/// Helper to initialize a block of device data | |
template <class Element> | |
bool initialize_block( | |
cutlass::DeviceAllocation<Element>& block, | |
uint64_t seed, float max=1.0, float min=-1.0) { | |
if (block.size()<=0) return false; | |
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min)); | |
cutlass::reference::device::BlockFillRandomUniform( | |
block.get(), block.size(), seed, scope_max, scope_min, 0); | |
return true; | |
} | |
extern "C" int run_standalone(uint64_t seed, int repetitions) { | |
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl; | |
size_t workspace_size = 0; | |
size_t* workspace_size_ptr = &workspace_size; | |
using ElementA = cutlass::half_t; | |
using ElementB = cutlass::half_t; | |
using ElementC = cutlass::half_t; // may not be void | |
using ElementD = cutlass::half_t; | |
cutlass::DeviceAllocation<ElementA> X_data(6094848); | |
initialize_block(X_data, seed++); | |
cutlass::DeviceAllocation<ElementB> W_data(6094848); | |
initialize_block(W_data, seed++); | |
cutlass::DeviceAllocation<ElementC> Bias_data(1024); | |
initialize_block(Bias_data, seed++); | |
cutlass::DeviceAllocation<ElementD> Y_data(1048576); | |
cutlass::DeviceAllocation<uint8_t> workspace_data; | |
// Call once with workspace_size_ptr set to get workspace size | |
std::cout << "Calling once to get workspace size" << std::endl; | |
cuda_cutlass_gemm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);; | |
// Allocate workspace if neccessary | |
if (workspace_size > 0) { | |
workspace_data.reset(workspace_size); | |
std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl; | |
} | |
std::cout << "Calling Kernel as cuda_cutlass_gemm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;" << std::endl; | |
workspace_size_ptr = nullptr; | |
for (int i=0; i<repetitions; i++) { | |
cuda_cutlass_gemm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);; | |
} | |
cudaError_t result = cudaDeviceSynchronize(); | |
if (result != cudaSuccess) { | |
std::cerr << "Device synchronize failed with error " | |
<< cudaGetErrorString(result) << std::endl; | |
return result; | |
} | |
return 0; | |
} | |
int main(int argc, char** argv) { | |
// warmup | |
run_standalone(1, 2); | |
// repeat | |
return run_standalone(2, 10); | |
} | |
#endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <exception> | |
#include <iostream> | |
#include <memory> | |
#include <random> | |
#include <vector> | |
#include "cute/tensor.hpp" | |
#include "cutlass/cutlass.h" | |
#include "cutlass/numeric_types.h" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include "cutlass/util/device_memory.h" | |
#include "cutlass/gemm/gemm.h" | |
#include "cutlass/gemm/device/gemm_universal.h" | |
#include "cutlass/gemm/device/gemm_universal_adapter.h" | |
#include "cutlass/gemm/kernel/gemm_universal.hpp" | |
#include "cutlass/gemm/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/default_epilogue.hpp" | |
#include "cutlass/epilogue/thread/linear_combination.h" | |
#include "cutlass/epilogue/thread/activation.h" | |
#include "cutlass/gemm/dispatch_policy.hpp" | |
#include "cutlass/gemm/kernel/tile_scheduler.hpp" | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
#ifdef GENERATE_STANDALONE_RUNNER | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
#include "cutlass/util/reference/device/gemm_complex.h" | |
#include "cutlass/util/reference/device/tensor_compare.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include <iostream> | |
#endif | |
// We compile all models with -fvisibility=hidden. Any symbols that need to be | |
// exposed in the final shared library must be declared with PT_EXPORT to make | |
// them visible. | |
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) | |
#define PT_EXPORT __attribute__((__visibility__("default"))) | |
#else | |
#ifdef _WIN32 | |
#define PT_EXPORT __declspec(dllexport) | |
#else | |
#define PT_EXPORT | |
#endif | |
#endif | |
using bfloat16 = nv_bfloat16; | |
using namespace cute; | |
#define CUTLASS_CHECK(status) \ | |
{ \ | |
cutlass::Status error = status; \ | |
if (error != cutlass::Status::kSuccess) { \ | |
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \ | |
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \ | |
throw std::runtime_error(msg); \ | |
} \ | |
} | |
// Used as pass-through functor in EVT just for type casting / rounding | |
template <typename T> | |
struct identity_op { | |
CUTLASS_HOST_DEVICE | |
T operator()(T val) const { return val; } | |
}; | |
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; | |
static_assert(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> || | |
cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>, | |
"Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue"); | |
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; | |
using ElementAcc = float; | |
using ElementD = cutlass::half_t; | |
using ElementC = cutlass::half_t; | |
using TileShapeMNK = cute::Shape<cute::_64, cute::_32, cute::_32>; | |
using ClusterShapeMNK = cute::Shape<cute::_2,cute::_1,cute::_1>; | |
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; | |
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< | |
TileShapeMNK, | |
EpilogueTileType, | |
ElementC, | |
ElementD, | |
EpilogueScheduleType | |
>; | |
using ADDMM_EVT = // alpha * acc + beta * C | |
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, | |
ElementD, ElementAcc, RoundStyle>, // beta * C + (alpha * acc) | |
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // beta | |
cutlass::epilogue::fusion::Sm90SrcFetch, // C | |
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementAcc, | |
ElementAcc, RoundStyle>, // alpha * acc | |
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // alpha | |
cutlass::epilogue::fusion::Sm90AccFetch // acc | |
>>; | |
using EVT_expr_1 = ADDMM_EVT; | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,EVT_expr_1>; | |
; | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue = | |
typename cutlass::epilogue::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
TileShapeMNK, | |
ClusterShapeMNK, | |
EpilogueTileType, | |
float, float, | |
cutlass::half_t, cutlass::layout::ColumnMajor, 8, | |
cutlass::half_t, cutlass::layout::ColumnMajor, 8, | |
EpilogueScheduleType, | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor | |
>::CollectiveOp; | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_mainloop = | |
typename cutlass::gemm::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
float, | |
cute::Shape<cute::_64, cute::_32, cute::_32>, | |
cute::Shape<cute::_2,cute::_1,cute::_1>, | |
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue::SharedStorage)>, | |
cutlass::gemm::KernelTmaWarpSpecializedPingpong | |
>::CollectiveOp; | |
// Gemm operator cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< | |
cute::Shape<int,int,int,int>, | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_mainloop, | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue, | |
cutlass::gemm::PersistentScheduler>; | |
// Define named type | |
struct cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma : | |
public cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_base { }; | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma>; | |
// When workspace_size is not a nullptr, populates requested workspace_size and returns. | |
// Otherwise, computes the Gemm kernel using the given workspace ptr. | |
extern "C" { | |
PT_EXPORT int cuda_cutlass_gemm_1(const half* Bias, const half* X, const half* W, half* Y, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) { | |
try { | |
int64_t B = 1; | |
int64_t M = 1024L; | |
int64_t K = 256L; | |
int64_t N = 109760L; | |
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type::ElementAccumulator; | |
using coord_t = cutlass::gemm::GemmCoord::Index; | |
static cutlass::KernelHardwareInfo hw_info; | |
if (hw_info.sm_count == 0) { | |
// @TODO kadeng: Add support for Multi-GPU machines with heterogeneous SM counts | |
// for now we just pick the SM count of the first GPU | |
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); | |
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); | |
} | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type::Arguments arguments; | |
// Initialize GemmUniversal3xInstance arguments. | |
arguments = { | |
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode | |
{ | |
static_cast<coord_t>(N), | |
static_cast<coord_t>(M), | |
static_cast<coord_t>(K), | |
static_cast<coord_t>(B) | |
}, // ProblemShape problem_shape | |
{ | |
(cutlass::half_t*)(W), // ElementA const* ptr_A | |
{ | |
256L /* stride_x0 */, | |
cute::Int<1>{} /* stride_x1 */, | |
0 /* batch_stride_x */ | |
}, // StrideA dA | |
(cutlass::half_t*)(X), // ElementB const* ptr_B | |
{ | |
cute::Int<1>{} /* stride_w1 */, | |
1024L /* stride_w0 */, | |
0 /* batch_stride_w */ | |
}, // StrideB dB | |
}, // MainloopArguments mainloop | |
// see https://tinyurl.com/4rk89z48 | |
{ | |
{ | |
{ // ADDMM Arguments: ternary op : beta * C + (alpha * acc) | |
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : beta | |
{}, // leaf op+args : C | |
{ // binary op : alpha * acc | |
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : alpha | |
{}, // leaf op+args : acc | |
{} // binary args : multiplies | |
}, // end binary op | |
{} // ternary args : multiply_add | |
} // end ternary op | |
}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) | |
(cutlass::half_t*)(Bias), // ElementC const* ptr_C | |
{ | |
cute::Int<1>{} /* stride_bias0 */, | |
0L /* stride_bias1 */, | |
0 /* batch_stride_bias */ | |
}, // StrideC dC | |
(cutlass::half_t*)(Y), // ElementD const* ptr_D | |
{ | |
cute::Int<1>{} /* stride_y0 */, | |
109760L /* stride_y1 */, | |
0 /* batch_stride_y */ | |
}, // StrideD dD | |
}, // EpilogueArguments epilogue, | |
hw_info | |
}; | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type gemm_op; | |
if (workspace_size) { | |
*workspace_size = gemm_op.get_workspace_size(arguments); | |
return 0; | |
} | |
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers | |
{ | |
if (!X) { | |
int64_t X_size = 262144L; | |
if (X_size > 0) { | |
throw std::runtime_error("input X is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!W) { | |
int64_t W_size = 28098560L; | |
if (W_size > 0) { | |
throw std::runtime_error("input W is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!Bias) { | |
int64_t Bias_size = 112394240L; | |
if (Bias_size > 0) { | |
throw std::runtime_error("input Bias is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!Y) { | |
int64_t Y_size = 112394240L; | |
if (Y_size > 0) { | |
throw std::runtime_error("input Y is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
auto status = gemm_op.can_implement(arguments); | |
CUTLASS_CHECK(status); | |
} | |
#ifdef CUTLASS_DEBUG_TRACE_LEVEL | |
#if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
{ | |
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
// we don't need a print statement, it's happening inside the function. | |
gemm_op.maximum_active_blocks(); | |
} | |
#endif | |
#endif | |
{ | |
auto status = gemm_op.initialize(arguments, workspace, stream); | |
CUTLASS_CHECK(status); | |
} | |
{ | |
auto status = gemm_op(stream); | |
CUTLASS_CHECK(status); | |
} | |
} | |
catch (std::exception& e) { | |
std::cerr << "Runtime error: " << e.what() << std::endl; | |
return -1; | |
} | |
catch (...) { | |
return -1; | |
} | |
return 0; | |
} | |
} | |
#ifdef GENERATE_STANDALONE_RUNNER | |
/// Helper to initialize a block of device data | |
template <class Element> | |
bool initialize_block( | |
cutlass::DeviceAllocation<Element>& block, | |
uint64_t seed, float max=1.0, float min=-1.0) { | |
if (block.size()<=0) return false; | |
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min)); | |
cutlass::reference::device::BlockFillRandomUniform( | |
block.get(), block.size(), seed, scope_max, scope_min, 0); | |
return true; | |
} | |
extern "C" int run_standalone(uint64_t seed, int repetitions) { | |
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl; | |
size_t workspace_size = 0; | |
size_t* workspace_size_ptr = &workspace_size; | |
using ElementA = cutlass::half_t; | |
using ElementB = cutlass::half_t; | |
using ElementC = cutlass::half_t; // may not be void | |
using ElementD = cutlass::half_t; | |
cutlass::DeviceAllocation<ElementA> X_data(262144); | |
initialize_block(X_data, seed++); | |
cutlass::DeviceAllocation<ElementB> W_data(28098560); | |
initialize_block(W_data, seed++); | |
cutlass::DeviceAllocation<ElementC> Bias_data(109760); | |
initialize_block(Bias_data, seed++); | |
cutlass::DeviceAllocation<ElementD> Y_data(112394240); | |
cutlass::DeviceAllocation<uint8_t> workspace_data; | |
// Call once with workspace_size_ptr set to get workspace size | |
std::cout << "Calling once to get workspace size" << std::endl; | |
cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);; | |
// Allocate workspace if neccessary | |
if (workspace_size > 0) { | |
workspace_data.reset(workspace_size); | |
std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl; | |
} | |
std::cout << "Calling Kernel as cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;" << std::endl; | |
workspace_size_ptr = nullptr; | |
for (int i=0; i<repetitions; i++) { | |
cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);; | |
} | |
cudaError_t result = cudaDeviceSynchronize(); | |
if (result != cudaSuccess) { | |
std::cerr << "Device synchronize failed with error " | |
<< cudaGetErrorString(result) << std::endl; | |
return result; | |
} | |
return 0; | |
} | |
int main(int argc, char** argv) { | |
// warmup | |
run_standalone(1, 2); | |
// repeat | |
return run_standalone(2, 10); | |
} | |
#endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <exception> | |
#include <iostream> | |
#include <memory> | |
#include <random> | |
#include <vector> | |
#include "cute/tensor.hpp" | |
#include "cutlass/cutlass.h" | |
#include "cutlass/numeric_types.h" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include "cutlass/util/device_memory.h" | |
#include "cutlass/gemm/gemm.h" | |
#include "cutlass/gemm/device/gemm_universal.h" | |
#include "cutlass/gemm/device/gemm_universal_adapter.h" | |
#include "cutlass/gemm/kernel/gemm_universal.hpp" | |
#include "cutlass/gemm/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/default_epilogue.hpp" | |
#include "cutlass/epilogue/thread/linear_combination.h" | |
#include "cutlass/epilogue/thread/activation.h" | |
#include "cutlass/gemm/dispatch_policy.hpp" | |
#include "cutlass/gemm/kernel/tile_scheduler.hpp" | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
#ifdef GENERATE_STANDALONE_RUNNER | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
#include "cutlass/util/reference/device/gemm_complex.h" | |
#include "cutlass/util/reference/device/tensor_compare.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include <iostream> | |
#endif | |
// We compile all models with -fvisibility=hidden. Any symbols that need to be | |
// exposed in the final shared library must be declared with PT_EXPORT to make | |
// them visible. | |
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) | |
#define PT_EXPORT __attribute__((__visibility__("default"))) | |
#else | |
#ifdef _WIN32 | |
#define PT_EXPORT __declspec(dllexport) | |
#else | |
#define PT_EXPORT | |
#endif | |
#endif | |
using bfloat16 = nv_bfloat16; | |
using namespace cute; | |
#define CUTLASS_CHECK(status) \ | |
{ \ | |
cutlass::Status error = status; \ | |
if (error != cutlass::Status::kSuccess) { \ | |
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \ | |
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \ | |
throw std::runtime_error(msg); \ | |
} \ | |
} | |
// Used as pass-through functor in EVT just for type casting / rounding | |
template <typename T> | |
struct identity_op { | |
CUTLASS_HOST_DEVICE | |
T operator()(T val) const { return val; } | |
}; | |
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; | |
static_assert(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> || | |
cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>, | |
"Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue"); | |
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; | |
using ElementAcc = float; | |
using ElementD = cutlass::half_t; | |
using ElementC = cutlass::half_t; | |
using TileShapeMNK = cute::Shape<cute::_64, cute::_32, cute::_32>; | |
using ClusterShapeMNK = cute::Shape<cute::_2,cute::_1,cute::_1>; | |
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; | |
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< | |
TileShapeMNK, | |
EpilogueTileType, | |
ElementC, | |
ElementD, | |
EpilogueScheduleType | |
>; | |
using ADDMM_EVT = // alpha * acc + beta * C | |
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, | |
ElementD, ElementAcc, RoundStyle>, // beta * C + (alpha * acc) | |
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // beta | |
cutlass::epilogue::fusion::Sm90SrcFetch, // C | |
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementAcc, | |
ElementAcc, RoundStyle>, // alpha * acc | |
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // alpha | |
cutlass::epilogue::fusion::Sm90AccFetch // acc | |
>>; | |
using EVT_expr_1 = ADDMM_EVT; | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,EVT_expr_1>; | |
; | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue = | |
typename cutlass::epilogue::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
TileShapeMNK, | |
ClusterShapeMNK, | |
EpilogueTileType, | |
float, float, | |
cutlass::half_t, cutlass::layout::ColumnMajor, 8, | |
cutlass::half_t, cutlass::layout::ColumnMajor, 8, | |
EpilogueScheduleType, | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor | |
>::CollectiveOp; | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_mainloop = | |
typename cutlass::gemm::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
float, | |
cute::Shape<cute::_64, cute::_32, cute::_32>, | |
cute::Shape<cute::_2,cute::_1,cute::_1>, | |
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue::SharedStorage)>, | |
cutlass::gemm::KernelTmaWarpSpecializedPingpong | |
>::CollectiveOp; | |
// Gemm operator cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< | |
cute::Shape<int,int,int,int>, | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_mainloop, | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue, | |
cutlass::gemm::PersistentScheduler>; | |
// Define named type | |
struct cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma : | |
public cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_base { }; | |
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma>; | |
// When workspace_size is not a nullptr, populates requested workspace_size and returns. | |
// Otherwise, computes the Gemm kernel using the given workspace ptr. | |
extern "C" { | |
PT_EXPORT int cuda_cutlass_gemm_1(const half* Bias, const half* X, const half* W, half* Y, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) { | |
try { | |
int64_t B = 1; | |
int64_t M = 1024L; | |
int64_t K = 256L; | |
int64_t N = 109760L; | |
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type::ElementAccumulator; | |
using coord_t = cutlass::gemm::GemmCoord::Index; | |
static cutlass::KernelHardwareInfo hw_info; | |
if (hw_info.sm_count == 0) { | |
// @TODO kadeng: Add support for Multi-GPU machines with heterogeneous SM counts | |
// for now we just pick the SM count of the first GPU | |
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); | |
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); | |
} | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type::Arguments arguments; | |
// Initialize GemmUniversal3xInstance arguments. | |
arguments = { | |
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode | |
{ | |
static_cast<coord_t>(N), | |
static_cast<coord_t>(M), | |
static_cast<coord_t>(K), | |
static_cast<coord_t>(B) | |
}, // ProblemShape problem_shape | |
{ | |
(cutlass::half_t*)(W), // ElementA const* ptr_A | |
{ | |
256L /* stride_x0 */, | |
cute::Int<1>{} /* stride_x1 */, | |
0 /* batch_stride_x */ | |
}, // StrideA dA | |
(cutlass::half_t*)(X), // ElementB const* ptr_B | |
{ | |
cute::Int<1>{} /* stride_w1 */, | |
1024L /* stride_w0 */, | |
0 /* batch_stride_w */ | |
}, // StrideB dB | |
}, // MainloopArguments mainloop | |
// see https://tinyurl.com/4rk89z48 | |
{ | |
{ | |
{ // ADDMM Arguments: ternary op : beta * C + (alpha * acc) | |
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : beta | |
{}, // leaf op+args : C | |
{ // binary op : alpha * acc | |
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : alpha | |
{}, // leaf op+args : acc | |
{} // binary args : multiplies | |
}, // end binary op | |
{} // ternary args : multiply_add | |
} // end ternary op | |
}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) | |
(cutlass::half_t*)(Bias), // ElementC const* ptr_C | |
{ | |
cute::Int<1>{} /* stride_bias0 */, | |
0L /* stride_bias1 */, | |
0 /* batch_stride_bias */ | |
}, // StrideC dC | |
(cutlass::half_t*)(Y), // ElementD const* ptr_D | |
{ | |
cute::Int<1>{} /* stride_y0 */, | |
109760L /* stride_y1 */, | |
0 /* batch_stride_y */ | |
}, // StrideD dD | |
}, // EpilogueArguments epilogue, | |
hw_info | |
}; | |
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type gemm_op; | |
if (workspace_size) { | |
*workspace_size = gemm_op.get_workspace_size(arguments); | |
return 0; | |
} | |
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers | |
{ | |
if (!X) { | |
int64_t X_size = 262144L; | |
if (X_size > 0) { | |
throw std::runtime_error("input X is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!W) { | |
int64_t W_size = 28098560L; | |
if (W_size > 0) { | |
throw std::runtime_error("input W is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!Bias) { | |
int64_t Bias_size = 112394240L; | |
if (Bias_size > 0) { | |
throw std::runtime_error("input Bias is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!Y) { | |
int64_t Y_size = 112394240L; | |
if (Y_size > 0) { | |
throw std::runtime_error("input Y is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
auto status = gemm_op.can_implement(arguments); | |
CUTLASS_CHECK(status); | |
} | |
#ifdef CUTLASS_DEBUG_TRACE_LEVEL | |
#if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
{ | |
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
// we don't need a print statement, it's happening inside the function. | |
gemm_op.maximum_active_blocks(); | |
} | |
#endif | |
#endif | |
{ | |
auto status = gemm_op.initialize(arguments, workspace, stream); | |
CUTLASS_CHECK(status); | |
} | |
{ | |
auto status = gemm_op(stream); | |
CUTLASS_CHECK(status); | |
} | |
} | |
catch (std::exception& e) { | |
std::cerr << "Runtime error: " << e.what() << std::endl; | |
return -1; | |
} | |
catch (...) { | |
return -1; | |
} | |
return 0; | |
} | |
} | |
#ifdef GENERATE_STANDALONE_RUNNER | |
/// Helper to initialize a block of device data | |
template <class Element> | |
bool initialize_block( | |
cutlass::DeviceAllocation<Element>& block, | |
uint64_t seed, float max=1.0, float min=-1.0) { | |
if (block.size()<=0) return false; | |
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min)); | |
cutlass::reference::device::BlockFillRandomUniform( | |
block.get(), block.size(), seed, scope_max, scope_min, 0); | |
return true; | |
} | |
extern "C" int run_standalone(uint64_t seed, int repetitions) { | |
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl; | |
size_t workspace_size = 0; | |
size_t* workspace_size_ptr = &workspace_size; | |
using ElementA = cutlass::half_t; | |
using ElementB = cutlass::half_t; | |
using ElementC = cutlass::half_t; // may not be void | |
using ElementD = cutlass::half_t; | |
cutlass::DeviceAllocation<ElementA> X_data(262144); | |
initialize_block(X_data, seed++); | |
cutlass::DeviceAllocation<ElementB> W_data(28098560); | |
initialize_block(W_data, seed++); | |
cutlass::DeviceAllocation<ElementC> Bias_data(109760); | |
initialize_block(Bias_data, seed++); | |
cutlass::DeviceAllocation<ElementD> Y_data(112394240); | |
cutlass::DeviceAllocation<uint8_t> workspace_data; | |
// Call once with workspace_size_ptr set to get workspace size | |
std::cout << "Calling once to get workspace size" << std::endl; | |
cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);; | |
// Allocate workspace if neccessary | |
if (workspace_size > 0) { | |
workspace_data.reset(workspace_size); | |
std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl; | |
} | |
std::cout << "Calling Kernel as cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;" << std::endl; | |
workspace_size_ptr = nullptr; | |
for (int i=0; i<repetitions; i++) { | |
cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);; | |
} | |
cudaError_t result = cudaDeviceSynchronize(); | |
if (result != cudaSuccess) { | |
std::cerr << "Device synchronize failed with error " | |
<< cudaGetErrorString(result) << std::endl; | |
return result; | |
} | |
return 0; | |
} | |
int main(int argc, char** argv) { | |
// warmup | |
run_standalone(1, 2); | |
// repeat | |
return run_standalone(2, 10); | |
} | |
#endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
# Change the environment variables to point to Cutlass and CUDA Toolkit and run this, | |
# passing any of the standalone repro_N.cu files as argument. It will compile and run the | |
# example. | |
# This will create a debug build and run it through compute-sanitizer | |
set -x | |
export REPRO_CUTLASS_PATH=/home/klondenberg/github/pytorch/pytorch/third_party/cutlass | |
export REPRO_CUDA_PATH=/home/klondenberg/local/cuda121 | |
$REPRO_CUDA_PATH/bin/nvcc -g -G -t=0 -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -w -gencode=arch=compute_90a,code=[sm_90a,compute_90a] -O1 -std=c++17 --expt-relaxed-constexpr -Xcompiler=-fPIC --use_fast_math -Xcompiler=-fno-strict-aliasing -Xcompiler -fvisibility=hidden -Xcompiler=-Wconversion -I${REPRO_CUTLASS_PATH}/include -I${REPRO_CUTLASS_PATH}/tools/library/include -I${REPRO_CUTLASS_PATH}/tools/library/src -I${REPRO_CUTLASS_PATH}/tools/util/include -L${REPRO_CUDA_PATH}/lib64 -L${REPRO_CUDA_PATH}/lib64/stubs -lcuda -lcudart -DGENERATE_STANDALONE_RUNNER -DNDEBUG -DCUTLASS_DEBUG_TRACE_LEVEL=1 -o "${@}.debug.exe" "$@" | |
"${REPRO_CUDA_PATH}/bin/compute-sanitizer" "./${@}.debug.exe" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment