- Problem: We have blocks that are scheduled later than others which imply that we won't get the "true max value" at the time we need it.
- Direction: We should find a way to wait for all threads of all blocks to finish
- Solution:
-
- Split into 2 kernels
-
- Use cooperative groups: https://numba.readthedocs.io/en/stable/cuda/cooperative_groups.html
-
Last active
March 28, 2023 19:22
-
-
Save 3outeille/1d4336a19dd5ef9c32fa9d832dd03166 to your computer and use it in GitHub Desktop.
race condition fuck my life
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
python sandbox/race_condition.py | |
========= CUDA-MEMCHECK | |
========= This tool is deprecated and will be removed in a future release of the CUDA toolkit | |
========= Please use the compute-sanitizer tool as a drop-in replacement | |
=============== CPU ==================== | |
---- sO ---- | |
tensor([[0.3000, 0.3600, 0.4200], | |
[0.6600, 0.8100, 0.9600], | |
[1.0200, 1.2600, 1.5000]], dtype=torch.float64) | |
---- tile_rowmax ---- | |
[0.42 0.96 1.5 ] | |
---- tile_numerator ---- | |
[[0.72 1.32 1.92] | |
[1.08 1.77 2.46] | |
[1.44 2.22 3. ]] | |
=============== GPU ==================== | |
blockspergrid: (3, 3), threadsperblock: (1, 1) | |
/home/f.mom/sync/.pyenv/versions/3.9.12/envs/env-flash-attention-numba/lib/python3.9/site-packages/numba/cuda/dispatcher.py:488: NumbaPerformanceWarning: Grid size 9 will likely result in GPU under-utilization due to low occupancy. | |
warn(NumbaPerformanceWarning(msg)) | |
---- tmp_O ---- | |
[[0.3 0.36 0.42] | |
[0.66 0.81 0.96] | |
[1.02 1.26 1.5 ]] | |
---- tile_rowmax ---- | |
[0.42 0.96 1.5 ] | |
---- tile_numerator ---- | |
[[0.72 1.32 1.92] | |
[1.08 1.77 2.46] | |
[1.44 2.22 3. ]] | |
========= ERROR SUMMARY: 0 errors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
python sandbox/race_condition.py | |
=============== CPU ==================== | |
---- sO ---- | |
tensor([[0.3000, 0.3600, 0.4200], | |
[0.6600, 0.8100, 0.9600], | |
[1.0200, 1.2600, 1.5000]], dtype=torch.float64) | |
---- tile_rowmax ---- | |
[0.42 0.96 1.5 ] | |
---- tile_numerator ---- | |
[[0.72 1.32 1.92] | |
[1.08 1.77 2.46] | |
[1.44 2.22 3. ]] | |
=============== GPU ==================== | |
blockspergrid: (3, 3), threadsperblock: (1, 1) | |
/home/f.mom/sync/.pyenv/versions/3.9.12/envs/env-flash-attention-numba/lib/python3.9/site-packages/numba/cuda/dispatcher.py:488: NumbaPerformanceWarning: Grid size 9 will likely result in GPU under-utilization due to low occupancy. | |
warn(NumbaPerformanceWarning(msg)) | |
---- tmp_O ---- | |
[[0.3 0.36 0.42] | |
[0.66 0.81 0.96] | |
[1.02 1.26 1.5 ]] | |
---- tile_rowmax ---- | |
[0.42 0.96 1.5 ] | |
---- tile_numerator ---- | |
[[0.66 1.32 1.92] | |
[1.02 1.77 2.46] | |
[1.38 2.22 3. ]] | |
Traceback (most recent call last): | |
File "/mnt/nfs/home/f.mom/flash_attention_numba/sandbox/race_condition.py", line 128, in <module> | |
assert np.allclose(h_tile_numerator, tile_numerator_cpu) | |
AssertionError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import os | |
import numpy as np | |
import torch | |
from numba import cuda, float64 | |
import math | |
def seed_everything(): | |
seed = 42 | |
random.seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = True | |
def cpu(A, B): | |
sO = A @ B | |
print("---- sO ----") | |
print(sO) | |
tile_rowmax = torch.max(sO, dim=1).values | |
tile_numerator = sO + tile_rowmax | |
return tile_rowmax, tile_numerator | |
BLOCKSIZE = 1 | |
@cuda.jit | |
def kernel(A, B, tile_rowmax, tile_numerator, tmp_O): | |
sA = cuda.shared.array(shape=(BLOCKSIZE, BLOCKSIZE), dtype=float64) | |
sB = cuda.shared.array(shape=(BLOCKSIZE, BLOCKSIZE), dtype=float64) | |
sO = cuda.shared.array(shape=(BLOCKSIZE, BLOCKSIZE), dtype=float64) | |
col = cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x | |
row = cuda.threadIdx.y + cuda.blockIdx.y * cuda.blockDim.y | |
tx = cuda.threadIdx.x | |
ty = cuda.threadIdx.y | |
g = cuda.cg.this_grid() <=== FIX | |
for blockId in range(cuda.gridDim.x): | |
# Load a tile of A and B into shared memory | |
if row < A.shape[0] and tx + blockId * BLOCKSIZE < A.shape[1]: | |
sA[ty, tx] = A[row, tx + blockId * BLOCKSIZE] | |
if col < B.shape[1] and ty + blockId * BLOCKSIZE < B.shape[0]: | |
sB[ty, tx] = B[ty + blockId * BLOCKSIZE, col] | |
cuda.syncthreads() | |
# Matmul on the current tile | |
for k in range(BLOCKSIZE): | |
sO[ty, tx] += sA[ty, k] * sB[k, tx] | |
cuda.atomic.max(tile_rowmax, row, sO[ty, tx]) | |
cuda.syncthreads() | |
tmp_O[row, col] = sO[ty, tx] | |
g.sync() # <=== FIX | |
tile_numerator[row, col] = sO[ty, tx] + tile_rowmax[col] | |
if __name__ == "__main__": | |
seed_everything() | |
N = 3 | |
h_A = (torch.arange(N * N, dtype=torch.float64).reshape(N, N) + 1.) / 10 | |
h_B = (torch.arange(N * N, dtype=torch.float64).reshape(N, N) + 1.) / 10 | |
h_tile_rowmax = torch.zeros(N, dtype=torch.float64) | |
h_tile_numerator = torch.zeros(N * N, dtype=torch.float64).reshape(N, N) | |
d_A = cuda.to_device(h_A) | |
d_B = cuda.to_device(h_B) | |
d_tile_rowmax = cuda.to_device(h_tile_rowmax) | |
d_tile_numerator = cuda.to_device(h_tile_numerator) | |
print("=============== CPU ====================") | |
tile_rowmax_cpu, tile_numerator_cpu = cpu(h_A.clone(), h_B.clone()) | |
tile_rowmax_cpu = tile_rowmax_cpu.numpy() | |
tile_numerator_cpu = tile_numerator_cpu.numpy() | |
print("---- tile_rowmax ----") | |
print(tile_rowmax_cpu) | |
print("---- tile_numerator ----") | |
print(tile_numerator_cpu) | |
print("=============== GPU ====================") | |
threadsperblock = (BLOCKSIZE, BLOCKSIZE) | |
blockspergrid_x = math.ceil(N / threadsperblock[0]) | |
blockspergrid_y = math.ceil(N / threadsperblock[1]) | |
blockspergrid = (blockspergrid_x, blockspergrid_y) | |
print(f"blockspergrid: {blockspergrid}, threadsperblock: {threadsperblock}") | |
h_tmp_O = torch.zeros(N * N, dtype=torch.float64).reshape(N, N) | |
d_tmp_O = cuda.to_device(h_tmp_O) | |
kernel[blockspergrid, threadsperblock]( | |
d_A, | |
d_B, | |
d_tile_rowmax, | |
d_tile_numerator, | |
d_tmp_O, | |
) | |
h_tmp_O = d_tmp_O.copy_to_host() | |
h_tile_rowmax = d_tile_rowmax.copy_to_host() | |
h_tile_numerator = d_tile_numerator.copy_to_host() | |
print("---- tmp_O ----") | |
print(h_tmp_O) | |
print("---- tile_rowmax ----") | |
print(h_tile_rowmax) | |
print("---- tile_numerator ----") | |
print(h_tile_numerator) | |
assert np.allclose(h_tile_rowmax, tile_rowmax_cpu) | |
assert np.allclose(h_tile_numerator, tile_numerator_cpu) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-f https://download.pytorch.org/whl/cu117/torch_stable.html | |
torch==1.13.1+cu117 | |
numba | |
pygount | |
pdbpp | |
pytest | |
pdbpp |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment