Last active
August 8, 2023 13:42
-
-
Save tiandiao123/d3274b63b7c4f9378d96f11fcd2aced5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cublas_v2.h> | |
#include <cstdint> | |
#include <cuda.h> | |
#include <cuda_runtime.h> | |
#include <cuda_fp16.h> | |
#include <iostream> | |
#include <torch/torch.h> | |
#include "cutlass/cutlass.h" | |
#include "cutlass/gemm/device/gemm_splitk_parallel.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/device/gemm.h" | |
#include "cutlass/util/reference/host/tensor_compare.h" | |
#include "cutlass/util/reference/host/tensor_copy.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include "cutlass/util/tensor_view_io.h" | |
#include <iostream> | |
#include <stdexcept> | |
// The code section below describes matrix layout of input and output matrices. Column Major for | |
// Matrix A, Row Major for Matrix B and Row Major for Matrix C | |
using LayoutInputA = cutlass::layout::ColumnMajor; | |
using LayoutInputB = cutlass::layout::RowMajor; | |
using LayoutOutput = cutlass::layout::RowMajor; | |
// The code section below describes datatype for input, output matrices and computation between | |
// elements in input matrices. | |
using ElementAccumulator = float; // <- data type of accumulator | |
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations | |
using ElementInputA = at::Half; // <- data type of elements in input matrix A | |
using ElementInputB = at::Half; // <- data type of elements in input matrix B | |
using ElementOutput = float; // <- data type of elements in output matrix D | |
using MMAOp = cutlass::arch::OpClassTensorOp; | |
using SmArch = cutlass::arch::Sm80; | |
// This code section describes the tile size a thread block will compute | |
using ShapeMMAThreadBlock = | |
cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32 | |
// This code section describes tile size a warp will compute | |
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32 | |
// This code section describes the size of MMA op | |
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; | |
// This code section describes ? | |
using EpilogueOp = cutlass::epilogue::thread::LinearCombination< | |
ElementOutput, // <- data type of output matrix | |
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- This is the number of elements per | |
// vectorized memory access. For half | |
// precision, it's 8 elements. This becomes | |
// the vector width of math instructions in | |
// epilogue too | |
ElementAccumulator, // <- data type of accumulator | |
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function | |
// Put all the created template variables to create GemmSplitKParallel template variable | |
using Gemm = cutlass::gemm::device::GemmSplitKParallel<ElementInputA, | |
LayoutInputA, | |
ElementInputB, | |
LayoutInputB, | |
ElementOutput, | |
LayoutOutput, | |
ElementAccumulator, | |
MMAOp, | |
SmArch, | |
ShapeMMAThreadBlock, | |
ShapeMMAWarp, | |
ShapeMMAOp, | |
EpilogueOp>; | |
// this function currently only works for A100 GPU | |
torch::Tensor bmm_fp16_cutlass(torch::Tensor& A, torch::Tensor& B, torch::Tensor& C, float alpha) { | |
const int length_m = A.size(0); | |
const int length_k = A.size(1); | |
const int length_n = B.size(1); | |
auto D = torch::empty({length_m, length_n}, torch::dtype(torch::kFloat32).device(A.device())); | |
// Create a tuple of problem size for matrix multiplication | |
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); | |
int split_k_slices = 16; | |
// Initialize alpha and beta for dot product computation | |
ElementComputeEpilogue alpha = ElementComputeEpilogue(1); | |
ElementComputeEpilogue beta = ElementComputeEpilogue(0); | |
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch | |
// instantiated CUTLASS kernel | |
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication | |
A.data_ptr<ElementInputA>(), // <- reference to matrix A on device | |
B.data_ptr<ElementInputB>(), // <- reference to matrix B on device | |
C.data_ptr<ElementOutput>(), // <- reference to matrix C on device | |
D.data_ptr<ElementOutput>(), // <- reference to matrix D on device | |
{alpha, beta}, // <- tuple of alpha and beta | |
split_k_slices}; // <- k-dimension split factor | |
// Using the arguments, query for extra workspace required for matrix multiplication computation | |
size_t workspace_size = Gemm::get_workspace_size(arguments); | |
// Allocate workspace memory | |
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); | |
// Instantiate CUTLASS kernel depending on templates | |
Gemm gemm_op; | |
// Initialize CUTLASS kernel with arguments and workspace pointer | |
cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); | |
CUTLASS_CHECK(status); | |
// Launch initialized CUTLASS kernel | |
status = gemm_op(); | |
CUTLASS_CHECK(status); | |
return D; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment