Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Created July 28, 2023 20:59
Show Gist options
  • Save davidberard98/c0cc39f3a2324936abbfe5d8c98eba48 to your computer and use it in GitHub Desktop.
Save davidberard98/c0cc39f3a2324936abbfe5d8c98eba48 to your computer and use it in GitHub Desktop.
import torch
import triton
import triton.language as tl
@triton.jit
def dense_to_jagged_triton(
in_ptr,
offsets_ptr,
inverse_offsets_ptr,
out_ptr,
JAGGED_TOTAL_LEN,
MAX_SEQ_LEN,
BLOCK_SIZE : tl.constexpr,
):
idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = idx % 256
y = idx // 256
mask = y < JAGGED_TOTAL_LEN
batch_load_ptrs = inverse_offsets_ptr + y
batch_idx = tl.load(batch_load_ptrs, mask)
seq_load_ptrs = offsets_ptr + batch_idx
seq_idx = y - tl.load(seq_load_ptrs, mask)
dense_mask = seq_idx < MAX_SEQ_LEN
values = tl.load(in_ptr + x + seq_idx * 256 + batch_idx * 256 * MAX_SEQ_LEN, mask & dense_mask)
masked_values = tl.where(dense_mask, values, 0.0)
tl.store(out_ptr + x + y * 256, masked_values, mask)
def dense_to_jagged(
dense,
offsets,
inverse_offsets,
jagged_total_length,
print_ptx=False,
):
assert dense.shape[-1] == 256
output = torch.empty(
(jagged_total_length, dense.shape[-1]),
dtype=dense.dtype,
device=dense.device,
)
BLOCK_SIZE = 1024
num_warps = 4
grid = (triton.cdiv(output.numel(), BLOCK_SIZE),)
res = dense_to_jagged_triton[grid](
dense,
offsets,
inverse_offsets,
output,
output.shape[0],
dense.shape[1],
BLOCK_SIZE,
num_warps=num_warps,
)
if print_ptx:
print(res.asm.keys())
# print(res.asm["ttir"])
print(res.asm["ttgir"])
# print(res.asm["llir"])
print(res.asm["ptx"])
return output
def generate_offsets(
batch_size: int,
max_seq_len: int,
load_factor: float,
offsets_dtype: torch.dtype,
spread_radius: float,
) -> torch.Tensor:
import random
assert 0 <= load_factor <= 1
assert 0 <= spread_radius <= 1
if load_factor < 1:
spread = int(max_seq_len * spread_radius)
mean = int(max_seq_len * load_factor)
lengths = [
mean + random.randint(-spread, spread + 1) for _ in range(batch_size)
]
lengths = [max(min(L, max_seq_len), 0) for L in lengths]
else:
lengths = [max_seq_len] * batch_size
offsets = [0]
for length in lengths:
offsets.append(offsets[-1] + length)
return torch.tensor(offsets, dtype=offsets_dtype)
BATCH_SIZE = 1024
MAX_SEQ_LEN = 260
EMBEDDING_DIM = 256
dense = torch.rand((BATCH_SIZE, MAX_SEQ_LEN, EMBEDDING_DIM), device='cuda', dtype=torch.float16)
offsets = generate_offsets(BATCH_SIZE, MAX_SEQ_LEN, 0.3, torch.int32, 0.1)
jagged_lengths = offsets[1:] - offsets[:-1]
inverse_offsets = torch.zeros((offsets[-1].item(),), dtype=torch.int32)
idx = 0
for i, cnt in enumerate(jagged_lengths):
for x in range(cnt.item()):
inverse_offsets[idx] = i
idx += 1
jagged_total_length = offsets[-1].item()
offsets = offsets.to('cuda')
inverse_offsets = inverse_offsets.to('cuda')
def run_fn():
dense_to_jagged(dense, offsets, inverse_offsets, jagged_total_length)
dense_to_jagged(dense, offsets, inverse_offsets, jagged_total_length, True)
ms, min_ms, max_ms = triton.testing.do_bench(run_fn, quantiles=[0.5, 0.2, 0.8])
print(ms)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment