Skip to content

Instantly share code, notes, and snippets.

@drisspg
Created October 19, 2024 00:51
Show Gist options
  • Save drisspg/bef40a41a3f2b6faedf4a5d625616bda to your computer and use it in GitHub Desktop.
Save drisspg/bef40a41a3f2b6faedf4a5d625616bda to your computer and use it in GitHub Desktop.
sdpa.py
import itertools
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from typing import Callable, List, Tuple
from tabulate import tabulate
from tqdm import tqdm
import torch
import torch.utils.benchmark as benchmark
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.functional import scaled_dot_product_attention
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# warmup
for _ in range(5):
func(*args, **kwargs)
t0 = benchmark.Timer(
stmt="func(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "func": func},
)
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
@dataclass(frozen=True)
class ExperimentConfig:
batch_size: int
num_heads: int
q_seq_len: int
kv_seq_len: int
embed_dim: int
is_causal: bool
dtype: torch.dtype
backend: SDPBackend
transposed: bool # New field to control transposition
device: torch.device = torch.device("cuda")
@property
def head_dim(self) -> int:
return self.embed_dim // self.num_heads
def asdict(self):
dict_obj = asdict(self)
dict_obj["head_dim"] = self.head_dim
return dict_obj
@dataclass(frozen=True)
class ExperimentResults:
forward_time: float
backward_time: float
def asdict(self):
return asdict(self)
@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
results: ExperimentResults
def asdict(self):
dict1 = asdict(self.config)
dict2 = asdict(self.results)
return {**dict1, **dict2}
def get_input(
config: ExperimentConfig,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if config.transposed:
q = torch.randn(
(config.batch_size, config.q_seq_len, config.num_heads, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
).transpose(1, 2)
k = torch.randn(
(config.batch_size, config.kv_seq_len, config.num_heads, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
).transpose(1, 2)
v = torch.randn(
(config.batch_size, config.kv_seq_len, config.num_heads, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
).transpose(1, 2)
else:
q = torch.randn(
(config.batch_size, config.num_heads, config.q_seq_len, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
)
k = torch.randn(
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
)
v = torch.randn(
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
)
return q, k, v
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
q, k, v = get_input(config)
is_causal = config.is_causal
context = (
sdpa_kernel(config.backend) if config.backend is not None else nullcontext()
)
with context:
forward_time = benchmark_torch_function_in_microseconds(
scaled_dot_product_attention,
q,
k,
v,
is_causal=is_causal,
attn_mask=None,
)
out_torch = scaled_dot_product_attention(
q, k, v, is_causal=is_causal, attn_mask=None
)
dOut = torch.randn_like(out_torch)
backward_time = benchmark_torch_function_in_microseconds(
out_torch.backward, dOut, retain_graph=True
)
return ExperimentResults(
forward_time=forward_time,
backward_time=backward_time,
)
def generate_experiment_configs() -> List[ExperimentConfig]:
batch_sizes = [
1,
8,
]
num_heads = [32]
q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)]
transposed_configs = [True, False]
embed_dims = [2048]
backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION] # If set to None, all backends are enabled
dtypes = [
torch.float16,
]
is_causal = [True, False]
all_configs = []
for (
bsz,
heads,
(q_seq_len, kv_seq_len),
embed_dim,
causal,
dtype,
backend,
transpose
) in itertools.product(
batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes, backends, transposed_configs
):
all_configs.append(
ExperimentConfig(
batch_size=bsz,
num_heads=heads,
q_seq_len=q_seq_len,
kv_seq_len=kv_seq_len,
embed_dim=embed_dim,
is_causal=causal,
dtype=dtype,
backend=backend,
transposed=transpose,
)
)
return all_configs
def print_results(experiments: List[Experiment]):
table_data = defaultdict(list)
for experiment in experiments:
for key, value in experiment.asdict().items():
table_data[key].append(value)
del table_data["device"]
if table_data["backend"][0] is None:
del table_data["backend"]
print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f"))
def main():
seed = 123
torch.manual_seed(seed)
results = []
for config in tqdm(generate_experiment_configs()):
results.append(Experiment(config, run_single_experiment(config)))
print_results(results)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment