Created
June 7, 2024 08:39
-
-
Save ArthurZucker/a79018e7642e7ddefe06531407ef8401 to your computer and use it in GitHub Desktop.
Whisper static cache
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset | |
from transformers import WhisperForConditionalGeneration, AutoProcessor, StaticCache | |
import torch | |
import torch._dynamo.config | |
import torch._inductor.config | |
import time | |
from tqdm import tqdm | |
import logging | |
torch._inductor.config.coordinate_descent_tuning = True | |
torch._inductor.config.triton.unique_kernel_names = True | |
torch._inductor.config.fx_graph_cache = True | |
torch._dynamo.config.cache_size_limit = 32 | |
torch._logging.set_logs(recompiles=True, graph_breaks=True) | |
# torch.set_float32_matmul_precision('high') | |
torch.set_printoptions(linewidth=200) # you can better see how the mask is shaped | |
NUM_TOKENS = 100 | |
NUM_WARMUP = 3 | |
NUM_ITERS = 5 | |
ATTN_IMPLEMENTATION = "sdpa" | |
MODEL_ID = "openai/whisper-medium.en" | |
BATCH_SIZES = [1] | |
torch_device = "cuda:0" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
processor = AutoProcessor.from_pretrained(MODEL_ID) | |
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID, attn_implementation=ATTN_IMPLEMENTATION) | |
model.to(torch_device, dtype=torch_dtype) | |
is_multilingual = getattr(model.generation_config, "is_multilingual", False) | |
language = "en" if is_multilingual else None | |
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True) | |
sample = dataset[30]["audio"] | |
inputs = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").to(torch_device) | |
input_features = inputs.input_features.to(torch_dtype) | |
all_dynamic_results = {} | |
all_static_results = {} | |
from torch.profiler import profile, record_function, ProfilerActivity, \ | |
tensorboard_trace_handler | |
import datetime | |
with torch.no_grad(): | |
for _ in range(NUM_ITERS): | |
torch.cuda.synchronize() | |
start = datetime.datetime.now() | |
out = model.generate(input_features, min_new_tokens=NUM_TOKENS, max_new_tokens=NUM_TOKENS, language=language, begin_suppress_tokens=None, suppress_tokens=None) | |
torch.cuda.synchronize() | |
print(processor.tokenizer.batch_decode(out)) | |
runtime = datetime.datetime.now() - start | |
print(f"Inference took {runtime} seconds") | |
tok_per_s = (out.shape[1] * 1) / runtime.total_seconds() | |
print(f"Dynamic bsz {1} - {tok_per_s} tok/s") | |
model.generation_config.cache_implementation = "static" | |
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) | |
for batch_size in BATCH_SIZES: | |
# cache = (StaticCache(model.config, batch_size, (NUM_WARMUP + NUM_ITERS) * NUM_TOKENS, device="cuda:0", dtype=torch.float16) , StaticCache(model.config, batch_size, (NUM_WARMUP + NUM_ITERS) * NUM_TOKENS, device="cuda:0", dtype=torch.float16 )) | |
cache = None | |
# cache = StaticCache(model.config,batch_size, (NUM_WARMUP + NUM_ITERS)*NUM_TOKENS, device=model.device, dtype=torch.float16) | |
input_features_batch = input_features.repeat(batch_size, 1, 1) | |
for _ in range(NUM_WARMUP): | |
start = datetime.datetime.now() | |
torch.cuda.synchronize() | |
out = model.generate(input_features, past_key_values=cache, min_new_tokens=NUM_TOKENS, max_new_tokens=NUM_TOKENS, language=language, begin_suppress_tokens=None, suppress_tokens=None) | |
torch.cuda.synchronize() | |
# cache[1].reset() | |
# cache[0].reset() | |
print(processor.tokenizer.batch_decode(out)) | |
print(f"Warmup Inference took {datetime.datetime.now() - start} seconds") | |
# start = time.time() | |
# with profile( | |
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], | |
# on_trace_ready=tensorboard_trace_handler(f"./tb_logs/tb_{datetime.datetime.now().strftime('%Y-%m-%d_%Hh%Mm%Ss')}"), | |
# record_shapes=False, | |
# profile_memory=False, | |
# with_stack=False | |
# ) as prof: | |
# for _ in range(NUM_ITERS): | |
# with torch.inference_mode(): | |
# torch.cuda.synchronize() | |
# start = datetime.datetime.now() | |
# with record_function("generate"): | |
# out = model.generate(input_features, past_key_values=cache, min_new_tokens=NUM_TOKENS, max_new_tokens=NUM_TOKENS, language=language, begin_suppress_tokens=None, suppress_tokens=None) | |
# torch.cuda.synchronize() | |
# print(processor.tokenizer.batch_decode(out)) | |
# print(f"Profiled Inference took {datetime.datetime.now() - start} seconds") | |
for _ in range(NUM_ITERS): | |
torch.cuda.synchronize() | |
start = datetime.datetime.now() | |
out = model.generate(input_features, min_new_tokens=NUM_TOKENS, max_new_tokens=NUM_TOKENS, language=language, begin_suppress_tokens=None, suppress_tokens=None) | |
torch.cuda.synchronize() | |
print(processor.tokenizer.batch_decode(out)) | |
runtime = datetime.datetime.now() - start | |
print(f"Inference took {runtime} seconds") | |
tok_per_s = (out.shape[1] * batch_size) / runtime.total_seconds() | |
print(f"Static bsz {batch_size} - {tok_per_s} tok/s") | |
# runtime = time.time() - start | |
# tok_per_s = (NUM_TOKENS * batch_size) / runtime | |
# all_dynamic_results[batch_size] = tok_per_s | |
# model.generation_config.cache_implementation = "static" | |
# model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) | |
# for batch_size in BATCH_SIZES: | |
# input_features_batch = input_features.repeat(batch_size, 1, 1) | |
# for _ in range(NUM_WARMUP): | |
# model.generate(input_features, min_new_tokens=NUM_TOKENS, max_new_tokens=NUM_TOKENS, language=language, begin_suppress_tokens=None, suppress_tokens=None) | |
# start = time.time() | |
# for _ in range(NUM_ITERS): | |
# model.generate(input_features, min_new_tokens=NUM_TOKENS, max_new_tokens=NUM_TOKENS, language=language, begin_suppress_tokens=None, suppress_tokens=None) | |
# torch.cuda.synchronize() | |
# runtime = time.time() - start | |
# tok_per_s = (NUM_TOKENS * batch_size) / runtime | |
# all_static_results[batch_size] = tok_per_s | |
# print(f"Static bsz {batch_size} - {tok_per_s} tok/s") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment