Created
October 19, 2024 00:51
-
-
Save drisspg/bef40a41a3f2b6faedf4a5d625616bda to your computer and use it in GitHub Desktop.
sdpa.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
from collections import defaultdict | |
from contextlib import nullcontext | |
from dataclasses import asdict, dataclass | |
from typing import Callable, List, Tuple | |
from tabulate import tabulate | |
from tqdm import tqdm | |
import torch | |
import torch.utils.benchmark as benchmark | |
from torch.nn.attention import sdpa_kernel, SDPBackend | |
from torch.nn.functional import scaled_dot_product_attention | |
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float: | |
# warmup | |
for _ in range(5): | |
func(*args, **kwargs) | |
t0 = benchmark.Timer( | |
stmt="func(*args, **kwargs)", | |
globals={"args": args, "kwargs": kwargs, "func": func}, | |
) | |
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6 | |
@dataclass(frozen=True) | |
class ExperimentConfig: | |
batch_size: int | |
num_heads: int | |
q_seq_len: int | |
kv_seq_len: int | |
embed_dim: int | |
is_causal: bool | |
dtype: torch.dtype | |
backend: SDPBackend | |
transposed: bool # New field to control transposition | |
device: torch.device = torch.device("cuda") | |
@property | |
def head_dim(self) -> int: | |
return self.embed_dim // self.num_heads | |
def asdict(self): | |
dict_obj = asdict(self) | |
dict_obj["head_dim"] = self.head_dim | |
return dict_obj | |
@dataclass(frozen=True) | |
class ExperimentResults: | |
forward_time: float | |
backward_time: float | |
def asdict(self): | |
return asdict(self) | |
@dataclass(frozen=True) | |
class Experiment: | |
config: ExperimentConfig | |
results: ExperimentResults | |
def asdict(self): | |
dict1 = asdict(self.config) | |
dict2 = asdict(self.results) | |
return {**dict1, **dict2} | |
def get_input( | |
config: ExperimentConfig, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
if config.transposed: | |
q = torch.randn( | |
(config.batch_size, config.q_seq_len, config.num_heads, config.head_dim), | |
dtype=config.dtype, | |
device=config.device, | |
requires_grad=True, | |
).transpose(1, 2) | |
k = torch.randn( | |
(config.batch_size, config.kv_seq_len, config.num_heads, config.head_dim), | |
dtype=config.dtype, | |
device=config.device, | |
requires_grad=True, | |
).transpose(1, 2) | |
v = torch.randn( | |
(config.batch_size, config.kv_seq_len, config.num_heads, config.head_dim), | |
dtype=config.dtype, | |
device=config.device, | |
requires_grad=True, | |
).transpose(1, 2) | |
else: | |
q = torch.randn( | |
(config.batch_size, config.num_heads, config.q_seq_len, config.head_dim), | |
dtype=config.dtype, | |
device=config.device, | |
requires_grad=True, | |
) | |
k = torch.randn( | |
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim), | |
dtype=config.dtype, | |
device=config.device, | |
requires_grad=True, | |
) | |
v = torch.randn( | |
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim), | |
dtype=config.dtype, | |
device=config.device, | |
requires_grad=True, | |
) | |
return q, k, v | |
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults: | |
q, k, v = get_input(config) | |
is_causal = config.is_causal | |
context = ( | |
sdpa_kernel(config.backend) if config.backend is not None else nullcontext() | |
) | |
with context: | |
forward_time = benchmark_torch_function_in_microseconds( | |
scaled_dot_product_attention, | |
q, | |
k, | |
v, | |
is_causal=is_causal, | |
attn_mask=None, | |
) | |
out_torch = scaled_dot_product_attention( | |
q, k, v, is_causal=is_causal, attn_mask=None | |
) | |
dOut = torch.randn_like(out_torch) | |
backward_time = benchmark_torch_function_in_microseconds( | |
out_torch.backward, dOut, retain_graph=True | |
) | |
return ExperimentResults( | |
forward_time=forward_time, | |
backward_time=backward_time, | |
) | |
def generate_experiment_configs() -> List[ExperimentConfig]: | |
batch_sizes = [ | |
1, | |
8, | |
] | |
num_heads = [32] | |
q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)] | |
transposed_configs = [True, False] | |
embed_dims = [2048] | |
backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION] # If set to None, all backends are enabled | |
dtypes = [ | |
torch.float16, | |
] | |
is_causal = [True, False] | |
all_configs = [] | |
for ( | |
bsz, | |
heads, | |
(q_seq_len, kv_seq_len), | |
embed_dim, | |
causal, | |
dtype, | |
backend, | |
transpose | |
) in itertools.product( | |
batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes, backends, transposed_configs | |
): | |
all_configs.append( | |
ExperimentConfig( | |
batch_size=bsz, | |
num_heads=heads, | |
q_seq_len=q_seq_len, | |
kv_seq_len=kv_seq_len, | |
embed_dim=embed_dim, | |
is_causal=causal, | |
dtype=dtype, | |
backend=backend, | |
transposed=transpose, | |
) | |
) | |
return all_configs | |
def print_results(experiments: List[Experiment]): | |
table_data = defaultdict(list) | |
for experiment in experiments: | |
for key, value in experiment.asdict().items(): | |
table_data[key].append(value) | |
del table_data["device"] | |
if table_data["backend"][0] is None: | |
del table_data["backend"] | |
print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f")) | |
def main(): | |
seed = 123 | |
torch.manual_seed(seed) | |
results = [] | |
for config in tqdm(generate_experiment_configs()): | |
results.append(Experiment(config, run_single_experiment(config))) | |
print_results(results) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment