Created
September 15, 2022 18:04
-
-
Save ahennequ/5bc4ebde0caa2541fc8fc40546652ddd to your computer and use it in GitHub Desktop.
Use this program to find out about tensor core's accumulator warp register layout
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
// Check tensor core's warp register layout | |
// nvcc -arch=sm_75 tensorcore_mapping.cu -o mapping | |
// ./mapping | |
// Define some error checking macros. | |
#define cudaErrCheck(stat) { cudaErrCheck_((stat), __FILE__, __LINE__); } | |
void cudaErrCheck_(cudaError_t stat, const char *file, int line) { | |
if (stat != cudaSuccess) { | |
fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(stat), file, line); | |
} | |
} | |
#include <mma.h> | |
using namespace nvcuda; | |
__device__ int getWarpRow(int i) { | |
int tid = threadIdx.x % 32; | |
return ((i / 2) % 2) * 8 + tid / 4; | |
} | |
__device__ int getWarpCol(int i) { | |
int tid = threadIdx.x % 32; | |
return (tid % 4) * 2 + i % 2 + (i / 4) * 8; | |
} | |
__global__ void wmma_example(float *elem, float* thread, float* row, float* col) { | |
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag; | |
wmma::fill_fragment(acc_frag, 0.0f); | |
for (int i=0 ; i<acc_frag.num_elements; i++) { | |
acc_frag.x[i] = i; | |
} | |
wmma::store_matrix_sync(elem, acc_frag, 16, wmma::mem_row_major); | |
wmma::fill_fragment(acc_frag, 0.0f); | |
for (int i=0 ; i<acc_frag.num_elements; i++) { | |
acc_frag.x[i] = threadIdx.x; | |
} | |
wmma::store_matrix_sync(thread, acc_frag, 16, wmma::mem_row_major); | |
// row: | |
wmma::fill_fragment(acc_frag, 0.0f); | |
for (int i=0 ; i<acc_frag.num_elements; i++) { | |
acc_frag.x[i] = getWarpRow(i); | |
} | |
wmma::store_matrix_sync(row, acc_frag, 16, wmma::mem_row_major); | |
// col: | |
wmma::fill_fragment(acc_frag, 0.0f); | |
for (int i=0 ; i<acc_frag.num_elements; i++) { | |
acc_frag.x[i] = getWarpCol(i); | |
} | |
wmma::store_matrix_sync(col, acc_frag, 16, wmma::mem_row_major); | |
} | |
int main(int argc, char* argv[]) { | |
float *elem; | |
float *thread; | |
float *row; | |
float *col; | |
float *elem_host; | |
float *thread_host; | |
float *row_host; | |
float *col_host; | |
// Use tensor cores | |
cudaErrCheck(cudaMalloc((void**)&elem, 16 * 16 * sizeof(float))); | |
cudaErrCheck(cudaMalloc((void**)&thread, 16 * 16 * sizeof(float))); | |
cudaErrCheck(cudaMalloc((void**)&row, 16 * 16 * sizeof(float))); | |
cudaErrCheck(cudaMalloc((void**)&col, 16 * 16 * sizeof(float))); | |
elem_host = (float*)malloc(16 * 16 * sizeof(float)); | |
thread_host = (float*)malloc(16 * 16 * sizeof(float)); | |
row_host = (float*)malloc(16 * 16 * sizeof(float)); | |
col_host = (float*)malloc(16 * 16 * sizeof(float)); | |
// First: using WMMA | |
dim3 gridDim(1); | |
dim3 blockDim(32); | |
printf("Running with wmma...\n"); | |
wmma_example <<< gridDim, blockDim >>> (elem, thread, row, col); | |
// Error checking | |
printf("\nChecking results...\n"); | |
cudaErrCheck(cudaMemcpy(elem_host, elem, 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost)); | |
cudaErrCheck(cudaMemcpy(thread_host, thread, 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost)); | |
cudaErrCheck(cudaMemcpy(row_host, row, 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost)); | |
cudaErrCheck(cudaMemcpy(col_host, col, 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost)); | |
printf("Elem:\n"); | |
for (int i=0; i<16 ; i++) { | |
for (int j=0; j<16; j++) { | |
printf("%2d ", (int) elem_host[i*16+j]); | |
} | |
printf("\n"); | |
} | |
printf("ThreadIdx:\n"); | |
for (int i=0; i<16 ; i++) { | |
for (int j=0; j<16; j++) { | |
printf("%2d ", (int) thread_host[i*16+j]); | |
} | |
printf("\n"); | |
} | |
printf("Row:\n"); | |
for (int i=0; i<16 ; i++) { | |
for (int j=0; j<16; j++) { | |
printf("%2d ", (int) row_host[i*16+j]); | |
} | |
printf("\n"); | |
} | |
printf("Col:\n"); | |
for (int i=0; i<16 ; i++) { | |
for (int j=0; j<16; j++) { | |
printf("%2d ", (int) col_host[i*16+j]); | |
} | |
printf("\n"); | |
} | |
cudaErrCheck(cudaFree(elem)); | |
cudaErrCheck(cudaFree(thread)); | |
cudaErrCheck(cudaFree(row)); | |
cudaErrCheck(cudaFree(col)); | |
free(elem_host); | |
free(thread_host); | |
free(row_host); | |
free(col_host); | |
cudaErrCheck(cudaDeviceReset()); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment