Created
August 8, 2023 08:15
-
-
Save tiandiao123/5ff3b7f06e6de27b6374718f7fce5e25 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include "cutlass/cutlass.h" | |
#include "cutlass/gemm/device/gemm.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/host/tensor_compare.h" | |
#include "cutlass/util/reference/host/tensor_copy.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/tensor_view_io.h" | |
#include "helper.h" | |
// Define half precision type for CUTLASS | |
using cutlass::half_t; | |
bool gemm_fp16_tensorcore( | |
int length_m, | |
int length_n, | |
int length_k, | |
cutlass::HostTensor<half_t, cutlass::layout::RowMajor>& tensor_a, | |
cutlass::HostTensor<half_t, cutlass::layout::ColumnMajor>& tensor_b, | |
cutlass::HostTensor<half_t, cutlass::layout::RowMajor>& tensor_c) { | |
// Define data types and layouts | |
using ElementAccumulator = float; | |
using ElementComputeEpilogue = float; | |
using ElementInputA = half_t; | |
using ElementInputB = half_t; | |
using ElementOutput = half_t; | |
using LayoutInputA = cutlass::layout::RowMajor; | |
using LayoutInputB = cutlass::layout::ColumnMajor; | |
using LayoutOutput = cutlass::layout::RowMajor; | |
using MMAOp = cutlass::arch::OpClassTensorOp; | |
using SmArch = cutlass::arch::Sm75; | |
using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 256, 64>; | |
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; | |
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>; | |
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; | |
using EpilogueOp = cutlass::epilogue::thread::LinearCombination< | |
ElementOutput, | |
128 / cutlass::sizeof_bits<ElementOutput>::value, | |
ElementAccumulator, | |
ElementComputeEpilogue>; | |
constexpr int NumStages = 2; | |
using Gemm = cutlass::gemm::device::Gemm<ElementInputA, | |
LayoutInputA, | |
ElementInputB, | |
LayoutInputB, | |
ElementOutput, | |
LayoutOutput, | |
ElementAccumulator, | |
MMAOp, | |
SmArch, | |
ShapeMMAThreadBlock, | |
ShapeMMAWarp, | |
ShapeMMAOp, | |
EpilogueOp, | |
SwizzleThreadBlock, | |
NumStages>; | |
// Create problem size | |
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); | |
// Initialize alpha and beta for computation | |
ElementComputeEpilogue alpha = 1.0f; | |
ElementComputeEpilogue beta = 0.0f; | |
int split_k_slices = 1; | |
typename Gemm::Arguments arguments{ | |
problem_size, | |
tensor_a.device_ref(), | |
tensor_b.device_ref(), | |
tensor_c.device_ref(), | |
tensor_c.device_ref(), // Use tensor C for both input and output | |
{alpha, beta}, | |
split_k_slices}; | |
size_t workspace_size = Gemm::get_workspace_size(arguments); | |
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); | |
Gemm gemm_op; | |
cutlass::Status status = gemm_op.can_implement(arguments); | |
if (status != cutlass::Status::kSuccess) { | |
return false; | |
} | |
status = gemm_op.initialize(arguments, workspace.get()); | |
if (status != cutlass::Status::kSuccess) { | |
return false; | |
} | |
status = gemm_op(); | |
if (status != cutlass::Status::kSuccess) { | |
return false; | |
} | |
cudaDeviceSynchronize(); | |
tensor_c.sync_host(); | |
return true; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment