Created
May 28, 2025 19:06
-
-
Save abelardojarab/3ae5010573444d77ed10ffa5aa320de5 to your computer and use it in GitHub Desktop.
Optimized CUDA Matrix multiplication
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// GEMM kernel v07. | |
// Each thread in the block processes THREAD_TILE_SIZE_Y * | |
// THREAD_TILE_SIZE_X output values. Number of threads BLOCK_TILE_SIZE_Y * | |
// BLOCK_TILE_SIZE_X / (THREAD_TILE_SIZE_Y * THREAD_TILE_SIZE_X) | |
template <typename T, size_t BLOCK_TILE_SIZE_X, size_t BLOCK_TILE_SIZE_Y, | |
size_t BLOCK_TILE_SIZE_K, size_t BLOCK_TILE_SKEW_SIZE_X, | |
size_t BLOCK_TILE_SKEW_SIZE_Y, size_t WARP_TILE_SIZE_X, | |
size_t WARP_TILE_SIZE_Y, size_t WMMA_TILE_SIZE_X, | |
size_t WMMA_TILE_SIZE_Y, size_t WMMA_TILE_SIZE_K, size_t NUM_THREADS> | |
__global__ void gemm_v07_vectorized(size_t m, size_t n, size_t k, T alpha, | |
T const* A, size_t lda, T const* B, | |
size_t ldb, T beta, T* C, size_t ldc) | |
{ | |
constexpr size_t NUM_WARPS_X{BLOCK_TILE_SIZE_X / WARP_TILE_SIZE_X}; | |
static_assert(BLOCK_TILE_SIZE_X % WARP_TILE_SIZE_X == 0U); | |
static_assert(BLOCK_TILE_SIZE_Y % WARP_TILE_SIZE_Y == 0U); | |
// Cache a tile of A and B in shared memory for data reuse. | |
__shared__ T A_thread_block_tile_transposed[BLOCK_TILE_SIZE_K] | |
[BLOCK_TILE_SIZE_Y + | |
BLOCK_TILE_SKEW_SIZE_Y]; | |
__shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X + | |
BLOCK_TILE_SKEW_SIZE_X]; | |
constexpr size_t NUM_WMMA_TILES_X{WARP_TILE_SIZE_X / WMMA_TILE_SIZE_X}; | |
static_assert(WARP_TILE_SIZE_X % WMMA_TILE_SIZE_X == 0U); | |
constexpr size_t NUM_WMMA_TILES_Y{WARP_TILE_SIZE_Y / WMMA_TILE_SIZE_Y}; | |
static_assert(WARP_TILE_SIZE_Y % WMMA_TILE_SIZE_Y == 0U); | |
constexpr size_t NUM_WMMA_TILES_K{BLOCK_TILE_SIZE_K / WMMA_TILE_SIZE_K}; | |
static_assert(BLOCK_TILE_SIZE_K % WMMA_TILE_SIZE_K == 0U); | |
// Declare the fragments. | |
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, WMMA_TILE_SIZE_Y, | |
WMMA_TILE_SIZE_X, WMMA_TILE_SIZE_K, T, | |
nvcuda::wmma::col_major> | |
a_frags[NUM_WMMA_TILES_Y]; | |
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, WMMA_TILE_SIZE_Y, | |
WMMA_TILE_SIZE_X, WMMA_TILE_SIZE_K, T, | |
nvcuda::wmma::row_major> | |
b_frags[NUM_WMMA_TILES_X]; | |
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMA_TILE_SIZE_Y, | |
WMMA_TILE_SIZE_X, WMMA_TILE_SIZE_K, T> | |
acc_frags[NUM_WMMA_TILES_Y][NUM_WMMA_TILES_X]; | |
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMA_TILE_SIZE_Y, | |
WMMA_TILE_SIZE_X, WMMA_TILE_SIZE_K, T> | |
c_frag; | |
// Make sure the accumulator starts from 0. | |
#pragma unroll | |
for (size_t wmma_tile_row_idx{0U}; wmma_tile_row_idx < NUM_WMMA_TILES_Y; | |
++wmma_tile_row_idx) | |
{ | |
for (size_t wmma_tile_col_idx{0U}; wmma_tile_col_idx < NUM_WMMA_TILES_X; | |
++wmma_tile_col_idx) | |
{ | |
nvcuda::wmma::fill_fragment( | |
acc_frags[wmma_tile_row_idx][wmma_tile_col_idx], | |
static_cast<T>(0)); | |
} | |
} | |
size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; | |
size_t const warp_linear_idx{thread_linear_idx / 32U}; | |
size_t const warp_row_idx{warp_linear_idx / NUM_WARPS_X}; | |
size_t const warp_col_idx{warp_linear_idx % NUM_WARPS_X}; | |
// Number of outer loops to perform the sum of inner products. | |
// C_thread_block_tile = | |
// \sigma_{thread_block_tile_idx=0}^{num_thread_block_tiles-1} A[:, | |
// thread_block_tile_idx:BLOCK_TILE_SIZE_K] * | |
// B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] | |
size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / | |
BLOCK_TILE_SIZE_K}; | |
for (size_t thread_block_tile_idx{0U}; | |
thread_block_tile_idx < num_thread_block_tiles; | |
++thread_block_tile_idx) | |
{ | |
load_data_to_shared_memory_transposed_vectorized< | |
T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, | |
NUM_THREADS, BLOCK_TILE_SKEW_SIZE_X, BLOCK_TILE_SKEW_SIZE_Y>( | |
A, lda, B, ldb, A_thread_block_tile_transposed, B_thread_block_tile, | |
thread_block_tile_idx, thread_linear_idx, m, n, k); | |
__syncthreads(); | |
// Perform A[:, thread_block_tile_idx:BLOCK_TILE_SIZE_K] * | |
// B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] where A[:, | |
// thread_block_tile_idx:BLOCK_TILE_SIZE_K] and | |
// B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] are cached in the | |
// shared memory as A_thread_block_tile and B_thread_block_tile, | |
// respectively. This inner product is further decomposed to | |
// BLOCK_TILE_SIZE_K outer products. A_thread_block_tile * | |
// B_thread_block_tile = \sigma_{k_i=0}^{BLOCK_TILE_SIZE_K-1} | |
// A_thread_block_tile[:, k_i] @ B_thread_block_tile[k_i, :] Note that | |
// both A_thread_block_tile and B_thread_block_tile can be cached in the | |
// register. | |
#pragma unroll | |
for (size_t k_i{0U}; k_i < NUM_WMMA_TILES_K; ++k_i) | |
{ | |
#pragma unroll | |
for (size_t wmma_tile_row_idx{0U}; | |
wmma_tile_row_idx < NUM_WMMA_TILES_Y; ++wmma_tile_row_idx) | |
{ | |
nvcuda::wmma::load_matrix_sync( | |
a_frags[wmma_tile_row_idx], | |
&A_thread_block_tile_transposed[k_i * WMMA_TILE_SIZE_K] | |
[warp_row_idx * | |
WARP_TILE_SIZE_Y + | |
wmma_tile_row_idx * | |
WMMA_TILE_SIZE_Y], | |
BLOCK_TILE_SIZE_Y + BLOCK_TILE_SKEW_SIZE_Y); | |
#pragma unroll | |
for (size_t wmma_tile_col_idx{0U}; | |
wmma_tile_col_idx < NUM_WMMA_TILES_X; ++wmma_tile_col_idx) | |
{ | |
// These loads are extremely slow somehow, which affects the | |
// performance a lot. Load the fragment from shared memory. | |
nvcuda::wmma::load_matrix_sync( | |
b_frags[wmma_tile_col_idx], | |
&B_thread_block_tile[k_i * WMMA_TILE_SIZE_K] | |
[warp_col_idx * WARP_TILE_SIZE_X + | |
wmma_tile_col_idx * | |
WMMA_TILE_SIZE_Y], | |
BLOCK_TILE_SIZE_X + BLOCK_TILE_SKEW_SIZE_X); | |
// Perform the matrix multiplication. | |
nvcuda::wmma::mma_sync( | |
acc_frags[wmma_tile_row_idx][wmma_tile_col_idx], | |
a_frags[wmma_tile_row_idx], b_frags[wmma_tile_col_idx], | |
acc_frags[wmma_tile_row_idx][wmma_tile_col_idx]); | |
} | |
} | |
} | |
__syncthreads(); | |
} | |
// Write the results to DRAM. | |
#pragma unroll | |
for (size_t wmma_tile_row_idx{0U}; wmma_tile_row_idx < NUM_WMMA_TILES_Y; | |
++wmma_tile_row_idx) | |
{ | |
#pragma unroll | |
for (size_t wmma_tile_col_idx{0U}; wmma_tile_col_idx < NUM_WMMA_TILES_X; | |
++wmma_tile_col_idx) | |
{ | |
// Load the fragment from shared memory. | |
nvcuda::wmma::load_matrix_sync( | |
c_frag, | |
&C[(blockIdx.y * BLOCK_TILE_SIZE_Y + | |
warp_row_idx * WARP_TILE_SIZE_Y + | |
wmma_tile_row_idx * WMMA_TILE_SIZE_Y) * | |
n + | |
blockIdx.x * BLOCK_TILE_SIZE_X + | |
warp_col_idx * WARP_TILE_SIZE_X + | |
wmma_tile_col_idx * WMMA_TILE_SIZE_X], | |
n, nvcuda::wmma::mem_row_major); | |
// Perform scaling and addition. | |
for (size_t i{0}; i < c_frag.num_elements; ++i) | |
{ | |
c_frag.x[i] = | |
alpha * | |
acc_frags[wmma_tile_row_idx][wmma_tile_col_idx].x[i] + | |
beta * c_frag.x[i]; | |
} | |
// Store the fragment back to shared memory. | |
nvcuda::wmma::store_matrix_sync( | |
&C[(blockIdx.y * BLOCK_TILE_SIZE_Y + | |
warp_row_idx * WARP_TILE_SIZE_Y + | |
wmma_tile_row_idx * WMMA_TILE_SIZE_Y) * | |
n + | |
blockIdx.x * BLOCK_TILE_SIZE_X + | |
warp_col_idx * WARP_TILE_SIZE_X + | |
wmma_tile_col_idx * WMMA_TILE_SIZE_X], | |
c_frag, n, nvcuda::wmma::mem_row_major); | |
} | |
} | |
} | |
template <typename T> | |
void launch_gemm_kernel_v07_vectorized(size_t m, size_t n, size_t k, | |
T const* alpha, T const* A, size_t lda, | |
T const* B, size_t ldb, T const* beta, | |
T* C, size_t ldc, cudaStream_t stream) | |
{ | |
// Feel free to play with the block tile sizes. | |
// The algorithm correctness should always be guaranteed. | |
constexpr unsigned int BLOCK_TILE_SIZE_X{128U}; | |
constexpr unsigned int BLOCK_TILE_SIZE_Y{128U}; | |
constexpr unsigned int BLOCK_TILE_SIZE_K{16U}; | |
// The skew size is used to avoid bank conflicts in shared memory. | |
constexpr size_t BLOCK_TILE_SKEW_SIZE_X{16U}; | |
constexpr size_t BLOCK_TILE_SKEW_SIZE_Y{16U}; | |
constexpr unsigned int WARP_TILE_SIZE_X{32U}; | |
constexpr unsigned int WARP_TILE_SIZE_Y{64U}; | |
constexpr unsigned int NUM_WARPS_X{BLOCK_TILE_SIZE_X / WARP_TILE_SIZE_X}; | |
constexpr unsigned int NUM_WARPS_Y{BLOCK_TILE_SIZE_Y / WARP_TILE_SIZE_Y}; | |
static_assert(BLOCK_TILE_SIZE_X % WARP_TILE_SIZE_X == 0U); | |
static_assert(BLOCK_TILE_SIZE_Y % WARP_TILE_SIZE_Y == 0U); | |
constexpr unsigned int WMMA_TILE_SIZE_X{16U}; | |
constexpr unsigned int WMMA_TILE_SIZE_Y{16U}; | |
constexpr unsigned int WMMA_TILE_SIZE_K{16U}; | |
constexpr unsigned int NUM_THREADS_PER_BLOCK{NUM_WARPS_X * NUM_WARPS_Y * | |
32U}; | |
dim3 const block_dim{NUM_THREADS_PER_BLOCK, 1U, 1U}; | |
dim3 const grid_dim{ | |
(static_cast<unsigned int>(n) + BLOCK_TILE_SIZE_X - 1U) / | |
BLOCK_TILE_SIZE_X, | |
(static_cast<unsigned int>(m) + BLOCK_TILE_SIZE_Y - 1U) / | |
BLOCK_TILE_SIZE_Y, | |
1U}; | |
gemm_v07_vectorized<T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, | |
BLOCK_TILE_SIZE_K, BLOCK_TILE_SKEW_SIZE_X, | |
BLOCK_TILE_SKEW_SIZE_Y, WARP_TILE_SIZE_X, | |
WARP_TILE_SIZE_Y, WMMA_TILE_SIZE_X, WMMA_TILE_SIZE_Y, | |
WMMA_TILE_SIZE_K, NUM_THREADS_PER_BLOCK> | |
<<<grid_dim, block_dim, 0U, stream>>>(m, n, k, *alpha, A, lda, B, ldb, | |
*beta, C, ldc); | |
CHECK_LAST_CUDA_ERROR(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment