Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created April 27, 2025 22:54
Show Gist options
  • Save pashu123/2bdcc6774c767e08e532da37709372b4 to your computer and use it in GitHub Desktop.
Save pashu123/2bdcc6774c767e08e532da37709372b4 to your computer and use it in GitHub Desktop.
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>
#include <iostream>
#include <vector>
#include <random>
#define HIP_CHECK(err) hipAssert(err, __FILE__, __LINE__)
inline void hipAssert(hipError_t err, const char* file, int line) {
if (err != hipSuccess) {
std::cerr << "HIP error: " << hipGetErrorString(err) << " at " << file << ":" << line << std::endl;
exit(EXIT_FAILURE);
}
}
constexpr int BLOCK_SIZE = 256;
__global__ void argmax_bf16_kernel(const __hip_bfloat16* data, int size, int* block_results) {
__shared__ float sdata[BLOCK_SIZE];
__shared__ int sidx[BLOCK_SIZE];
int tid = threadIdx.x;
int gid = blockIdx.x * blockDim.x + tid;
// Load and convert to float for comparison
float val = (gid < size) ? static_cast<float>(data[gid]) : -INFINITY;
sdata[tid] = val;
sidx[tid] = (gid < size) ? gid : -1;
__syncthreads();
// Parallel reduction
for (int s = blockDim.x/2; s > 0; s >>= 1) {
if (tid < s) {
if (sdata[tid + s] > sdata[tid]) {
sdata[tid] = sdata[tid + s];
sidx[tid] = sidx[tid + s];
}
}
__syncthreads();
}
if (tid == 0) {
block_results[blockIdx.x] = sidx[0];
}
}
int hip_argmax_bf16(const __hip_bfloat16* data, int size) {
const int block_size = BLOCK_SIZE;
int grid_size = (size + block_size - 1) / block_size;
__hip_bfloat16* d_data;
int* d_block_results;
HIP_CHECK(hipMalloc(&d_data, size * sizeof(__hip_bfloat16)));
HIP_CHECK(hipMalloc(&d_block_results, grid_size * sizeof(int)));
HIP_CHECK(hipMemcpy(d_data, data, size * sizeof(__hip_bfloat16), hipMemcpyHostToDevice));
hipLaunchKernelGGL(argmax_bf16_kernel,
dim3(grid_size), dim3(block_size),
sizeof(float)*block_size + sizeof(int)*block_size, 0,
d_data, size, d_block_results);
std::vector<int> block_results(grid_size);
HIP_CHECK(hipMemcpy(block_results.data(), d_block_results,
grid_size * sizeof(int), hipMemcpyDeviceToHost));
// Final host-side reduction
int max_index = block_results[0];
float max_val = static_cast<float>(data[max_index]);
for (int i = 1; i < grid_size; i++) {
float current = static_cast<float>(data[block_results[i]]);
if (current > max_val) {
max_val = current;
max_index = block_results[i];
}
}
HIP_CHECK(hipFree(d_data));
HIP_CHECK(hipFree(d_block_results));
return max_index;
}
int main() {
const int N = 1 << 17; // 1M elements
std::vector<__hip_bfloat16> data(N);
// Generate random data
std::mt19937 gen(42);
std::uniform_real_distribution<float> dist(-1000.0f, 1000.0f);
for (auto& val : data) val = static_cast<__hip_bfloat16>(dist(gen));
// CPU reference
int cpu_argmax = 0;
float max_val = static_cast<float>(data[0]);
for (int i = 1; i < N; ++i) {
float current = static_cast<float>(data[i]);
if (current > max_val) {
max_val = current;
cpu_argmax = i;
}
}
// GPU implementation
hipEvent_t start, stop;
HIP_CHECK(hipEventCreate(&start));
HIP_CHECK(hipEventCreate(&stop));
// Warm-up
int gpu_argmax = hip_argmax_bf16(data.data(), N);
// Benchmark
const int trials = 100;
HIP_CHECK(hipEventRecord(start));
for (int i = 0; i < trials; i++) {
gpu_argmax = hip_argmax_bf16(data.data(), N);
}
HIP_CHECK(hipEventRecord(stop));
HIP_CHECK(hipEventSynchronize(stop));
float milliseconds;
HIP_CHECK(hipEventElapsedTime(&milliseconds, start, stop));
// Verify
if (static_cast<float>(data[cpu_argmax]) != static_cast<float>(data[gpu_argmax])) {
std::cerr << "Validation failed!" << std::endl;
return EXIT_FAILURE;
}
std::cout << "Results match!" << std::endl;
std::cout << "Average execution time: " << milliseconds/trials << " ms" << std::endl;
HIP_CHECK(hipEventDestroy(start));
HIP_CHECK(hipEventDestroy(stop));
return EXIT_SUCCESS;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment