Skip to content

Instantly share code, notes, and snippets.

@kalradivyanshu
Last active July 7, 2024 14:50
Show Gist options
  • Save kalradivyanshu/e8200625ace898c8586579b1b80b9cca to your computer and use it in GitHub Desktop.
Save kalradivyanshu/e8200625ace898c8586579b1b80b9cca to your computer and use it in GitHub Desktop.
# install flashinfer by running:
# pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
import torch
import flashinfer
from math import sqrt, ceil
torch.manual_seed(0)
num_layers = 32
num_qo_heads = 64
num_kv_heads = 16
head_dim = 128
softmax_scale = 1.0 / sqrt(head_dim)
max_num_pages = 128
page_size = 16
# allocate 128MB workspace buffer
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
batch_size = 7
nnz_qo = 33
q_indexes = torch.tensor([33]).half().to("cuda:0")
qo_indptr = torch.tensor(
[0, 33], dtype=torch.int32, device="cuda:0"
)
num_pages_per_req = torch.ceil(q_indexes.float() / page_size).int()
paged_kv_indptr = torch.cat([torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)]).int()
paged_kv_indices = torch.arange(0, num_pages_per_req.sum().item(), device="cuda:0", dtype=torch.int32)
paged_kv_last_page_len = (q_indexes % page_size).int()
q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0")
kv_data_at_layer = torch.zeros(
max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
)
k = torch.randn(nnz_qo, num_kv_heads, head_dim).half().to("cuda:0")
v = torch.randn(nnz_qo, num_kv_heads, head_dim).half().to("cuda:0")
#store K and V in kv_data_at_layer
layer_id = 0
flashinfer.append_paged_kv_cache(
k,
v,
qo_indptr,
kv_data_at_layer,
paged_kv_indices,
paged_kv_indptr,
paged_kv_last_page_len
)
def reconstruct(k_or_v: int):
kv_data_reconstructed = torch.zeros_like(k)
kv_data_reconstructed[:16,:,:] = kv_data_at_layer[0,k_or_v,:,:,:] # first page
kv_data_reconstructed[16:32,:,:] = kv_data_at_layer[1,k_or_v,:,:,:] # second page
kv_data_reconstructed[32:33,:,:] = kv_data_at_layer[2,k_or_v,:1,:,:] # third page
return kv_data_reconstructed
#check if stored properly
k_data_reconstructed = reconstruct(0)
v_data_reconstructed = reconstruct(1)
assert(torch.max(torch.abs(k - k_data_reconstructed)).item() == 0)
assert(torch.max(torch.abs(v - v_data_reconstructed)).item() == 0)
#now run flash infer prefill
prefill_wrapper.begin_forward(
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
)
outputs = []
q = q_at_layer[layer_id]
kv_data = kv_data_at_layer
# compute batch prefill attention, reuse auxiliary data structures
output_flashinfer = prefill_wrapper.forward(
q, kv_data, causal=True, sm_scale=softmax_scale, allow_fp16_qk_reduction=False
)
# clear auxiliary data structures
prefill_wrapper.end_forward()
import vllm_flash_attn
#do the same with vllm
output_vllm = vllm_flash_attn.flash_attn_varlen_func(
q,
k,
v,
qo_indptr,
qo_indptr,
33,
33,
causal=True
)
# check the DIFF
print("Max Difference:", torch.max(torch.abs(output_vllm - output_flashinfer)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment