Created
December 7, 2023 12:20
-
-
Save kadeng/6df8a529dcc2d50c96cbb50fe97c96c0 to your computer and use it in GitHub Desktop.
Cutlass performance regression
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Environment: | |
* Linux x64, NVIDIA H100 GPU | |
* CUDA 12.1 | |
* Cutlass v3.3.0 ( tagged release ) and Cutlass v3.2.2 ( tagged release ) | |
Command ( example ): | |
nvcc -t=0 -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -w -gencode=arch=compute_90a,code=[sm_90a,compute_90a] -O1 -std=c++17 --expt-relaxed-constexpr -Xcompiler=-fPIC --use-fast-math -Xcompiler=-fno-strict-aliasing -Xcompiler -fvisibility=hidden -Xcompiler=-Wconversion -I/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include -I/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/tools/library/include -I/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/tools/library/src -I/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/tools/util/include -L/home/klondenberg/local/cuda121/lib64 -L/home/klondenberg/local/cuda121/lib64/stubs -lcuda -lcudart -DGENERATE_STANDALONE_RUNNER -o performance_repro performance_repro.cu | |
Where | |
* /home/klondenberg/github/pytorch/pytorch/third_party/cutlass is the Cutlass v3.3.0 check out directory | |
* /home/klondenberg/local/cuda121 is the CUDA 12.1 Toolkit path | |
* nvcc is from CUDA 12.1 toolkit | |
To obtain the performance measurements, use "nsys profile performance_repro" | |
Results ( from nsys ): | |
* Cutlass 3.3.0: ~ 80ms | |
* Cutlass 3.2.2: ~ 12ms | |
When compiling with Cutlass 3.3.0, I also get the following warning, which I don't get with Cutlass 3.2.2: | |
ptxas info : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN7cutlass13device_kernelI119cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_void_f16_64x32x64_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tmaEEvNT_6ParamsE' | |
/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include/cutlass/device_kernel.h: In function 'void cutlass::device_kernel(typename Operator::Params) [with Operator = cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_void_f16_64x32x64_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma]': | |
/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include/cutlass/device_kernel.h:104:1: note: the ABI for passing parameters with 64-byte alignment has changed in GCC 4.6 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <exception> | |
#include <iostream> | |
#include <memory> | |
#include <random> | |
#include <vector> | |
#include "cute/tensor.hpp" | |
#include "cutlass/cutlass.h" | |
#include "cutlass/numeric_types.h" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include "cutlass/util/device_memory.h" | |
#include "cutlass/gemm/gemm.h" | |
#include "cutlass/gemm/device/gemm_universal.h" | |
#include "cutlass/gemm/device/gemm_universal_adapter.h" | |
#include "cutlass/gemm/kernel/gemm_universal.hpp" | |
#include "cutlass/gemm/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/default_epilogue.hpp" | |
#include "cutlass/epilogue/thread/linear_combination.h" | |
#include "cutlass/epilogue/thread/activation.h" | |
#include "cutlass/gemm/dispatch_policy.hpp" | |
#include "cutlass/gemm/kernel/tile_scheduler.hpp" | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
#ifdef GENERATE_STANDALONE_RUNNER | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
#include "cutlass/util/reference/device/gemm_complex.h" | |
#include "cutlass/util/reference/device/tensor_compare.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include <iostream> | |
#endif | |
// We compile all models with -fvisibility=hidden. Any symbols that need to be | |
// exposed in the final shared library must be declared with PT_EXPORT to make | |
// them visible. | |
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) | |
#define PT_EXPORT __attribute__((__visibility__("default"))) | |
#else | |
#ifdef _WIN32 | |
#define PT_EXPORT __declspec(dllexport) | |
#else | |
#define PT_EXPORT | |
#endif | |
#endif | |
using bfloat16 = nv_bfloat16; | |
using namespace cute; | |
#define CUTLASS_CHECK(status) \ | |
{ \ | |
cutlass::Status error = status; \ | |
if (error != cutlass::Status::kSuccess) { \ | |
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \ | |
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \ | |
throw std::runtime_error(msg); \ | |
} \ | |
} | |
// Used as pass-through functor in EVT just for type casting / rounding | |
template <typename T> | |
struct identity_op { | |
CUTLASS_HOST_DEVICE | |
T operator()(T val) const { return val; } | |
}; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma_epilogue = | |
typename cutlass::epilogue::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cute::Shape<cute::_256, cute::_128, cute::_64>, | |
cute::Shape<cute::_1,cute::_2,cute::_1>, | |
cutlass::epilogue::collective::EpilogueTileAuto, | |
float, float, | |
void, cutlass::layout::ColumnMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::epilogue::TmaWarpSpecializedCooperative | |
>::CollectiveOp; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma_mainloop = | |
typename cutlass::gemm::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
float, | |
cute::Shape<cute::_256, cute::_128, cute::_64>, | |
cute::Shape<cute::_1,cute::_2,cute::_1>, | |
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma_epilogue::SharedStorage)>, | |
cutlass::gemm::KernelTmaWarpSpecializedCooperative | |
>::CollectiveOp; | |
// Gemm operator cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< | |
cute::Shape<int,int,int,int>, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma_mainloop, | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma_epilogue, | |
cutlass::gemm::PersistentScheduler>; | |
// Define named type | |
struct cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma : | |
public cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma_base { }; | |
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma>; | |
// When workspace_size is not a nullptr, populates requested workspace_size and returns. | |
// Otherwise, computes the Gemm kernel using the given workspace ptr. | |
extern "C" { | |
PT_EXPORT int cuda_fused_bmm_0(const half* X, const half* W, half* Y, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) { | |
try { | |
int64_t B = 10L; | |
int64_t M = 10240L; | |
int64_t K = 2048L; | |
int64_t N = 10240L; | |
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma_device_type::ElementAccumulator; | |
using coord_t = cutlass::gemm::GemmCoord::Index; | |
static cutlass::KernelHardwareInfo hw_info; | |
if (hw_info.sm_count == 0) { | |
// @TODO kadeng: Add support for Multi-GPU machines with heterogeneous SM counts | |
// for now we just pick the SM count of the first GPU | |
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); | |
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); | |
} | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma_device_type::Arguments arguments; | |
// Initialize GemmUniversal3xInstance arguments. | |
arguments = { | |
cutlass::gemm::GemmUniversalMode::kBatched, // GemmUniversalMode mode | |
{ | |
static_cast<coord_t>(M), | |
static_cast<coord_t>(N), | |
static_cast<coord_t>(K), | |
static_cast<coord_t>(B) | |
}, // ProblemShape problem_shape | |
{ | |
(cutlass::half_t*)(X), // ElementA const* ptr_A | |
{ | |
2048L /* stride_x0 */, | |
cute::Int<1>{} /* stride_x1 */, | |
20971520L /* batch_stride_x */ | |
}, // StrideA dA | |
(cutlass::half_t*)(W), // ElementB const* ptr_B | |
{ | |
cute::Int<1>{} /* stride_w1 */, | |
10240L /* stride_w0 */, | |
20971520L /* batch_stride_w */ | |
}, // StrideB dB | |
}, // MainloopArguments mainloop | |
// see https://tinyurl.com/4rk89z48 | |
{ | |
{}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) | |
nullptr, // ElementC const* ptr_C | |
{ | |
cute::Int<1>{} /* stride_bias0 */, | |
cute::Int<1>{} /* stride_bias1 */, | |
0 /* batch_stride_bias */ | |
}, // StrideC dC | |
(cutlass::half_t*)(Y), // ElementD const* ptr_D | |
{ | |
10240L /* stride_y0 */, | |
cute::Int<1>{} /* stride_y1 */, | |
104857600L /* batch_stride_y */ | |
}, // StrideD dD | |
}, // EpilogueArguments epilogue, | |
hw_info | |
}; | |
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma_device_type gemm_op; | |
if (workspace_size) { | |
*workspace_size = gemm_op.get_workspace_size(arguments); | |
return 0; | |
} | |
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers | |
{ | |
if (!X) { | |
int64_t X_size = 209715200L; | |
if (X_size > 0) { | |
throw std::runtime_error("input X is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!W) { | |
int64_t W_size = 209715200L; | |
if (W_size > 0) { | |
throw std::runtime_error("input W is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!Y) { | |
int64_t Y_size = 1048576000L; | |
if (Y_size > 0) { | |
throw std::runtime_error("input Y is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
auto status = gemm_op.can_implement(arguments); | |
CUTLASS_CHECK(status); | |
} | |
#ifdef CUTLASS_DEBUG_TRACE_LEVEL | |
#if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
{ | |
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
// we don't need a print statement, it's happening inside the function. | |
gemm_op.maximum_active_blocks(); | |
} | |
#endif | |
#endif | |
{ | |
auto status = gemm_op.initialize(arguments, workspace, stream); | |
CUTLASS_CHECK(status); | |
} | |
{ | |
auto status = gemm_op(stream); | |
CUTLASS_CHECK(status); | |
} | |
} | |
catch (std::exception& e) { | |
std::cerr << "Runtime error: " << e.what() << std::endl; | |
return -1; | |
} | |
catch (...) { | |
return -1; | |
} | |
return 0; | |
} | |
} | |
#ifdef GENERATE_STANDALONE_RUNNER | |
/// Helper to initialize a block of device data | |
template <class Element> | |
bool initialize_block( | |
cutlass::DeviceAllocation<Element>& block, | |
uint64_t seed, float max=1.0, float min=-1.0) { | |
if (block.size()<=0) return false; | |
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min)); | |
cutlass::reference::device::BlockFillRandomUniform( | |
block.get(), block.size(), seed, scope_max, scope_min, 0); | |
return true; | |
} | |
extern "C" int run_standalone(uint64_t seed, int repetitions) { | |
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl; | |
size_t workspace_size = 0; | |
size_t* workspace_size_ptr = &workspace_size; | |
using ElementA = cutlass::half_t; | |
using ElementB = cutlass::half_t; | |
using ElementC = uint8_t; // may not be void | |
using ElementD = cutlass::half_t; | |
cutlass::DeviceAllocation<ElementA> X_data(209715200); | |
initialize_block(X_data, seed++); | |
cutlass::DeviceAllocation<ElementB> W_data(209715200); | |
initialize_block(W_data, seed++); | |
cutlass::DeviceAllocation<ElementC> Bias_data(0); | |
initialize_block(Bias_data, seed++); | |
cutlass::DeviceAllocation<ElementD> Y_data(1048576000); | |
cutlass::DeviceAllocation<uint8_t> workspace_data; | |
// Call once with workspace_size_ptr set to get workspace size | |
std::cout << "Calling once to get workspace size" << std::endl; | |
cuda_fused_bmm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);; | |
// Allocate workspace if neccessary | |
if (workspace_size > 0) { | |
workspace_data.reset(workspace_size); | |
std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl; | |
} | |
std::cout << "Calling Kernel as cuda_fused_bmm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;" << std::endl; | |
workspace_size_ptr = nullptr; | |
for (int i=0; i<repetitions; i++) { | |
cuda_fused_bmm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);; | |
} | |
cudaError_t result = cudaDeviceSynchronize(); | |
if (result != cudaSuccess) { | |
std::cerr << "Device synchronize failed with error " | |
<< cudaGetErrorString(result) << std::endl; | |
return result; | |
} | |
return 0; | |
} | |
int main(int argc, char** argv) { | |
// warmup | |
run_standalone(1, 2); | |
// repeat | |
return run_standalone(2, 10); | |
} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment