Last active
October 21, 2024 02:08
-
-
Save ArthurZucker/af34221def212259b43d55a2811d2dbb to your computer and use it in GitHub Desktop.
simple static kv cache script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache | |
import torch | |
from typing import Optional | |
device = "cuda" | |
# Copied from the gpt-fast repo | |
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization | |
q = torch.empty_like(probs_sort).exponential_(1) | |
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) | |
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): | |
logits = logits / max(temperature, 1e-5) | |
if top_k is not None: | |
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
pivot = v.select(-1, -1).unsqueeze(-1) | |
logits = torch.where(logits < pivot, -float("Inf"), logits) | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
return probs | |
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): | |
probs = logits_to_probs(logits[:, -1], temperature, top_k) | |
idx_next = multinomial_sample_one_no_sync(probs) | |
return idx_next, probs | |
def decode_one_tokens(model, cur_token, cache_position): | |
logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache = True)[0] | |
new_token = sample(logits,temperature=0.6, top_k=5)[0] | |
return new_token | |
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16) | |
model = model.to(device).eval() | |
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead",fullgraph=True) | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
prompt = "My favourite condiment is" | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
batch_size, sequence_length = input_ids.shape | |
max_cache_length = 2048 | |
max_new_tokens = 100 | |
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length) | |
generated_ids = torch.zeros((batch_size, max_new_tokens+sequence_length), dtype = torch.int, device=device) | |
generated_ids[:,:sequence_length] = input_ids | |
cache_position = torch.tensor([sequence_length], device=device) | |
with torch.no_grad(): | |
for i in range(100): | |
if i == 0: # prefill uses vanilla model | |
logits = model(input_ids, cache_position=torch.arange(sequence_length, device=device))[0] | |
input_id = sample(logits, temperature=0.6, top_k=5)[0] | |
generated_ids[:,sequence_length] = input_id[:,0] | |
else: | |
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): | |
input_id = decode_one_tokens(model, input_id.clone(), cache_position) | |
generated_ids.index_copy_(1, cache_position, input_id) | |
cache_position += 1 | |
print(tokenizer.batch_decode(generated_ids.long())) | |
["<s> My favourite condiment is ketchup. I know, I know, it's a bit cliche, but there's just something about the sweet and tangy flavour that I can't get enough of. I put it on everything from fries to scrambled eggs to grilled meats. And let's be real, it's the perfect accompaniment to a good old-fashioned burger and fries.\n\nBut ketchup isn't just delicious"] |
Hey sorry:
NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
]
prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
def decode_one_tokens(model, cur_token, input_pos, cache_position):
logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
)[0]
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
return new_token
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
model._setup_cache(StaticCache, 2, max_cache_len=4096)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
generated_ids[:, seq_length] = next_token[:, 0]
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, NUM_TOKENS_TO_GENERATE):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
with CaptureLogger(logging.get_logger(__name__)) as cl:
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
self.assertNotIn("skipping cudagraphs due to", cl.out)
generated_ids[:, cache_position] = next_token.int()
cache_position += 1
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
this works because you have the same length of the prompt. You need to make sure the position ids are passed to the model and incremented
Hey all!
https://gist.github.com/ArthurZucker/ae0a86ef8f841c0ef69aaa52ccbc0b03
benchmark
gist.github.com
Should help, we are fixing the last bits 😉
… On 27 Mar 2024, at 06:48, Billy Cao ***@***.***> wrote:
@aliencaocao commented on this gist.
Hi is the static kv cache supposed to require a lot of vram for bs1 and ctx 2048? Trying to use it on llavanext 7b but it oom on a 24gb card. I am loading in 4bit already and there are plenty of vram left (some 10gb)
—
Reply to this email directly, view it on GitHub <https://gist.github.com/ArthurZucker/af34221def212259b43d55a2811d2dbb#gistcomment-5002538> or unsubscribe <https://github.com/notifications/unsubscribe-auth/ALSYHV6G2644FPQVFKM63SDY2JMULBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTEOBSHAYTANBUU52HE2LHM5SXFJTDOJSWC5DF>.
You are receiving this email because you authored the thread.
Triage notifications on the go with GitHub Mobile for iOS <https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675> or Android <https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
Hi @ArthurZucker the link is broken btw
…
On 25 Apr 2024, at 16:07, Billy Cao ***@***.***> wrote:
@aliencaocao commented on this gist.
Hi @ArthurZucker <https://github.com/ArthurZucker> the link is broken btw
—
Reply to this email directly, view it on GitHub <https://gist.github.com/ArthurZucker/af34221def212259b43d55a2811d2dbb#gistcomment-5035795> or unsubscribe <https://github.com/notifications/unsubscribe-auth/ALSYHVYKTWTBLYJLHEBRAH3Y7EE2BBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTEOBSHAYTANBUU52HE2LHM5SXFJTDOJSWC5DF>.
You are receiving this email because you were mentioned.
Triage notifications on the go with GitHub Mobile for iOS <https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675> or Android <https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
Yes, thanks
Thanks for providing this sample! I am building a customized MLP block with Triton2.2. Looks like my model cannot be processed by torch.compile
. It works well with CUDAGraph without cache though. Could you provide a sample code that builds a CUDA graph with static cache? The error I get is
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'View' object has no attribute 'layout'
target: triton_kernel_wrapper_functional
kwargs: {'kernel_idx': 0, 'grid': [(688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1)], 'kwargs': {'a_ptr': TensorBox(
View(
SliceView(
StorageBox(
ComputedBuffer(name='buf19', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 4096], stride=[4096, 4096, 1]), data=Pointwise(
'cuda',
torch.float16,
def inner_fn(index):
....
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Any suggestion on fixing this error is also helpful!!!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @ArthurZucker, can you provide me any kind of support on this issue?