Skip to content

Instantly share code, notes, and snippets.

@eqy
Last active July 5, 2025 05:04
Show Gist options
  • Save eqy/c8f111e684dc316de78581f4b4d2296b to your computer and use it in GitHub Desktop.
Save eqy/c8f111e684dc316de78581f4b4d2296b to your computer and use it in GitHub Desktop.
softmax.cu
import torch
inp = torch.rand(4096, 65536, device='cuda')
for _ in range(10):
torch.softmax(inp, dim=1)
#include <cuda_runtime.h>
#include <math_constants.h>
#include <curand_kernel.h>
#include <stdio.h>
#include <cuda_runtime.h>
#include <math_constants.h>
// Numerically stable softmax kernel (three-pass)
extern "C" __global__ void softmax_kernel(
const float* input,
float* output,
int batch_size,
int num_features)
{
int batch_idx = blockIdx.x;
int tid = threadIdx.x;
if (batch_idx >= batch_size) return;
const float* input_row = input + batch_idx * num_features;
float* output_row = output + batch_idx * num_features;
// Shared memory for reduction operations
extern __shared__ float shared_mem[];
float* max_vals = shared_mem;
float* sum_vals = shared_mem + blockDim.x;
// Pass 1: find maximum for numerical stability
float thread_max = -CUDART_INF_F;
for (int i = tid; i < num_features; i += blockDim.x) {
thread_max = fmaxf(thread_max, input_row[i]);
}
max_vals[tid] = thread_max;
__syncthreads();
// Reduce to global max
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
max_vals[tid] = fmaxf(max_vals[tid], max_vals[tid + stride]);
}
__syncthreads();
}
float row_max = max_vals[0];
__syncthreads();
// Pass 2: compute exponentials and local sum
float thread_sum = 0.0f;
for (int i = tid; i < num_features; i += blockDim.x) {
float exp_val = expf(input_row[i] - row_max);
output_row[i] = exp_val;
thread_sum += exp_val;
}
sum_vals[tid] = thread_sum;
__syncthreads();
// Reduce to global sum
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
sum_vals[tid] += sum_vals[tid + stride];
}
__syncthreads();
}
float row_sum = sum_vals[0];
__syncthreads();
// Pass 3: normalise
for (int i = tid; i < num_features; i += blockDim.x) {
output_row[i] /= row_sum;
}
}
__global__ void fill_rand(float* out, size_t n, unsigned long long seed = 1234)
{
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= n) return;
// one RNG state per thread, init cheaply
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, /*subseq=*/0, &state);
out[idx] = curand_uniform(&state); // in (0,1]
}
// host helper: launch & synctiny
void gen_random(float* d_out, size_t n)
{
int threads = 256;
int blocks = (n + threads - 1) / threads;
fill_rand<<<blocks, threads>>>(d_out, n);
cudaDeviceSynchronize();
}
float* input;
float* output;
size_t batch_size = 4096;
size_t num_features = 65536;
size_t n = batch_size*num_features;
size_t threads = 1024;
cudaMalloc(&input, n*sizeof(float));
cudaMalloc(&output, n*sizeof(float));
gen_random(input, n);
cudaDeviceSynchronize();
printf("CUDA Error: %s\n", cudaGetErrorString(cudaGetLastError()));
for (int i = 0; i < 10; i++)
softmax_kernel<<<batch_size, threads, 2*threads*sizeof(float)>>>(input, output, batch_size, num_features);
cudaDeviceSynchronize();
printf("CUDA Error: %s\n", cudaGetErrorString(cudaGetLastError()));
cudaFree(input);
cudaFree(output);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment