Last active
December 6, 2023 16:43
-
-
Save kadeng/31df46a19d093bdfb36977892f578e1c to your computer and use it in GitHub Desktop.
Cutlass bug report
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <exception> | |
#include <iostream> | |
#include <memory> | |
#include <random> | |
#include <vector> | |
#include "cute/tensor.hpp" | |
#include "cutlass/cutlass.h" | |
#include "cutlass/numeric_types.h" | |
#include "cutlass/tensor_ref.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include "cutlass/util/device_memory.h" | |
#include "cutlass/gemm/gemm.h" | |
#include "cutlass/gemm/device/gemm_universal.h" | |
#include "cutlass/gemm/device/gemm_universal_adapter.h" | |
#include "cutlass/gemm/kernel/gemm_universal.hpp" | |
#include "cutlass/gemm/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/collective_builder.hpp" | |
#include "cutlass/epilogue/collective/default_epilogue.hpp" | |
#include "cutlass/epilogue/thread/linear_combination.h" | |
#include "cutlass/epilogue/thread/activation.h" | |
#include "cutlass/gemm/dispatch_policy.hpp" | |
#include "cutlass/gemm/kernel/tile_scheduler.hpp" | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
#ifdef GENERATE_STANDALONE_RUNNER | |
#include "cutlass/util/distribution.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/packed_stride.hpp" | |
#include "cutlass/util/tensor_view_io.h" | |
#include "cutlass/util/reference/device/gemm_complex.h" | |
#include "cutlass/util/reference/device/tensor_compare.h" | |
#include "cutlass/util/reference/device/tensor_fill.h" | |
#include <iostream> | |
#endif | |
// We compile all models with -fvisibility=hidden. Any symbols that need to be | |
// exposed in the final shared library must be declared with PT_EXPORT to make | |
// them visible. | |
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) | |
#define PT_EXPORT __attribute__((__visibility__("default"))) | |
#else | |
#ifdef _WIN32 | |
#define PT_EXPORT __declspec(dllexport) | |
#else | |
#define PT_EXPORT | |
#endif | |
#endif | |
using bfloat16 = nv_bfloat16; | |
using namespace cute; | |
#define CUTLASS_CHECK(status) \ | |
{ \ | |
cutlass::Status error = status; \ | |
if (error != cutlass::Status::kSuccess) { \ | |
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \ | |
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \ | |
throw std::runtime_error(msg); \ | |
} \ | |
} | |
// Used as pass-through functor in EVT just for type casting / rounding | |
template <typename T> | |
struct identity_op { | |
CUTLASS_HOST_DEVICE | |
T operator()(T val) const { return val; } | |
}; | |
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; | |
static_assert(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> || | |
cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>, | |
"Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue"); | |
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; | |
using ElementAcc = cutlass::half_t; | |
using ElementD = cutlass::half_t; | |
using ElementC = void; | |
using TileShapeMNK = cute::Shape<cute::_64, cute::_64, cute::_32>; | |
using ClusterShapeMNK = cute::Shape<cute::_1,cute::_1,cute::_1>; | |
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; | |
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< | |
TileShapeMNK, | |
EpilogueTileType, | |
ElementC, | |
ElementD, | |
EpilogueScheduleType | |
>; | |
using ADDMM_EVT = // alpha * acc + beta * C | |
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, | |
ElementD, ElementAcc, RoundStyle>, // beta * C + (alpha * acc) | |
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // beta | |
cutlass::epilogue::fusion::Sm90SrcFetch, // C | |
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementAcc, | |
ElementAcc, RoundStyle>, // alpha * acc | |
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // alpha | |
cutlass::epilogue::fusion::Sm90AccFetch // acc | |
>>; | |
using EVT_expr_1 = cutlass::epilogue::fusion::Sm90AccFetch /* :=buf0 (matmul output in accumulator) */; | |
using arg0_1AuxLoadDesc = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>, cutlass::half_t>; | |
using EVT_expr_2 = cutlass::epilogue::fusion::Sm90EVT< | |
cutlass::epilogue::fusion::Sm90Compute<identity_op,ElementAcc, typename arg0_1AuxLoadDesc::Element, RoundStyle >, | |
cutlass::epilogue::fusion::Sm90RowBroadcast<arg0_1AuxLoadDesc::Stages, TileShapeMNK, typename arg0_1AuxLoadDesc::Element, typename arg0_1AuxLoadDesc::Stride>> /* :=arg0_1 as aux operand, cast to accumulator dtype */; | |
using EVT_expr_3 = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::plus, ElementAcc, ElementAcc, RoundStyle>,EVT_expr_1,EVT_expr_2>; | |
using EVT_expr_4 = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, ElementAcc, ElementAcc, RoundStyle>,EVT_expr_3, cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc> /* value=0.0, dtype=torch.float32 */>; | |
using cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,EVT_expr_4>; | |
; | |
using cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue = | |
typename cutlass::epilogue::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
TileShapeMNK, | |
ClusterShapeMNK, | |
EpilogueTileType, | |
cutlass::half_t, cutlass::half_t, | |
void, cutlass::layout::ColumnMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
EpilogueScheduleType, | |
cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor | |
>::CollectiveOp; | |
using cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_mainloop = | |
typename cutlass::gemm::collective::CollectiveBuilder< | |
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, 8, | |
cutlass::half_t, | |
cute::Shape<cute::_64, cute::_64, cute::_32>, | |
cute::Shape<cute::_1,cute::_1,cute::_1>, | |
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue::SharedStorage)>, | |
cutlass::gemm::KernelTmaWarpSpecializedPingpong | |
>::CollectiveOp; | |
// Gemm operator cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma | |
using cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< | |
cute::Shape<int,int,int,int>, | |
cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_mainloop, | |
cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue, | |
cutlass::gemm::PersistentScheduler>; | |
// Define named type | |
struct cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma : | |
public cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_base { }; | |
using cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma>; | |
// When workspace_size is not a nullptr, populates requested workspace_size and returns. | |
// Otherwise, computes the Gemm kernel using the given workspace ptr. | |
extern "C" { | |
PT_EXPORT int cuda_cutlass_gemm_0(const half* X, const half* W, const half* aux_arg0_1, half* Y, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) { | |
try { | |
{ | |
if (!X) { | |
int64_t X_size = 65536L; | |
if (X_size > 0) { | |
throw std::runtime_error("input X is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!W) { | |
int64_t W_size = 65536L; | |
if (W_size > 0) { | |
throw std::runtime_error("input W is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!Y) { | |
int64_t Y_size = 65536L; | |
if (Y_size > 0) { | |
throw std::runtime_error("input Y is null but size is not 0!"); | |
} | |
} | |
} | |
{ | |
if (!aux_arg0_1) { | |
int64_t aux_arg0_1_size = 256L; | |
if (aux_arg0_1_size > 0) { | |
throw std::runtime_error("input aux_arg0_1 is null but size is not 0!"); | |
} | |
} | |
} | |
int64_t B = 1; | |
int64_t M = 256L; | |
int64_t K = 256L; | |
int64_t N = 256L; | |
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type::ElementAccumulator; | |
using coord_t = cutlass::gemm::GemmCoord::Index; | |
static cutlass::KernelHardwareInfo hw_info; | |
if (hw_info.sm_count == 0) { | |
// @TODO kadeng: Add support for Multi-GPU machines with heterogeneous SM counts | |
// for now we just pick the SM count of the first GPU | |
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); | |
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); | |
} | |
cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type::Arguments arguments; | |
// Initialize GemmUniversal3xInstance arguments. | |
arguments = { | |
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode | |
{ | |
static_cast<coord_t>(M), | |
static_cast<coord_t>(N), | |
static_cast<coord_t>(K), | |
static_cast<coord_t>(B) | |
}, // ProblemShape problem_shape | |
{ | |
(cutlass::half_t*)(X), // ElementA const* ptr_A | |
{ | |
256L /* stride_x0 */, | |
cute::Int<1>{} /* stride_x1 */, | |
0 /* batch_stride_x */ | |
}, // StrideA dA | |
(cutlass::half_t*)(W), // ElementB const* ptr_B | |
{ | |
cute::Int<1>{} /* stride_w1 */, | |
256L /* stride_w0 */, | |
0 /* batch_stride_w */ | |
}, // StrideB dB | |
}, // MainloopArguments mainloop | |
// see https://tinyurl.com/4rk89z48 | |
{ | |
{{{ /*plus: */ {}, { ((cutlass::half_t*)(aux_arg0_1)), cutlass::half_t(0), { cute::Int<0L>{}, cute::Int<1L>{}, cute::Int<0L>{} } } /* arg0_1 data pointer incl. offset, zero element value and strides for MNL (L=batch) dims */ }, { static_cast<ElementAcc>(0.0) }}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) | |
nullptr, // ElementC const* ptr_C | |
{ | |
cute::Int<1>{} /* stride_bias0 */, | |
cute::Int<1>{} /* stride_bias1 */, | |
0 /* batch_stride_bias */ | |
}, // StrideC dC | |
(cutlass::half_t*)(Y), // ElementD const* ptr_D | |
{ | |
256L /* stride_y0 */, | |
cute::Int<1>{} /* stride_y1 */, | |
0 /* batch_stride_y */ | |
}, // StrideD dD | |
}, // EpilogueArguments epilogue | |
}; | |
cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type gemm_op; | |
if (workspace_size) { | |
*workspace_size = gemm_op.get_workspace_size(arguments); | |
return 0; | |
} | |
{ | |
auto status = gemm_op.can_implement(arguments); | |
CUTLASS_CHECK(status); | |
} | |
#ifdef CUTLASS_DEBUG_TRACE_LEVEL | |
#if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
{ | |
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 | |
// we don't need a print statement, it's happening inside the function. | |
gemm_op.maximum_active_blocks(); | |
} | |
#endif | |
#endif | |
{ | |
auto status = gemm_op.initialize(arguments, workspace, stream); | |
CUTLASS_CHECK(status); | |
} | |
{ | |
auto status = gemm_op(stream); | |
CUTLASS_CHECK(status); | |
} | |
} | |
catch (std::exception& e) { | |
std::cerr << "Runtime error: " << e.what() << std::endl; | |
return -1; | |
} | |
catch (...) { | |
return -1; | |
} | |
return 0; | |
} | |
} | |
#ifdef GENERATE_STANDALONE_RUNNER | |
/// Helper to initialize a block of device data | |
template <class Element> | |
bool initialize_block( | |
cutlass::DeviceAllocation<Element>& block, | |
uint64_t seed, float max=1.0, float min=-1.0) { | |
if (block.size()<=0) return false; | |
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min)); | |
cutlass::reference::device::BlockFillRandomUniform( | |
block.get(), block.size(), seed, scope_max, scope_min, 0); | |
return true; | |
} | |
extern "C" int run_standalone(uint64_t seed) { | |
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl; | |
size_t workspace_size = 0; | |
size_t* workspace_size_ptr = &workspace_size; | |
using ElementA = cutlass::half_t; | |
using ElementB = cutlass::half_t; | |
using ElementC = uint8_t; // may not be void | |
using ElementD = cutlass::half_t; | |
using Element_arg0_1 = cutlass::half_t; | |
cutlass::DeviceAllocation<ElementA> X_data(65536); | |
initialize_block(X_data, seed++); | |
cutlass::DeviceAllocation<ElementB> W_data(65536); | |
initialize_block(W_data, seed++); | |
cutlass::DeviceAllocation<ElementC> Bias_data(0); | |
initialize_block(Bias_data, seed++); | |
cutlass::DeviceAllocation<ElementD> Y_data(65536); | |
cutlass::DeviceAllocation<Element_arg0_1> aux_arg0_1_data(256); | |
initialize_block(aux_arg0_1_data, seed++); | |
cutlass::DeviceAllocation<uint8_t> workspace_data; | |
// Call once with workspace_size_ptr set to get workspace size | |
std::cout << "Calling once to get workspace size" << std::endl; | |
cuda_cutlass_gemm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)aux_arg0_1_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);; | |
// Allocate workspace if neccessary | |
if (workspace_size > 0) { | |
workspace_data.reset(workspace_size); | |
std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl; | |
} | |
std::cout << "Calling Kernel as cuda_cutlass_gemm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)aux_arg0_1_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;" << std::endl; | |
workspace_size_ptr = nullptr; | |
cuda_cutlass_gemm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)aux_arg0_1_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);; | |
cudaError_t result = cudaDeviceSynchronize(); | |
if (result != cudaSuccess) { | |
std::cerr << "Device synchronize failed with error " | |
<< cudaGetErrorString(result) << std::endl; | |
return result; | |
} | |
return 0; | |
} | |
int main(int argc, char** argv) { | |
return run_standalone(1); | |
} | |
#endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Out-of-range shared or local address | |
========= at 0xbd0 in /home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include/cutlass/arch/barrier.h:169:cutlass::arch::ClusterBarrier::init(const unsigned long *, unsigned int) | |
========= by thread (0,0,0) in block (0,1,0) | |
========= Device Frame:/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include/cutlass/arch/barrier.h:127:cutlass::arch::ClusterBarrier::init(unsigned int) const [0xb20] | |
========= Device Frame:/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include/cutlass/pipeline/sm90_pipeline.hpp:1073:cutlass::OrderedSequenceBarrier<(int)1, (int)2>::OrderedSequenceBarrier(cutlass::OrderedSequenceBarrier<(int)1, (int)2>::SharedStorage &, const cutlass::OrderedSequenceBarrier<(int)1, (int)2>::Params &) [0xb20] | |
========= Device Frame:/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp:382:cutlass::gemm::kernel::GemmUniversal<cute::tuple<int, int, int, int>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<(int)27, cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cutlass::gemm::KernelTmaWarpSpecializedPingpong>, cute::tuple<cute::C<(int)64>, cute::C<(int)64>, cute::C<(int)32>>, cutlass::half_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::half_t, cute::tuple<cute::C<(int)1>, long, long>, cute::TiledMMA<cute::MMA_Atom<cute::SM90_64x64x16_F16F16F16_SS<(cute::GMMA::Major)0, (cute::GMMA::Major)1, (cute::GMMA::ScaleIn)1, (cute::GMMA::ScaleIn)1>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm90TmaWarpSpecialized<(int)2, (int)2, (int)16, (bool)0>, cute::tuple<cute::C<(int)64>, cute::C<(int)64>, cute::C<(int)32>>, cute::tuple<cute::C<(int)64>, cute::C<(int)32>>, void, cute::tuple<long, cute::C<(int)1>, long>, cutlass::half_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<identity_op, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<cutlass::plus, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90AccFetch, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<identity_op, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90RowBroadcast<(int)2, cute::tuple<cute::C<(int)64>, cute::C<(int)64>, cute::C<(int)32>>, cutlass::half_t, cute::tuple<cute::C<(int)0>, cute::C<(int)1>, cute::C<(int)0>>, (int)8, (bool)1>>>, cutlass::epilogue::fusion::Sm90ScalarBroadcast<cutlass::half_t, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>, (int)1, cutlass::multiplies>>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, cute::SM75_U32x4_LDSM_N, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, cute::SM90_U32x4_STSM_N>, cutlass::gemm::PersistentScheduler, void>::operator ()(const cutlass::gemm::kernel::GemmUniversal<cute::tuple<int, int, int, int>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<(int)27, cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cutlass::gemm::KernelTmaWarpSpecializedPingpong>, cute::tuple<cute::C<(int)64>, cute::C<(int)64>, cute::C<(int)32>>, cutlass::half_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::half_t, cute::tuple<cute::C<(int)1>, long, long>, cute::TiledMMA<cute::MMA_Atom<cute::SM90_64x64x16_F16F16F16_SS<(cute::GMMA::Major)0, (cute::GMMA::Major)1, (cute::GMMA::ScaleIn)1, (cute::GMMA::ScaleIn)1>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm90TmaWarpSpecialized<(int)2, (int)2, (int)16, (bool)0>, cute::tuple<cute::C<(int)64>, cute::C<(int)64>, cute::C<(int)32>>, cute::tuple<cute::C<(int)64>, cute::C<(int)32>>, void, cute::tuple<long, cute::C<(int)1>, long>, cutlass::half_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<identity_op, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<cutlass::plus, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90AccFetch, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<identity_op, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90RowBroadcast<(int)2, cute::tuple<cute::C<(int)64>, cute::C<(int)64>, cute::C<(int)32>>, cutlass::half_t, cute::tuple<cute::C<(int)0>, cute::C<(int)1>, cute::C<(int)0>>, (int)8, (bool)1>>>, cutlass::epilogue::fusion::Sm90ScalarBroadcast<cutlass::half_t, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>, (int)1, cutlass::multiplies>>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, cute::SM75_U32x4_LDSM_N, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, cute::SM90_U32x4_STSM_N>, cutlass::gemm::PersistentScheduler, void>::Params &, char *) [0xad0] | |
========= Device Frame:/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include/cutlass/device_kernel.h:109:void cutlass::device_kernel<cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma>(T1::Params) [0x20] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Environment: | |
* Linux x64, NVIDIA H100 GPU | |
* CUDA 12.1 | |
* Cutlass v3.3.0 ( tagged release ) | |
Command ( example ): | |
nvcc -t=0 -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -w -gencode=arch=compute_90a,code=[sm_90a,compute_90a] -O1 -std=c++17 --expt-relaxed-constexpr -lineinfo -g -DCUTLASS_DEBUG_TRACE_LEVEL=1 -Xcompiler=-fPIC -Xcompiler=-fno-strict-aliasing -Xcompiler -fvisibility=hidden -Xcompiler=-Wconversion -I/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include -I/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/tools/library/include -I/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/tools/library/src -I/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/tools/util/include -L/home/klondenberg/local/cuda121/lib64 -L/home/klondenberg/local/cuda121/lib64/stubs -lcuda -lcudart -DGENERATE_STANDALONE_RUNNER -o broken5 broken5.cu | |
Where | |
* /home/klondenberg/github/pytorch/pytorch/third_party/cutlass is the Cutlass v3.3.0 check out directory | |
* /home/klondenberg/local/cuda121 is the CUDA 12.1 Toolkit path | |
* nvcc is from CUDA 12.1 toolkit | |
To obtain the error trace above, run the compiled executable under compute-sanitizer | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment