Created
April 27, 2025 21:03
-
-
Save pashu123/a3480076c26def192c3bf3381bab0c13 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <hip/hip_runtime.h> | |
#include <hip/hip_bf16.h> | |
#include <hip/hip_cooperative_groups.h> | |
#include <iostream> | |
#include <vector> | |
#include <cstdlib> | |
#include <cstdint> | |
// Cooperative-groups namespace | |
namespace cg = cooperative_groups; | |
// Fixed vocabulary size | |
constexpr int VOCAB_SIZE = 131072; | |
// Kernel: compute argmax of each [VOCAB_SIZE]-length row of BF16s | |
__global__ void argmax_bf16(const __hip_bfloat16* input, int* output, int batch_size) { | |
int batch = blockIdx.x; | |
if (batch >= batch_size) return; | |
// Pointer to this row (batch entry) | |
const __hip_bfloat16* row = input + batch * VOCAB_SIZE; | |
// Interpret as 32-bit to load two BF16 values at a time | |
const uint32_t* row32 = reinterpret_cast<const uint32_t*>(row); | |
int tid = threadIdx.x; | |
// Number of 32-bit elements (2 BF16 per element) | |
const int N = VOCAB_SIZE / 2; | |
__hip_bfloat16 local_max; | |
int local_idx; | |
// Each thread processes a strided sequence of 2-value loads | |
int i = tid; | |
if (i < N) { | |
// Load first two BF16 values | |
uint32_t packed = row32[i]; | |
__hip_bfloat16 v0, v1; | |
v0 = packed & 0xFFFF; | |
v1= packed >> 16; | |
// Initialize local maximum with v0 | |
local_max = v0; | |
local_idx = 2 * i; | |
// Compare v1 as well | |
if (v1 > local_max) { | |
local_max = v1; | |
local_idx = 2 * i + 1; | |
} | |
// Continue scanning strided elements | |
for (i += blockDim.x; i < N; i += blockDim.x) { | |
packed = row32[i]; | |
v0 = packed & 0xFFFF; | |
v1 = packed >> 16; | |
if (v0 > local_max) { | |
local_max = v0; | |
local_idx = 2 * i; | |
} | |
if (v1 > local_max) { | |
local_max = v1; | |
local_idx = 2 * i + 1; | |
} | |
} | |
} else { | |
// Thread beyond data range; set minimal BF16 value (−0.0) | |
local_max = 0x8000; | |
local_idx = tid * 2; | |
} | |
// Warp-level reduction using cooperative groups (64-thread tile for GFX942) | |
auto block_group = cg::this_thread_block(); | |
auto tile64 = cg::tiled_partition<64>(block_group); | |
// Convert local max to float for shuffle reduction | |
float max_val_f = static_cast<float>(local_max); | |
int max_idx = local_idx; | |
// Warp-wide shuffle reduction (shfl_down) for max | |
for (int offset = 32; offset > 0; offset >>= 1) { | |
float other = tile64.shfl_down(max_val_f, offset); | |
int other_idx = tile64.shfl_down(max_idx, offset); | |
if (other > max_val_f) { | |
max_val_f = other; | |
max_idx = other_idx; | |
} | |
} | |
// Shared memory for each warp’s result | |
__shared__ float shared_vals[8]; // supports up to 8 warps (blockDim≤512) | |
__shared__ int shared_idxs[8]; | |
// Lane 0 of each warp (tile) writes its warp-max to shared memory | |
unsigned int warp_id = tile64.meta_group_rank(); // warp index in block | |
if (tile64.thread_rank() == 0) { | |
shared_vals[warp_id] = max_val_f; | |
shared_idxs[warp_id] = max_idx; | |
} | |
__syncthreads(); | |
// Final reduction over warps (done by thread 0 of block) | |
if (threadIdx.x == 0) { | |
int numWarps = (blockDim.x + 63) / 64; | |
float block_max = shared_vals[0]; | |
int block_idx = shared_idxs[0]; | |
for (int w = 1; w < numWarps; ++w) { | |
float val = shared_vals[w]; | |
if (val > block_max) { | |
block_max = val; | |
block_idx = shared_idxs[w]; | |
} | |
} | |
output[batch] = block_idx; | |
} | |
} | |
int main() { | |
// Test batch sizes 1 through 4 | |
const int batch_sizes[] = {1, 2, 3, 4}; | |
// Timing events for kernel execution | |
hipEvent_t start, stop; | |
hipEventCreate(&start); | |
hipEventCreate(&stop); | |
for (int bs : batch_sizes) { | |
int batch = bs; | |
size_t total_elems = static_cast<size_t>(batch) * VOCAB_SIZE; | |
// Allocate and initialize host input (random floats converted to BF16) | |
std::vector<__hip_bfloat16> h_input(total_elems); | |
for (size_t i = 0; i < total_elems; ++i) { | |
float val = static_cast<float>(rand()) / RAND_MAX; | |
h_input[i] = __hip_bfloat16(val); | |
} | |
// Allocate device memory | |
__hip_bfloat16* d_input = nullptr; | |
int* d_output = nullptr; | |
hipMalloc(&d_input, total_elems * sizeof(__hip_bfloat16)); | |
hipMalloc(&d_output, batch * sizeof(int)); | |
// Copy input BF16 data to device | |
hipMemcpy(d_input, h_input.data(), | |
total_elems * sizeof(__hip_bfloat16), | |
hipMemcpyHostToDevice); | |
// Launch kernel: one block per row, 256 threads (4 warps) | |
dim3 block(256); | |
dim3 grid(batch); | |
hipEventRecord(start, 0); | |
hipLaunchKernelGGL(argmax_bf16, grid, block, 0, 0, | |
d_input, d_output, batch); | |
hipEventRecord(stop, 0); | |
hipEventSynchronize(stop); | |
// Retrieve and print results | |
std::vector<int> h_output(batch); | |
hipMemcpy(h_output.data(), d_output, | |
batch * sizeof(int), | |
hipMemcpyDeviceToHost); | |
float time_ms; | |
hipEventElapsedTime(&time_ms, start, stop); | |
std::cout << "Batch size " << batch | |
<< ": Argmax indices ="; | |
for (int i = 0; i < batch; ++i) { | |
std::cout << " " << h_output[i]; | |
} | |
std::cout << "; Kernel time = " << time_ms << " ms\n"; | |
hipFree(d_input); | |
hipFree(d_output); | |
} | |
hipEventDestroy(start); | |
hipEventDestroy(stop); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment