Last active
July 7, 2024 14:50
-
-
Save kalradivyanshu/e8200625ace898c8586579b1b80b9cca to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# install flashinfer by running: | |
# pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ | |
import torch | |
import flashinfer | |
from math import sqrt, ceil | |
torch.manual_seed(0) | |
num_layers = 32 | |
num_qo_heads = 64 | |
num_kv_heads = 16 | |
head_dim = 128 | |
softmax_scale = 1.0 / sqrt(head_dim) | |
max_num_pages = 128 | |
page_size = 16 | |
# allocate 128MB workspace buffer | |
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") | |
prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( | |
workspace_buffer, "NHD" | |
) | |
batch_size = 7 | |
nnz_qo = 33 | |
q_indexes = torch.tensor([33]).half().to("cuda:0") | |
qo_indptr = torch.tensor( | |
[0, 33], dtype=torch.int32, device="cuda:0" | |
) | |
num_pages_per_req = torch.ceil(q_indexes.float() / page_size).int() | |
paged_kv_indptr = torch.cat([torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)]).int() | |
paged_kv_indices = torch.arange(0, num_pages_per_req.sum().item(), device="cuda:0", dtype=torch.int32) | |
paged_kv_last_page_len = (q_indexes % page_size).int() | |
q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0") | |
kv_data_at_layer = torch.zeros( | |
max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" | |
) | |
k = torch.randn(nnz_qo, num_kv_heads, head_dim).half().to("cuda:0") | |
v = torch.randn(nnz_qo, num_kv_heads, head_dim).half().to("cuda:0") | |
#store K and V in kv_data_at_layer | |
layer_id = 0 | |
flashinfer.append_paged_kv_cache( | |
k, | |
v, | |
qo_indptr, | |
kv_data_at_layer, | |
paged_kv_indices, | |
paged_kv_indptr, | |
paged_kv_last_page_len | |
) | |
def reconstruct(k_or_v: int): | |
kv_data_reconstructed = torch.zeros_like(k) | |
kv_data_reconstructed[:16,:,:] = kv_data_at_layer[0,k_or_v,:,:,:] # first page | |
kv_data_reconstructed[16:32,:,:] = kv_data_at_layer[1,k_or_v,:,:,:] # second page | |
kv_data_reconstructed[32:33,:,:] = kv_data_at_layer[2,k_or_v,:1,:,:] # third page | |
return kv_data_reconstructed | |
#check if stored properly | |
k_data_reconstructed = reconstruct(0) | |
v_data_reconstructed = reconstruct(1) | |
assert(torch.max(torch.abs(k - k_data_reconstructed)).item() == 0) | |
assert(torch.max(torch.abs(v - v_data_reconstructed)).item() == 0) | |
#now run flash infer prefill | |
prefill_wrapper.begin_forward( | |
qo_indptr, | |
paged_kv_indptr, | |
paged_kv_indices, | |
paged_kv_last_page_len, | |
num_qo_heads, | |
num_kv_heads, | |
head_dim, | |
page_size, | |
) | |
outputs = [] | |
q = q_at_layer[layer_id] | |
kv_data = kv_data_at_layer | |
# compute batch prefill attention, reuse auxiliary data structures | |
output_flashinfer = prefill_wrapper.forward( | |
q, kv_data, causal=True, sm_scale=softmax_scale, allow_fp16_qk_reduction=False | |
) | |
# clear auxiliary data structures | |
prefill_wrapper.end_forward() | |
import vllm_flash_attn | |
#do the same with vllm | |
output_vllm = vllm_flash_attn.flash_attn_varlen_func( | |
q, | |
k, | |
v, | |
qo_indptr, | |
qo_indptr, | |
33, | |
33, | |
causal=True | |
) | |
# check the DIFF | |
print("Max Difference:", torch.max(torch.abs(output_vllm - output_flashinfer))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment