Created
August 9, 2023 07:21
-
-
Save tiandiao123/1a43b05436deb5d81971bb02c2d3d8e4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cublas_v2.h> | |
#include <cstdint> | |
#include <cuda.h> | |
#include <cuda_runtime.h> | |
#include <cuda_fp16.h> | |
#include <iostream> | |
#include <torch/torch.h> | |
#include <torch/types.h> | |
#include <c10/util/Half.h> | |
#include "cutlass/cutlass.h" | |
#include "cutlass/gemm/device/gemm.h" | |
#include "cutlass/gemm/device/gemm_splitk_parallel.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/device/gemm.h" | |
#include "cutlass/util/reference/host/tensor_compare.h" | |
#include "cutlass/util/reference/host/tensor_copy.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/tensor_view_io.h" | |
#include "helper.h" | |
#include <iostream> | |
#include <stdexcept> | |
// The code section below describes matrix layout of input and output matrices. Column Major for | |
// Matrix A, Row Major for Matrix B and Row Major for Matrix C | |
using LayoutInputA = cutlass::layout::ColumnMajor; | |
using LayoutInputB = cutlass::layout::RowMajor; | |
using LayoutOutput = cutlass::layout::RowMajor; | |
// The code section below describes datatype for input, output matrices and computation between | |
// elements in input matrices. | |
using ElementAccumulator = float; // <- data type of accumulator | |
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations | |
using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A | |
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B | |
using ElementOutput = float; // <- data type of elements in output matrix D | |
using MMAOp = cutlass::arch::OpClassTensorOp; | |
using SmArch = cutlass::arch::Sm80; | |
// This code section describes the tile size a thread block will compute | |
using ShapeMMAThreadBlock = | |
cutlass::gemm::GemmShape<256, 128, 64>; // <- threadblock tile M = 256, N = 128, K = 32 | |
// This code section describes tile size a warp will compute | |
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = 64, N = 64, K = 64 | |
// This code section describes the size of MMA op | |
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; | |
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; | |
// This code section describes ? | |
using EpilogueOp = cutlass::epilogue::thread::LinearCombination< | |
ElementOutput, // <- data type of output matrix | |
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- This is the number of elements per | |
// vectorized memory access. For half | |
// precision, it's 8 elements. This becomes | |
// the vector width of math instructions in | |
// epilogue too | |
ElementAccumulator, // <- data type of accumulator | |
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function | |
// Number of pipelines you want to use | |
constexpr int NumStages = 2; | |
// Put all the created template variables to create GemmSplitKParallel template variable | |
using Gemm = cutlass::gemm::device::Gemm<ElementInputA, | |
LayoutInputA, | |
ElementInputB, | |
LayoutInputB, | |
ElementOutput, | |
LayoutOutput, | |
ElementAccumulator, | |
MMAOp, | |
SmArch, | |
ShapeMMAThreadBlock, | |
ShapeMMAWarp, | |
ShapeMMAOp, | |
EpilogueOp, | |
SwizzleThreadBlock, | |
NumStages>; | |
// this function currently only works for A100 GPU | |
void gemm_fp16_cutlass(torch::Tensor& A, | |
torch::Tensor& B, | |
torch::Tensor& C, | |
torch::Tensor& D, | |
float alpha_val, | |
float beta_val) | |
{ | |
const int length_m = A.size(0); | |
const int length_k = A.size(1); | |
const int length_n = B.size(1); | |
// Create a tuple of problem size for matrix multiplication | |
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); | |
// // Initialize alpha and beta for dot product computation | |
ElementComputeEpilogue alpha = ElementComputeEpilogue(alpha_val); | |
ElementComputeEpilogue beta = ElementComputeEpilogue(beta_val); | |
// Split K dimension into 1 partitions | |
int split_k_slices = 1; | |
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch | |
// instantiated CUTLASS kernel | |
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication | |
reinterpret_cast<cutlass::half_t*>(A.data_ptr<__half>()), // <- reference to matrix A on device | |
reinterpret_cast<cutlass::half_t*>(B.data_ptr<__half>()), // <- reference to matrix B on device | |
C.data_ptr<ElementOutput>(), // <- reference to matrix C on device | |
D.data_ptr<ElementOutput>(), // <- reference to matrix D on device | |
{alpha, beta}, // <- tuple of alpha and beta | |
split_k_slices}; // <- k-dimension split factor | |
// Using the arguments, query for extra workspace required for matrix multiplication computation | |
size_t workspace_size = Gemm::get_workspace_size(arguments); | |
// // Allocate workspace memory | |
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); | |
// // Instantiate CUTLASS kernel depending on templates | |
Gemm gemm_op; | |
// // Initialize CUTLASS kernel with arguments and workspace pointer | |
cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); | |
CUTLASS_CHECK(status); | |
// // Launch initialized CUTLASS kernel | |
status = gemm_op(); | |
CUTLASS_CHECK(status); | |
return D; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment