Created
April 27, 2025 22:54
-
-
Save pashu123/2bdcc6774c767e08e532da37709372b4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <hip/hip_runtime.h> | |
#include <hip/hip_bf16.h> | |
#include <iostream> | |
#include <vector> | |
#include <random> | |
#define HIP_CHECK(err) hipAssert(err, __FILE__, __LINE__) | |
inline void hipAssert(hipError_t err, const char* file, int line) { | |
if (err != hipSuccess) { | |
std::cerr << "HIP error: " << hipGetErrorString(err) << " at " << file << ":" << line << std::endl; | |
exit(EXIT_FAILURE); | |
} | |
} | |
constexpr int BLOCK_SIZE = 256; | |
__global__ void argmax_bf16_kernel(const __hip_bfloat16* data, int size, int* block_results) { | |
__shared__ float sdata[BLOCK_SIZE]; | |
__shared__ int sidx[BLOCK_SIZE]; | |
int tid = threadIdx.x; | |
int gid = blockIdx.x * blockDim.x + tid; | |
// Load and convert to float for comparison | |
float val = (gid < size) ? static_cast<float>(data[gid]) : -INFINITY; | |
sdata[tid] = val; | |
sidx[tid] = (gid < size) ? gid : -1; | |
__syncthreads(); | |
// Parallel reduction | |
for (int s = blockDim.x/2; s > 0; s >>= 1) { | |
if (tid < s) { | |
if (sdata[tid + s] > sdata[tid]) { | |
sdata[tid] = sdata[tid + s]; | |
sidx[tid] = sidx[tid + s]; | |
} | |
} | |
__syncthreads(); | |
} | |
if (tid == 0) { | |
block_results[blockIdx.x] = sidx[0]; | |
} | |
} | |
int hip_argmax_bf16(const __hip_bfloat16* data, int size) { | |
const int block_size = BLOCK_SIZE; | |
int grid_size = (size + block_size - 1) / block_size; | |
__hip_bfloat16* d_data; | |
int* d_block_results; | |
HIP_CHECK(hipMalloc(&d_data, size * sizeof(__hip_bfloat16))); | |
HIP_CHECK(hipMalloc(&d_block_results, grid_size * sizeof(int))); | |
HIP_CHECK(hipMemcpy(d_data, data, size * sizeof(__hip_bfloat16), hipMemcpyHostToDevice)); | |
hipLaunchKernelGGL(argmax_bf16_kernel, | |
dim3(grid_size), dim3(block_size), | |
sizeof(float)*block_size + sizeof(int)*block_size, 0, | |
d_data, size, d_block_results); | |
std::vector<int> block_results(grid_size); | |
HIP_CHECK(hipMemcpy(block_results.data(), d_block_results, | |
grid_size * sizeof(int), hipMemcpyDeviceToHost)); | |
// Final host-side reduction | |
int max_index = block_results[0]; | |
float max_val = static_cast<float>(data[max_index]); | |
for (int i = 1; i < grid_size; i++) { | |
float current = static_cast<float>(data[block_results[i]]); | |
if (current > max_val) { | |
max_val = current; | |
max_index = block_results[i]; | |
} | |
} | |
HIP_CHECK(hipFree(d_data)); | |
HIP_CHECK(hipFree(d_block_results)); | |
return max_index; | |
} | |
int main() { | |
const int N = 1 << 17; // 1M elements | |
std::vector<__hip_bfloat16> data(N); | |
// Generate random data | |
std::mt19937 gen(42); | |
std::uniform_real_distribution<float> dist(-1000.0f, 1000.0f); | |
for (auto& val : data) val = static_cast<__hip_bfloat16>(dist(gen)); | |
// CPU reference | |
int cpu_argmax = 0; | |
float max_val = static_cast<float>(data[0]); | |
for (int i = 1; i < N; ++i) { | |
float current = static_cast<float>(data[i]); | |
if (current > max_val) { | |
max_val = current; | |
cpu_argmax = i; | |
} | |
} | |
// GPU implementation | |
hipEvent_t start, stop; | |
HIP_CHECK(hipEventCreate(&start)); | |
HIP_CHECK(hipEventCreate(&stop)); | |
// Warm-up | |
int gpu_argmax = hip_argmax_bf16(data.data(), N); | |
// Benchmark | |
const int trials = 100; | |
HIP_CHECK(hipEventRecord(start)); | |
for (int i = 0; i < trials; i++) { | |
gpu_argmax = hip_argmax_bf16(data.data(), N); | |
} | |
HIP_CHECK(hipEventRecord(stop)); | |
HIP_CHECK(hipEventSynchronize(stop)); | |
float milliseconds; | |
HIP_CHECK(hipEventElapsedTime(&milliseconds, start, stop)); | |
// Verify | |
if (static_cast<float>(data[cpu_argmax]) != static_cast<float>(data[gpu_argmax])) { | |
std::cerr << "Validation failed!" << std::endl; | |
return EXIT_FAILURE; | |
} | |
std::cout << "Results match!" << std::endl; | |
std::cout << "Average execution time: " << milliseconds/trials << " ms" << std::endl; | |
HIP_CHECK(hipEventDestroy(start)); | |
HIP_CHECK(hipEventDestroy(stop)); | |
return EXIT_SUCCESS; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment