Skip to content

Instantly share code, notes, and snippets.

@gmarkall
Last active June 9, 2023 16:45
Show Gist options
  • Save gmarkall/79226342a330a06ce81fe1096318909c to your computer and use it in GitHub Desktop.
Save gmarkall/79226342a330a06ce81fe1096318909c to your computer and use it in GitHub Desktop.
Numba CUDA Warp-aggregated atomics example. See PR #6911: https://github.com/numba/numba/pull/6911
$ python wagg.py
Running with 16777216 elements, of which approximately 25.0% are zero
There are 12584753 nonzeroes in:
[0.417022 0.72032449 0. ... 0.20570723 0.36716537 0.0979951 ]
The kernel found 12584753 elements, resulting in the array:
[0.14349547 0.43006714 0.48695992 ... 0. 0. 0. ]
Traceback (most recent call last):
File "wagg.py", line 104, in <module>
np.testing.assert_equal(np.ones(value_count, dtype=np.bool), result[:cuda_n] > 0)
File "/home/gmarkall/miniconda3/envs/numba/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 342, in assert_equal
return assert_array_equal(actual, desired, err_msg, verbose)
File "/home/gmarkall/miniconda3/envs/numba/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 930, in assert_array_equal
assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
File "/home/gmarkall/miniconda3/envs/numba/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 840, in assert_array_compare
raise AssertionError(msg)
AssertionError:
Arrays are not equal
Mismatched elements: 12060435 / 12584753 (95.8%)
x: array([ True, True, True, ..., True, True, True])
y: array([ True, True, True, ..., False, False, False])
from numba import cuda
import numpy as np
# From:
#
# https://developer.nvidia.com/blog/cuda-pro-tip-optimized-filtering-warp-aggregated-atomics/
# __device__ int atomicAggInc(int *ctr) {
# unsigned int active = __activemask();
# int leader = __ffs(active) - 1;
# int change = __popc(active);
# unsigned int rank = __popc(active & __lanemask_lt());
# int warp_res;
# if(rank == 0)
# warp_res = atomicAdd(ctr, change);
# warp_res = __shfl_sync(active, warp_res, leader);
# return warp_res + rank;
# }
@cuda.jit(device=True)
def atomicAggInc(ctr):
active = cuda.activemask()
leader = cuda.ffs(active) - 1
change = cuda.popc(active)
rank = cuda.popc(active & cuda.lanemask_lt())
if rank == 0:
warp_res = cuda.atomic.add(ctr, 0, change)
else:
warp_res = cuda.shfl_sync(active, warp_res, leader)
return warp_res + rank
# There is a bug in the example C code, nres should be passed in:
#
# __global__ void filter_k(int *dst, const int *src, int n) {
# int i = threadIdx.x + blockIdx.x * blockDim.x;
# if(i >= n)
# return;
# if(src[i] > 0)
# dst[atomicAggInc(nres)] = src[i];
# }
@cuda.jit
def filter_k(dst, nres, src):
i = cuda.grid(1)
if i >= len(src):
return
if src[i] > 0:
dst[atomicAggInc(nres)] = src[i]
# Parameters for the run
N = 2 ** 24
zero_factor = 0.25
print(f'Running with {N} elements, of which approximately {zero_factor * 100}%'
' are zero\n')
# Seed the RNG for repeatability
np.random.seed(1)
# Create input data
inputs = np.random.random(N)
zeros = np.zeros(N)
factors = np.random.random(N)
values = np.where(factors > zero_factor, inputs, zeros)
# Quick summary of the data
value_count = np.sum(values > 0)
print(f'There are {value_count} nonzeroes in:')
print(values)
print()
# Create outputs for kernel
nres = np.zeros(1, dtype=np.uint32)
result = np.zeros_like(values)
# Compute grid dimensions and launch kernel
n_threads = 128
n_blocks = N // n_threads
filter_k[n_blocks, n_threads](result, nres, values)
# Summarize the kernel output
cuda_n = nres[0]
print(f'The kernel found {cuda_n} elements, resulting in the array:')
print(result)
print()
#breakpoint()
# Some sanity checking:
# Did we filter the expected number of values?
np.testing.assert_equal(value_count, nres[0])
# Are the first cuda_n elements all nonzero?
np.testing.assert_equal(np.ones(value_count, dtype=np.bool), result[:cuda_n] > 0)
# Were elements after the cuda_nth element left as zero?
np.testing.assert_equal(np.zeros(N - value_count), result[cuda_n:])
print('Sanity checks passed!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment