Skip to content

Instantly share code, notes, and snippets.

@noklam
Created March 25, 2025 13:10
Show Gist options
  • Save noklam/e97de674cb48ec687515fcbd40b66386 to your computer and use it in GitHub Desktop.
Save noklam/e97de674cb48ec687515fcbd40b66386 to your computer and use it in GitHub Desktop.
Traverse linklist with O(log N) with CUDA
# Source: https://www.linkedin.com/posts/yidewang_traverse-a-linked-list-of-n-nodes-using-cuda-activity-7310193902676811778-xIID
import triton
import triton.language as tl
import torch
@triton.jit
def pointer_jump_kernel(chum_ptr, n, BLOCK_SIZE: tl.constexpr):
...
pid = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pid < n
my_chum = tl.load(chum_ptr + pid, mask=mask, other=-1)
valid_mask = (my_chum != -1) & mask
chums_chum= tl.full([BLOCK_SIZE], -1, dtype=tl.int32)
chums_chum = tl.where(valid_mask, tl.load(chum_ptr + my_chum, mask=valid_mask,other=-1), chums_chum)
update_mask = valid_mask & (chums_chum != -1) & (chums_chum != my_chum)
my_chum = tl.where(update_mask, chums_chum, my_chum)
tl.store(chum_ptr + pid, my_chum, mask=mask)
def find_end_of_list(next_list):
n = len(next_list)
chum_tensor= torch.tensor(next_list, dtype=torch.int32, device="cuda")
print("initila chum_tensor", chum_tensor)
BLOCK_SIZE = 32
grid = (n+BLOCK_SIZE-1) // BLOCK_SIZE
max_iterations = int(torch.log2(torch.tensor(float(n))).item()) + 1
for i in range(max_iterations):
pointer_jump_kernel[(grid,)](chum_tensor, n, BLOCK_SIZE=BLOCK_SIZE)
print(f"After iteration {i}:", chum_tensor)
return chum_tensor.cpu().numpy()
if __name__ == "__main__":
next_list = [1,2,3,4,5,6,7,-1]
end_list = find_end_of_list(next_list)
print(" Node end of list from this node")
for i , end in enumerate(end_list):
print(f"{i}\t {end}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment