Skip to content

Instantly share code, notes, and snippets.

@remi-or
Created April 16, 2026 08:53
Show Gist options
  • Select an option

  • Save remi-or/923d98b6e50b6597c00d7a935abadc23 to your computer and use it in GitHub Desktop.

Select an option

Save remi-or/923d98b6e50b6597c00d7a935abadc23 to your computer and use it in GitHub Desktop.
"""
k_cache shape: torch.Size([1616, 256, 8, 128])
v_cache shape: torch.Size([1616, 256, 8, 128])
k shape: torch.Size([5, 1, 8, 128])
v shape: torch.Size([5, 1, 8, 128])
cache_seqlens shape: torch.Size([5])
block_table shape: torch.Size([1, 5, 64])
flash_kwargs: {'block_table': tensor([[ 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[ 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[ 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[ 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[ 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], device='cuda:0',
dtype=torch.int32)}
"""
import math
import torch
from transformers.modeling_flash_attention_utils import lazy_import_paged_flash_attention
flash_attn_with_kvcache = lazy_import_paged_flash_attention("kernels-community/flash-attn2")[1]
num_blocks = 128
block_size = 256
max_blocks_per_seq = 8
num_heads = 32
num_kv_heads = 8
head_dim = 128
batch_size = 5
q_state = torch.randn(size=(batch_size, 1, num_heads, head_dim), device="cuda", dtype=torch.bfloat16)
k_state = torch.randn(size=(batch_size, 1, num_kv_heads, head_dim), device="cuda", dtype=torch.bfloat16)
v_state = torch.randn(size=(batch_size, 1, num_kv_heads, head_dim), device="cuda", dtype=torch.bfloat16)
k_cache = torch.randn(size=(num_blocks, block_size, num_kv_heads, head_dim), device="cuda", dtype=torch.bfloat16)
v_cache = torch.randn(size=(num_blocks, block_size, num_kv_heads, head_dim), device="cuda", dtype=torch.bfloat16)
# All 5 request have only one allocated block
block_table = torch.full(size=(batch_size, max_blocks_per_seq), fill_value=-1, dtype=torch.int32, device="cuda")
block_table[0, 0] = 0
block_table[1, 0] = 1
block_table[2, 0] = 2
block_table[3, 0] = 3
block_table[4, 0] = 4
cache_seqlens = torch.tensor(data=[101, 145, 110, 119, 70], dtype=torch.int32, device="cuda")
out = flash_attn_with_kvcache(
q_state,
k_cache,
v_cache,
k=k_state,
v=v_state,
cache_seqlens=cache_seqlens,
block_table=block_table,
softmax_scale=1.0 / math.sqrt(head_dim),
causal=True,
window_size=(-1, -1),
)
print(out.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment