Last active
June 9, 2023 13:19
-
-
Save airalcorn2/bba918c17b36442ec67ecfb5126150f9 to your computer and use it in GitHub Desktop.
A minimal example showing a race condition in a Numba CUDA kernel.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import os | |
import torch | |
from numba import cuda, int32, njit | |
@njit((int32[:, :], int32[:, :], int32[:, :], int32[:], int32[:, :], int32[:])) | |
def loop(keys, key_mask, key2idx, cur_idx, idx2key, idx_counts): | |
for row, col in keys: | |
key = (row, col) | |
has_idx = key_mask[key] | |
if not has_idx: | |
key_mask[key] = 1 | |
idx = cur_idx[0] | |
cur_idx[0] += 1 | |
idx2key[idx] = key | |
key2idx[key] = idx | |
idx = key2idx[key] | |
idx_counts[idx] += 1 | |
@cuda.jit( | |
(int32[:, :], int32[:, :], int32[:, :], int32[:], int32[:, :], int32[:], int32) | |
) | |
def race_condition(keys, key_mask, key2idx, cur_idx, idx2key, idx_counts, use_sync): | |
pos = cuda.grid(1) | |
if pos >= len(keys): | |
return | |
(row, col) = keys[pos] | |
key = (row, col) | |
has_idx = cuda.atomic.cas(key_mask, key, 0, 1) | |
if not has_idx: | |
idx = cuda.atomic.add(cur_idx, 0, 1) | |
idx2key[idx] = key | |
key2idx[key] = idx | |
if use_sync: | |
cuda.cg.this_grid().sync() | |
# idx can be -1 because of race condition. | |
idx = key2idx[key] | |
cuda.atomic.add(idx_counts, idx, 1) | |
def main(): | |
N = 70000 | |
grid_size = 10 | |
keys = np.random.randint(grid_size, size=(N, 2), dtype="int32") | |
key_mask = np.zeros((grid_size, grid_size), dtype="int32") | |
key2idx = np.full((grid_size, grid_size), -1, dtype="int32") | |
idx2key = np.full((grid_size * grid_size, 2), -1, dtype="int32") | |
cur_idx = np.zeros(1, dtype="int32") | |
idx_counts = np.zeros(idx2key.shape[0], dtype="int32") | |
loop(keys, key_mask, key2idx, cur_idx, idx2key, idx_counts) | |
true_idx_counts = idx_counts.copy() | |
true_idx_counts.sort() | |
keys = torch.IntTensor(keys).cuda() | |
key_mask = torch.IntTensor(0 * key_mask).cuda() | |
key2idx = torch.IntTensor(0 * key2idx - 1).cuda() | |
idx2key = torch.IntTensor(0 * idx2key - 1).cuda() | |
cur_idx = torch.IntTensor(0 * cur_idx).cuda() | |
idx_counts = torch.IntTensor(0 * idx_counts).cuda() | |
threadsperblock = 256 | |
blockspergrid = len(keys) // threadsperblock + 1 | |
for use_sync in [0, 1]: | |
race_condition[blockspergrid, threadsperblock]( | |
keys, key_mask, key2idx, cur_idx, idx2key, idx_counts, use_sync | |
) | |
cuda_idx_counts = idx_counts.cpu().numpy() | |
cuda_idx_counts.sort() | |
print(f"use_sync: {bool(use_sync)}") | |
print(f"Mismatches: {(true_idx_counts != cuda_idx_counts).sum()}") | |
idx_counts *= 0 | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment