Skip to content

Instantly share code, notes, and snippets.

@kadeng
Last active December 6, 2023 16:43
Show Gist options
  • Save kadeng/31df46a19d093bdfb36977892f578e1c to your computer and use it in GitHub Desktop.
Save kadeng/31df46a19d093bdfb36977892f578e1c to your computer and use it in GitHub Desktop.
Cutlass bug report
#include <exception>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/util/distribution.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#ifdef GENERATE_STANDALONE_RUNNER
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include <iostream>
#endif
// We compile all models with -fvisibility=hidden. Any symbols that need to be
// exposed in the final shared library must be declared with PT_EXPORT to make
// them visible.
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
#define PT_EXPORT __attribute__((__visibility__("default")))
#else
#ifdef _WIN32
#define PT_EXPORT __declspec(dllexport)
#else
#define PT_EXPORT
#endif
#endif
using bfloat16 = nv_bfloat16;
using namespace cute;
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \
throw std::runtime_error(msg); \
} \
}
// Used as pass-through functor in EVT just for type casting / rounding
template <typename T>
struct identity_op {
CUTLASS_HOST_DEVICE
T operator()(T val) const { return val; }
};
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
static_assert(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> ||
cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>,
"Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue");
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementAcc = cutlass::half_t;
using ElementD = cutlass::half_t;
using ElementC = void;
using TileShapeMNK = cute::Shape<cute::_64, cute::_64, cute::_32>;
using ClusterShapeMNK = cute::Shape<cute::_1,cute::_1,cute::_1>;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShapeMNK,
EpilogueTileType,
ElementC,
ElementD,
EpilogueScheduleType
>;
using ADDMM_EVT = // alpha * acc + beta * C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add,
ElementD, ElementAcc, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // beta
cutlass::epilogue::fusion::Sm90SrcFetch, // C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementAcc,
ElementAcc, RoundStyle>, // alpha * acc
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // alpha
cutlass::epilogue::fusion::Sm90AccFetch // acc
>>;
using EVT_expr_1 = cutlass::epilogue::fusion::Sm90AccFetch /* :=buf0 (matmul output in accumulator) */;
using arg0_1AuxLoadDesc = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>, cutlass::half_t>;
using EVT_expr_2 = cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90Compute<identity_op,ElementAcc, typename arg0_1AuxLoadDesc::Element, RoundStyle >,
cutlass::epilogue::fusion::Sm90RowBroadcast<arg0_1AuxLoadDesc::Stages, TileShapeMNK, typename arg0_1AuxLoadDesc::Element, typename arg0_1AuxLoadDesc::Stride>> /* :=arg0_1 as aux operand, cast to accumulator dtype */;
using EVT_expr_3 = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::plus, ElementAcc, ElementAcc, RoundStyle>,EVT_expr_1,EVT_expr_2>;
using EVT_expr_4 = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, ElementAcc, ElementAcc, RoundStyle>,EVT_expr_3, cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc> /* value=0.0, dtype=torch.float32 */>;
using cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,EVT_expr_4>;
;
using cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShapeMNK,
ClusterShapeMNK,
EpilogueTileType,
cutlass::half_t, cutlass::half_t,
void, cutlass::layout::ColumnMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
EpilogueScheduleType,
cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor
>::CollectiveOp;
using cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_mainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::half_t,
cute::Shape<cute::_64, cute::_64, cute::_32>,
cute::Shape<cute::_1,cute::_1,cute::_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong
>::CollectiveOp;
// Gemm operator cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma
using cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_base = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_mainloop,
cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue,
cutlass::gemm::PersistentScheduler>;
// Define named type
struct cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma :
public cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_base { };
using cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma>;
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
// Otherwise, computes the Gemm kernel using the given workspace ptr.
extern "C" {
PT_EXPORT int cuda_cutlass_gemm_0(const half* X, const half* W, const half* aux_arg0_1, half* Y, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) {
try {
{
if (!X) {
int64_t X_size = 65536L;
if (X_size > 0) {
throw std::runtime_error("input X is null but size is not 0!");
}
}
}
{
if (!W) {
int64_t W_size = 65536L;
if (W_size > 0) {
throw std::runtime_error("input W is null but size is not 0!");
}
}
}
{
if (!Y) {
int64_t Y_size = 65536L;
if (Y_size > 0) {
throw std::runtime_error("input Y is null but size is not 0!");
}
}
}
{
if (!aux_arg0_1) {
int64_t aux_arg0_1_size = 256L;
if (aux_arg0_1_size > 0) {
throw std::runtime_error("input aux_arg0_1 is null but size is not 0!");
}
}
}
int64_t B = 1;
int64_t M = 256L;
int64_t K = 256L;
int64_t N = 256L;
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
if (hw_info.sm_count == 0) {
// @TODO kadeng: Add support for Multi-GPU machines with heterogeneous SM counts
// for now we just pick the SM count of the first GPU
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type::Arguments arguments;
// Initialize GemmUniversal3xInstance arguments.
arguments = {
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
{
static_cast<coord_t>(M),
static_cast<coord_t>(N),
static_cast<coord_t>(K),
static_cast<coord_t>(B)
}, // ProblemShape problem_shape
{
(cutlass::half_t*)(X), // ElementA const* ptr_A
{
256L /* stride_x0 */,
cute::Int<1>{} /* stride_x1 */,
0 /* batch_stride_x */
}, // StrideA dA
(cutlass::half_t*)(W), // ElementB const* ptr_B
{
cute::Int<1>{} /* stride_w1 */,
256L /* stride_w0 */,
0 /* batch_stride_w */
}, // StrideB dB
}, // MainloopArguments mainloop
// see https://tinyurl.com/4rk89z48
{
{{{ /*plus: */ {}, { ((cutlass::half_t*)(aux_arg0_1)), cutlass::half_t(0), { cute::Int<0L>{}, cute::Int<1L>{}, cute::Int<0L>{} } } /* arg0_1 data pointer incl. offset, zero element value and strides for MNL (L=batch) dims */ }, { static_cast<ElementAcc>(0.0) }}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
nullptr, // ElementC const* ptr_C
{
cute::Int<1>{} /* stride_bias0 */,
cute::Int<1>{} /* stride_bias1 */,
0 /* batch_stride_bias */
}, // StrideC dC
(cutlass::half_t*)(Y), // ElementD const* ptr_D
{
256L /* stride_y0 */,
cute::Int<1>{} /* stride_y1 */,
0 /* batch_stride_y */
}, // StrideD dD
}, // EpilogueArguments epilogue
};
cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type gemm_op;
if (workspace_size) {
*workspace_size = gemm_op.get_workspace_size(arguments);
return 0;
}
{
auto status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
}
#ifdef CUTLASS_DEBUG_TRACE_LEVEL
#if CUTLASS_DEBUG_TRACE_LEVEL == 1
{
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1
// we don't need a print statement, it's happening inside the function.
gemm_op.maximum_active_blocks();
}
#endif
#endif
{
auto status = gemm_op.initialize(arguments, workspace, stream);
CUTLASS_CHECK(status);
}
{
auto status = gemm_op(stream);
CUTLASS_CHECK(status);
}
}
catch (std::exception& e) {
std::cerr << "Runtime error: " << e.what() << std::endl;
return -1;
}
catch (...) {
return -1;
}
return 0;
}
}
#ifdef GENERATE_STANDALONE_RUNNER
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed, float max=1.0, float min=-1.0) {
if (block.size()<=0) return false;
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min));
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
extern "C" int run_standalone(uint64_t seed) {
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl;
size_t workspace_size = 0;
size_t* workspace_size_ptr = &workspace_size;
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = uint8_t; // may not be void
using ElementD = cutlass::half_t;
using Element_arg0_1 = cutlass::half_t;
cutlass::DeviceAllocation<ElementA> X_data(65536);
initialize_block(X_data, seed++);
cutlass::DeviceAllocation<ElementB> W_data(65536);
initialize_block(W_data, seed++);
cutlass::DeviceAllocation<ElementC> Bias_data(0);
initialize_block(Bias_data, seed++);
cutlass::DeviceAllocation<ElementD> Y_data(65536);
cutlass::DeviceAllocation<Element_arg0_1> aux_arg0_1_data(256);
initialize_block(aux_arg0_1_data, seed++);
cutlass::DeviceAllocation<uint8_t> workspace_data;
// Call once with workspace_size_ptr set to get workspace size
std::cout << "Calling once to get workspace size" << std::endl;
cuda_cutlass_gemm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)aux_arg0_1_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;
// Allocate workspace if neccessary
if (workspace_size > 0) {
workspace_data.reset(workspace_size);
std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl;
}
std::cout << "Calling Kernel as cuda_cutlass_gemm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)aux_arg0_1_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;" << std::endl;
workspace_size_ptr = nullptr;
cuda_cutlass_gemm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)aux_arg0_1_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Device synchronize failed with error "
<< cudaGetErrorString(result) << std::endl;
return result;
}
return 0;
}
int main(int argc, char** argv) {
return run_standalone(1);
}
#endif
Out-of-range shared or local address
========= at 0xbd0 in /home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include/cutlass/arch/barrier.h:169:cutlass::arch::ClusterBarrier::init(const unsigned long *, unsigned int)
========= by thread (0,0,0) in block (0,1,0)
========= Device Frame:/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include/cutlass/arch/barrier.h:127:cutlass::arch::ClusterBarrier::init(unsigned int) const [0xb20]
========= Device Frame:/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include/cutlass/pipeline/sm90_pipeline.hpp:1073:cutlass::OrderedSequenceBarrier<(int)1, (int)2>::OrderedSequenceBarrier(cutlass::OrderedSequenceBarrier<(int)1, (int)2>::SharedStorage &, const cutlass::OrderedSequenceBarrier<(int)1, (int)2>::Params &) [0xb20]
========= Device Frame:/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp:382:cutlass::gemm::kernel::GemmUniversal<cute::tuple<int, int, int, int>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<(int)27, cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cutlass::gemm::KernelTmaWarpSpecializedPingpong>, cute::tuple<cute::C<(int)64>, cute::C<(int)64>, cute::C<(int)32>>, cutlass::half_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::half_t, cute::tuple<cute::C<(int)1>, long, long>, cute::TiledMMA<cute::MMA_Atom<cute::SM90_64x64x16_F16F16F16_SS<(cute::GMMA::Major)0, (cute::GMMA::Major)1, (cute::GMMA::ScaleIn)1, (cute::GMMA::ScaleIn)1>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm90TmaWarpSpecialized<(int)2, (int)2, (int)16, (bool)0>, cute::tuple<cute::C<(int)64>, cute::C<(int)64>, cute::C<(int)32>>, cute::tuple<cute::C<(int)64>, cute::C<(int)32>>, void, cute::tuple<long, cute::C<(int)1>, long>, cutlass::half_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<identity_op, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<cutlass::plus, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90AccFetch, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<identity_op, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90RowBroadcast<(int)2, cute::tuple<cute::C<(int)64>, cute::C<(int)64>, cute::C<(int)32>>, cutlass::half_t, cute::tuple<cute::C<(int)0>, cute::C<(int)1>, cute::C<(int)0>>, (int)8, (bool)1>>>, cutlass::epilogue::fusion::Sm90ScalarBroadcast<cutlass::half_t, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>, (int)1, cutlass::multiplies>>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, cute::SM75_U32x4_LDSM_N, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, cute::SM90_U32x4_STSM_N>, cutlass::gemm::PersistentScheduler, void>::operator ()(const cutlass::gemm::kernel::GemmUniversal<cute::tuple<int, int, int, int>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<(int)27, cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cutlass::gemm::KernelTmaWarpSpecializedPingpong>, cute::tuple<cute::C<(int)64>, cute::C<(int)64>, cute::C<(int)32>>, cutlass::half_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::half_t, cute::tuple<cute::C<(int)1>, long, long>, cute::TiledMMA<cute::MMA_Atom<cute::SM90_64x64x16_F16F16F16_SS<(cute::GMMA::Major)0, (cute::GMMA::Major)1, (cute::GMMA::ScaleIn)1, (cute::GMMA::ScaleIn)1>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm90TmaWarpSpecialized<(int)2, (int)2, (int)16, (bool)0>, cute::tuple<cute::C<(int)64>, cute::C<(int)64>, cute::C<(int)32>>, cute::tuple<cute::C<(int)64>, cute::C<(int)32>>, void, cute::tuple<long, cute::C<(int)1>, long>, cutlass::half_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<identity_op, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<cutlass::plus, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90AccFetch, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<identity_op, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90RowBroadcast<(int)2, cute::tuple<cute::C<(int)64>, cute::C<(int)64>, cute::C<(int)32>>, cutlass::half_t, cute::tuple<cute::C<(int)0>, cute::C<(int)1>, cute::C<(int)0>>, (int)8, (bool)1>>>, cutlass::epilogue::fusion::Sm90ScalarBroadcast<cutlass::half_t, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>, (int)1, cutlass::multiplies>>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, cute::SM75_U32x4_LDSM_N, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, cute::SM90_U32x4_STSM_N>, cutlass::gemm::PersistentScheduler, void>::Params &, char *) [0xad0]
========= Device Frame:/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include/cutlass/device_kernel.h:109:void cutlass::device_kernel<cutlass3x_sm90_tensorop_h64x64x16gemm_f16_f16_f16_void_f16_64x64x32_1x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma>(T1::Params) [0x20]
Environment:
* Linux x64, NVIDIA H100 GPU
* CUDA 12.1
* Cutlass v3.3.0 ( tagged release )
Command ( example ):
nvcc -t=0 -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -w -gencode=arch=compute_90a,code=[sm_90a,compute_90a] -O1 -std=c++17 --expt-relaxed-constexpr -lineinfo -g -DCUTLASS_DEBUG_TRACE_LEVEL=1 -Xcompiler=-fPIC -Xcompiler=-fno-strict-aliasing -Xcompiler -fvisibility=hidden -Xcompiler=-Wconversion -I/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/include -I/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/tools/library/include -I/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/tools/library/src -I/home/klondenberg/github/pytorch/pytorch/third_party/cutlass/tools/util/include -L/home/klondenberg/local/cuda121/lib64 -L/home/klondenberg/local/cuda121/lib64/stubs -lcuda -lcudart -DGENERATE_STANDALONE_RUNNER -o broken5 broken5.cu
Where
* /home/klondenberg/github/pytorch/pytorch/third_party/cutlass is the Cutlass v3.3.0 check out directory
* /home/klondenberg/local/cuda121 is the CUDA 12.1 Toolkit path
* nvcc is from CUDA 12.1 toolkit
To obtain the error trace above, run the compiled executable under compute-sanitizer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment