Created
July 28, 2023 20:59
-
-
Save davidberard98/c0cc39f3a2324936abbfe5d8c98eba48 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def dense_to_jagged_triton( | |
in_ptr, | |
offsets_ptr, | |
inverse_offsets_ptr, | |
out_ptr, | |
JAGGED_TOTAL_LEN, | |
MAX_SEQ_LEN, | |
BLOCK_SIZE : tl.constexpr, | |
): | |
idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | |
x = idx % 256 | |
y = idx // 256 | |
mask = y < JAGGED_TOTAL_LEN | |
batch_load_ptrs = inverse_offsets_ptr + y | |
batch_idx = tl.load(batch_load_ptrs, mask) | |
seq_load_ptrs = offsets_ptr + batch_idx | |
seq_idx = y - tl.load(seq_load_ptrs, mask) | |
dense_mask = seq_idx < MAX_SEQ_LEN | |
values = tl.load(in_ptr + x + seq_idx * 256 + batch_idx * 256 * MAX_SEQ_LEN, mask & dense_mask) | |
masked_values = tl.where(dense_mask, values, 0.0) | |
tl.store(out_ptr + x + y * 256, masked_values, mask) | |
def dense_to_jagged( | |
dense, | |
offsets, | |
inverse_offsets, | |
jagged_total_length, | |
print_ptx=False, | |
): | |
assert dense.shape[-1] == 256 | |
output = torch.empty( | |
(jagged_total_length, dense.shape[-1]), | |
dtype=dense.dtype, | |
device=dense.device, | |
) | |
BLOCK_SIZE = 1024 | |
num_warps = 4 | |
grid = (triton.cdiv(output.numel(), BLOCK_SIZE),) | |
res = dense_to_jagged_triton[grid]( | |
dense, | |
offsets, | |
inverse_offsets, | |
output, | |
output.shape[0], | |
dense.shape[1], | |
BLOCK_SIZE, | |
num_warps=num_warps, | |
) | |
if print_ptx: | |
print(res.asm.keys()) | |
# print(res.asm["ttir"]) | |
print(res.asm["ttgir"]) | |
# print(res.asm["llir"]) | |
print(res.asm["ptx"]) | |
return output | |
def generate_offsets( | |
batch_size: int, | |
max_seq_len: int, | |
load_factor: float, | |
offsets_dtype: torch.dtype, | |
spread_radius: float, | |
) -> torch.Tensor: | |
import random | |
assert 0 <= load_factor <= 1 | |
assert 0 <= spread_radius <= 1 | |
if load_factor < 1: | |
spread = int(max_seq_len * spread_radius) | |
mean = int(max_seq_len * load_factor) | |
lengths = [ | |
mean + random.randint(-spread, spread + 1) for _ in range(batch_size) | |
] | |
lengths = [max(min(L, max_seq_len), 0) for L in lengths] | |
else: | |
lengths = [max_seq_len] * batch_size | |
offsets = [0] | |
for length in lengths: | |
offsets.append(offsets[-1] + length) | |
return torch.tensor(offsets, dtype=offsets_dtype) | |
BATCH_SIZE = 1024 | |
MAX_SEQ_LEN = 260 | |
EMBEDDING_DIM = 256 | |
dense = torch.rand((BATCH_SIZE, MAX_SEQ_LEN, EMBEDDING_DIM), device='cuda', dtype=torch.float16) | |
offsets = generate_offsets(BATCH_SIZE, MAX_SEQ_LEN, 0.3, torch.int32, 0.1) | |
jagged_lengths = offsets[1:] - offsets[:-1] | |
inverse_offsets = torch.zeros((offsets[-1].item(),), dtype=torch.int32) | |
idx = 0 | |
for i, cnt in enumerate(jagged_lengths): | |
for x in range(cnt.item()): | |
inverse_offsets[idx] = i | |
idx += 1 | |
jagged_total_length = offsets[-1].item() | |
offsets = offsets.to('cuda') | |
inverse_offsets = inverse_offsets.to('cuda') | |
def run_fn(): | |
dense_to_jagged(dense, offsets, inverse_offsets, jagged_total_length) | |
dense_to_jagged(dense, offsets, inverse_offsets, jagged_total_length, True) | |
ms, min_ms, max_ms = triton.testing.do_bench(run_fn, quantiles=[0.5, 0.2, 0.8]) | |
print(ms) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment