Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created April 27, 2025 21:03
Show Gist options
  • Save pashu123/a3480076c26def192c3bf3381bab0c13 to your computer and use it in GitHub Desktop.
Save pashu123/a3480076c26def192c3bf3381bab0c13 to your computer and use it in GitHub Desktop.
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>
#include <hip/hip_cooperative_groups.h>
#include <iostream>
#include <vector>
#include <cstdlib>
#include <cstdint>
// Cooperative-groups namespace
namespace cg = cooperative_groups;
// Fixed vocabulary size
constexpr int VOCAB_SIZE = 131072;
// Kernel: compute argmax of each [VOCAB_SIZE]-length row of BF16s
__global__ void argmax_bf16(const __hip_bfloat16* input, int* output, int batch_size) {
int batch = blockIdx.x;
if (batch >= batch_size) return;
// Pointer to this row (batch entry)
const __hip_bfloat16* row = input + batch * VOCAB_SIZE;
// Interpret as 32-bit to load two BF16 values at a time
const uint32_t* row32 = reinterpret_cast<const uint32_t*>(row);
int tid = threadIdx.x;
// Number of 32-bit elements (2 BF16 per element)
const int N = VOCAB_SIZE / 2;
__hip_bfloat16 local_max;
int local_idx;
// Each thread processes a strided sequence of 2-value loads
int i = tid;
if (i < N) {
// Load first two BF16 values
uint32_t packed = row32[i];
__hip_bfloat16 v0, v1;
v0 = packed & 0xFFFF;
v1= packed >> 16;
// Initialize local maximum with v0
local_max = v0;
local_idx = 2 * i;
// Compare v1 as well
if (v1 > local_max) {
local_max = v1;
local_idx = 2 * i + 1;
}
// Continue scanning strided elements
for (i += blockDim.x; i < N; i += blockDim.x) {
packed = row32[i];
v0 = packed & 0xFFFF;
v1 = packed >> 16;
if (v0 > local_max) {
local_max = v0;
local_idx = 2 * i;
}
if (v1 > local_max) {
local_max = v1;
local_idx = 2 * i + 1;
}
}
} else {
// Thread beyond data range; set minimal BF16 value (−0.0)
local_max = 0x8000;
local_idx = tid * 2;
}
// Warp-level reduction using cooperative groups (64-thread tile for GFX942)
auto block_group = cg::this_thread_block();
auto tile64 = cg::tiled_partition<64>(block_group);
// Convert local max to float for shuffle reduction
float max_val_f = static_cast<float>(local_max);
int max_idx = local_idx;
// Warp-wide shuffle reduction (shfl_down) for max
for (int offset = 32; offset > 0; offset >>= 1) {
float other = tile64.shfl_down(max_val_f, offset);
int other_idx = tile64.shfl_down(max_idx, offset);
if (other > max_val_f) {
max_val_f = other;
max_idx = other_idx;
}
}
// Shared memory for each warp’s result
__shared__ float shared_vals[8]; // supports up to 8 warps (blockDim≤512)
__shared__ int shared_idxs[8];
// Lane 0 of each warp (tile) writes its warp-max to shared memory
unsigned int warp_id = tile64.meta_group_rank(); // warp index in block
if (tile64.thread_rank() == 0) {
shared_vals[warp_id] = max_val_f;
shared_idxs[warp_id] = max_idx;
}
__syncthreads();
// Final reduction over warps (done by thread 0 of block)
if (threadIdx.x == 0) {
int numWarps = (blockDim.x + 63) / 64;
float block_max = shared_vals[0];
int block_idx = shared_idxs[0];
for (int w = 1; w < numWarps; ++w) {
float val = shared_vals[w];
if (val > block_max) {
block_max = val;
block_idx = shared_idxs[w];
}
}
output[batch] = block_idx;
}
}
int main() {
// Test batch sizes 1 through 4
const int batch_sizes[] = {1, 2, 3, 4};
// Timing events for kernel execution
hipEvent_t start, stop;
hipEventCreate(&start);
hipEventCreate(&stop);
for (int bs : batch_sizes) {
int batch = bs;
size_t total_elems = static_cast<size_t>(batch) * VOCAB_SIZE;
// Allocate and initialize host input (random floats converted to BF16)
std::vector<__hip_bfloat16> h_input(total_elems);
for (size_t i = 0; i < total_elems; ++i) {
float val = static_cast<float>(rand()) / RAND_MAX;
h_input[i] = __hip_bfloat16(val);
}
// Allocate device memory
__hip_bfloat16* d_input = nullptr;
int* d_output = nullptr;
hipMalloc(&d_input, total_elems * sizeof(__hip_bfloat16));
hipMalloc(&d_output, batch * sizeof(int));
// Copy input BF16 data to device
hipMemcpy(d_input, h_input.data(),
total_elems * sizeof(__hip_bfloat16),
hipMemcpyHostToDevice);
// Launch kernel: one block per row, 256 threads (4 warps)
dim3 block(256);
dim3 grid(batch);
hipEventRecord(start, 0);
hipLaunchKernelGGL(argmax_bf16, grid, block, 0, 0,
d_input, d_output, batch);
hipEventRecord(stop, 0);
hipEventSynchronize(stop);
// Retrieve and print results
std::vector<int> h_output(batch);
hipMemcpy(h_output.data(), d_output,
batch * sizeof(int),
hipMemcpyDeviceToHost);
float time_ms;
hipEventElapsedTime(&time_ms, start, stop);
std::cout << "Batch size " << batch
<< ": Argmax indices =";
for (int i = 0; i < batch; ++i) {
std::cout << " " << h_output[i];
}
std::cout << "; Kernel time = " << time_ms << " ms\n";
hipFree(d_input);
hipFree(d_output);
}
hipEventDestroy(start);
hipEventDestroy(stop);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment