Created
August 8, 2023 08:17
-
-
Save tiandiao123/0706e748bff9675724e667210b13b8d3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <torch/extension.h> | |
#include <cutlass/gemm/gemm.h> | |
#include <cutlass/epilogue/thread/linear_combination.h> | |
torch::Tensor bmm_fp16_fp16_f32(torch::Tensor A, torch::Tensor B, float alpha) { | |
int batch_size = A.size(0); | |
int M = A.size(1); | |
int N = B.size(1); | |
int K = A.size(2); | |
auto C = torch::empty({batch_size, M, N}, torch::dtype(torch::kFloat32).device(A.device())); | |
int lda = A.size(2); | |
int ldb = B.size(2); | |
int ldc = C.size(2); | |
using LayoutInputA = cutlass::layout::RowMajor; | |
using LayoutInputB = cutlass::layout::ColumnMajor; | |
using LayoutOutput = cutlass::layout::RowMajor; | |
using ElementOutput = float; | |
using ElementInputA = at::Half; | |
using ElementInputB = at::Half; | |
using ElementAccumulator = float; // It's common to use float32 as accumulator for fp16 operations | |
using ElementComputeEpilogue = float; | |
using EpilogueOp = cutlass::epilogue::thread::LinearCombination< | |
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value, | |
ElementAccumulator, ElementComputeEpilogue>; | |
using Gemm = cutlass::gemm::device::GemmBatched< | |
ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput, | |
LayoutOutput, ElementAccumulator, cutlass::arch::OpClassTensorOp, | |
cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 128, 64>, | |
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, | |
EpilogueOp>; | |
long long int batch_stride_A = M * K; | |
long long int batch_stride_B = N * K; | |
long long int batch_stride_C = M * N; | |
Gemm gemm_op; | |
typename Gemm::Arguments arguments{ | |
{M, N, K}, {A.data_ptr<ElementInputA>(), lda}, | |
batch_stride_A, {B.data_ptr<ElementInputB>(), ldb}, | |
batch_stride_B, {C.data_ptr<ElementOutput>(), ldc}, | |
batch_stride_C, {C.data_ptr<ElementOutput>(), ldc}, | |
batch_stride_C, {alpha, 0}, | |
batch_size}; | |
size_t workspace_size = Gemm::get_workspace_size(arguments); | |
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); | |
cutlass::Status status = gemm_op.can_implement(arguments); | |
if (status != cutlass::Status::kSuccess) { | |
throw std::runtime_error("cutlass cannot implement"); | |
} | |
status = gemm_op.initialize(arguments, workspace.get()); | |
if (status != cutlass::Status::kSuccess) { | |
throw std::runtime_error("cutlass cannot initialize"); | |
} | |
status = gemm_op(); | |
if (status != cutlass::Status::kSuccess) { | |
throw std::runtime_error("cutlass cannot run"); | |
} | |
return C; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment