Created
February 16, 2024 00:23
-
-
Save chengscott/d97ca744c72c8b8cd27abbecbe3d48f7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
__device__ int warpInclusiveScan(int val) { | |
int laneId = threadIdx.x % warpSize; | |
for (int offset = 1; offset < 32; offset <<= 1) { | |
int v = __shfl_up_sync(0xffffffff, val, offset); | |
if (laneId >= offset) val += v; | |
} | |
return val; | |
} | |
__global__ void inclusivePrefixSumKernel(int *input, int *output, int n) { | |
__shared__ int warpSums[32]; | |
int tx = threadIdx.x; | |
int warpId = threadIdx.x / warpSize; | |
int laneId = threadIdx.x % warpSize; | |
// intra-warp | |
int v = warpInclusiveScan(tx < n ? input[tx] : 0); | |
if (laneId == 31) { | |
warpSums[warpId] = v; | |
} | |
__syncthreads(); | |
// inter-warp | |
if (warpId == 0) { | |
int is_active = laneId < (blockDim.x / warpSize); | |
int warpSum = warpInclusiveScan(is_active ? warpSums[laneId] : 0); | |
if (is_active) | |
warpSums[laneId] = warpSum; | |
} | |
__syncthreads(); | |
// done | |
if (warpId > 0) { | |
v += warpSums[warpId - 1]; | |
} | |
if (tx < n) { | |
output[tx] = v; | |
} | |
} | |
int main() { | |
const int n = 1024; | |
int h_input[n], h_output[n]; | |
// Initialize input array | |
for (int i = 0; i < n; ++i) { | |
h_input[i] = 1; | |
} | |
int *d_input, *d_output; | |
cudaMalloc(&d_input, n * sizeof(int)); | |
cudaMalloc(&d_output, n * sizeof(int)); | |
cudaMemcpy(d_input, h_input, n * sizeof(int), cudaMemcpyHostToDevice); | |
// Execute kernel | |
inclusivePrefixSumKernel<<<1, n>>>(d_input, d_output, n); | |
cudaMemcpy(h_output, d_output, n * sizeof(int), cudaMemcpyDeviceToHost); | |
// Print the result | |
for (int i = 0; i < n; ++i) { | |
std::cout << h_output[i] << " "; | |
} | |
std::cout << std::endl; | |
cudaFree(d_input); | |
cudaFree(d_output); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment