Created
September 5, 2025 15:40
-
-
Save maaquib/9465d79bd0e524f69f621a6dc1c1310d to your computer and use it in GitHub Desktop.
benchmark_flash_comparison
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Benchmark comparing FlashInfer vs Flash Attention on B200 GPU. | |
| Tests various sequence lengths and head dimensions. | |
| """ | |
| import argparse | |
| import sys | |
| from collections.abc import Sequence | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from typing import Any | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn.attention import SDPBackend, sdpa_kernel | |
| try: | |
| import flashinfer | |
| FLASHINFER_AVAILABLE = True | |
| except ImportError: | |
| FLASHINFER_AVAILABLE = False | |
| print( | |
| "Warning: FlashInfer not available. Install with: pip install flashinfer-python" | |
| ) | |
| try: | |
| from flash_attn import flash_attn_func | |
| FLASH_ATTN_AVAILABLE = True | |
| except ImportError: | |
| FLASH_ATTN_AVAILABLE = False | |
| print( | |
| "Warning: Flash Attention not available. Install with: pip install flash-attn --no-build-isolation" | |
| ) | |
| try: | |
| import max | |
| MAX_AVAILABLE = True | |
| except ImportError: | |
| MAX_AVAILABLE = False | |
| print("Warning: Max not available. Install with: pip install max") | |
| @dataclass | |
| class BenchmarkConfig: | |
| """Configuration for benchmark parameters.""" | |
| batch_sizes: Sequence[int] | None = None | |
| seq_lengths: Sequence[int] | None = None | |
| num_heads_list: Sequence[int] | None = None | |
| head_dims: Sequence[int] | None = None | |
| causal: bool = True | |
| dtype: torch.dtype = torch.float16 | |
| warmup_iters: int = 10 | |
| benchmark_iters: int = 100 | |
| def benchmark_flashinfer( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| causal: bool = True, | |
| warmup_iters: int = 10, | |
| benchmark_iters: int = 100, | |
| ): | |
| """Benchmark FlashInfer attention.""" | |
| if not FLASHINFER_AVAILABLE: | |
| return None | |
| batch_size, num_heads, seq_len, head_dim = q.shape | |
| # Reshape for FlashInfer (batch_size * seq_len, num_heads, head_dim) | |
| q_fi = q.transpose(1, 2).contiguous().view(-1, num_heads, head_dim) | |
| k_fi = k.transpose(1, 2).contiguous().view(-1, num_heads, head_dim) | |
| v_fi = v.transpose(1, 2).contiguous().view(-1, num_heads, head_dim) | |
| scale = 1.0 / (head_dim ** 0.5) | |
| # Warmup | |
| for _ in range(warmup_iters): | |
| with torch.no_grad(): | |
| _ = flashinfer.single_prefill_with_kv_cache( | |
| q_fi, | |
| k_fi, | |
| v_fi, | |
| causal=causal, | |
| sm_scale=scale, | |
| ) | |
| torch.cuda.synchronize() | |
| # Benchmark | |
| start_events = [] | |
| end_events = [] | |
| for _ in range(benchmark_iters): | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| with torch.no_grad(): | |
| output = flashinfer.single_prefill_with_kv_cache( | |
| q_fi, | |
| k_fi, | |
| v_fi, | |
| causal=causal, | |
| sm_scale=scale, | |
| ) | |
| end.record() | |
| start_events.append(start) | |
| end_events.append(end) | |
| torch.cuda.synchronize() | |
| latencies = [] | |
| for start, end in zip(start_events, end_events): | |
| latencies.append(start.elapsed_time(end)) | |
| # Convert back to original shape for validation | |
| output = output.view(batch_size, seq_len, num_heads, head_dim) | |
| output = output.transpose(1, 2).contiguous() | |
| return { | |
| "mean_ms": np.mean(latencies), | |
| "p50_ms": np.percentile(latencies, 50), | |
| "p90_ms": np.percentile(latencies, 90), | |
| "p95_ms": np.percentile(latencies, 95), | |
| "p99_ms": np.percentile(latencies, 99), | |
| "std_ms": np.std(latencies), | |
| "median_ms": np.median(latencies), | |
| "min_ms": np.min(latencies), | |
| "max_ms": np.max(latencies), | |
| "output": output, | |
| } | |
| def benchmark_flash_attention( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| causal: bool = True, | |
| warmup_iters: int = 10, | |
| benchmark_iters: int = 100, | |
| ): | |
| """Benchmark Flash Attention.""" | |
| if not FLASH_ATTN_AVAILABLE: | |
| return None | |
| batch_size, num_heads, seq_len, head_dim = q.shape | |
| # Reshape for Flash Attention (batch_size, seq_len, num_heads, head_dim) | |
| q_fa = q.transpose(1, 2).contiguous() | |
| k_fa = k.transpose(1, 2).contiguous() | |
| v_fa = v.transpose(1, 2).contiguous() | |
| # Warmup | |
| for _ in range(warmup_iters): | |
| with torch.no_grad(): | |
| _ = flash_attn_func(q_fa, k_fa, v_fa, causal=causal) | |
| torch.cuda.synchronize() | |
| # Benchmark | |
| start_events = [] | |
| end_events = [] | |
| for _ in range(benchmark_iters): | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| with torch.no_grad(): | |
| output = flash_attn_func(q_fa, k_fa, v_fa, causal=causal) | |
| end.record() | |
| start_events.append(start) | |
| end_events.append(end) | |
| torch.cuda.synchronize() | |
| latencies = [] | |
| for start, end in zip(start_events, end_events): | |
| latencies.append(start.elapsed_time(end)) | |
| # Convert back to original shape for validation | |
| output = output.transpose(1, 2).contiguous() | |
| return { | |
| "mean_ms": np.mean(latencies), | |
| "p50_ms": np.percentile(latencies, 50), | |
| "p90_ms": np.percentile(latencies, 90), | |
| "p95_ms": np.percentile(latencies, 95), | |
| "p99_ms": np.percentile(latencies, 99), | |
| "std_ms": np.std(latencies), | |
| "median_ms": np.median(latencies), | |
| "min_ms": np.min(latencies), | |
| "max_ms": np.max(latencies), | |
| "output": output, | |
| } | |
| def benchmark_torch_sdpa( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| causal: bool = True, | |
| warmup_iters: int = 10, | |
| benchmark_iters: int = 100, | |
| ): | |
| """Benchmark PyTorch's scaled dot-product attention (baseline).""" | |
| # Warmup | |
| for _ in range(warmup_iters): | |
| with torch.no_grad(): | |
| _ = F.scaled_dot_product_attention(q, k, v, is_causal=causal) | |
| torch.cuda.synchronize() | |
| # Benchmark | |
| start_events = [] | |
| end_events = [] | |
| for _ in range(benchmark_iters): | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| with torch.no_grad(): | |
| output = F.scaled_dot_product_attention(q, k, v, is_causal=causal) | |
| end.record() | |
| start_events.append(start) | |
| end_events.append(end) | |
| torch.cuda.synchronize() | |
| latencies = [] | |
| for start, end in zip(start_events, end_events): | |
| latencies.append(start.elapsed_time(end)) | |
| return { | |
| "mean_ms": np.mean(latencies), | |
| "p50_ms": np.percentile(latencies, 50), | |
| "p90_ms": np.percentile(latencies, 90), | |
| "p95_ms": np.percentile(latencies, 95), | |
| "p99_ms": np.percentile(latencies, 99), | |
| "std_ms": np.std(latencies), | |
| "median_ms": np.median(latencies), | |
| "min_ms": np.min(latencies), | |
| "max_ms": np.max(latencies), | |
| "output": output, | |
| } | |
| def benchmark_max( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| causal: bool = True, | |
| warmup_iters: int = 10, | |
| benchmark_iters: int = 100, | |
| ): | |
| """Benchmark Max's attention using Graph API with flash_attention_gpu kernel.""" | |
| if not MAX_AVAILABLE: | |
| return None | |
| try: | |
| from max.driver import Accelerator, Tensor | |
| from max.dtype import DType | |
| from max.engine import InferenceSession | |
| from max.graph import DeviceRef, Graph, TensorType, TensorValue, ops | |
| from max.nn.attention.mask_config import MHAMaskVariant | |
| from max.nn.kernels import flash_attention_gpu | |
| from torch.utils.dlpack import from_dlpack | |
| except ImportError as e: | |
| print(f"Warning: Could not import MAX modules: {e}") | |
| return None | |
| _, num_heads, _, head_dim = q.shape | |
| # Convert torch tensors to the right format (batch, seq_len, num_heads, head_dim) | |
| q_max = q.transpose(1, 2).contiguous() | |
| k_max = k.transpose(1, 2).contiguous() | |
| v_max = v.transpose(1, 2).contiguous() | |
| dtype_map = { | |
| torch.float16: DType.float16, | |
| torch.bfloat16: DType.bfloat16, | |
| torch.float32: DType.float32, | |
| } | |
| max_dtype = dtype_map.get(q.dtype, DType.float32) | |
| device_ref, device = DeviceRef.GPU(), Accelerator() | |
| with Graph( | |
| "max_attention_benchmark", | |
| input_types=( | |
| TensorType( | |
| shape=("batch_size", "seq_len", num_heads, head_dim), | |
| dtype=max_dtype, | |
| device=device_ref, | |
| ), | |
| TensorType( | |
| shape=("batch_size", "seq_len", num_heads, head_dim), | |
| dtype=max_dtype, | |
| device=device_ref, | |
| ), | |
| TensorType( | |
| shape=("batch_size", "seq_len", num_heads, head_dim), | |
| dtype=max_dtype, | |
| device=device_ref, | |
| ), | |
| ), | |
| ) as graph: | |
| q_input, k_input, v_input = graph.inputs | |
| assert isinstance(q_input, TensorValue) | |
| assert isinstance(k_input, TensorValue) | |
| assert isinstance(v_input, TensorValue) | |
| mask_variant = MHAMaskVariant.CAUSAL_MASK if causal else MHAMaskVariant.NULL_MASK | |
| scale = 1.0 / np.sqrt(head_dim) | |
| output = flash_attention_gpu( | |
| q=q_input, | |
| k=k_input, | |
| v=v_input, | |
| mask_variant=mask_variant, | |
| scale=scale, | |
| ) | |
| # Transpose back to (batch, num_heads, seq_len, head_dim) | |
| output = ops.transpose(output, 1, 2) | |
| graph.output(output) | |
| session = InferenceSession(devices=[device]) | |
| compiled = session.load(graph) | |
| q_max = Tensor.from_dlpack(q_max).to(device) | |
| k_max = Tensor.from_dlpack(k_max).to(device) | |
| v_max = Tensor.from_dlpack(v_max).to(device) | |
| # Warmup | |
| for _ in range(warmup_iters): | |
| _ = compiled.execute( | |
| q_max, | |
| k_max, | |
| v_max, | |
| ) | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| # Benchmark | |
| start_events = [] | |
| end_events = [] | |
| for _ in range(benchmark_iters): | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| result = compiled.execute( | |
| q_max, | |
| k_max, | |
| v_max, | |
| )[0] | |
| end.record() | |
| start_events.append(start) | |
| end_events.append(end) | |
| torch.cuda.synchronize() | |
| latencies = [] | |
| for start, end in zip(start_events, end_events): | |
| latencies.append(start.elapsed_time(end)) | |
| # Convert output back to torch tensor with original shape | |
| output_torch = from_dlpack(result) | |
| return { | |
| "mean_ms": np.mean(latencies), | |
| "p50_ms": np.percentile(latencies, 50), | |
| "p90_ms": np.percentile(latencies, 90), | |
| "p95_ms": np.percentile(latencies, 95), | |
| "p99_ms": np.percentile(latencies, 99), | |
| "std_ms": np.std(latencies), | |
| "median_ms": np.median(latencies), | |
| "min_ms": np.min(latencies), | |
| "max_ms": np.max(latencies), | |
| "output": output_torch, | |
| } | |
| def validate_outputs( | |
| outputs_dict: dict[str, dict[str, Any]], tolerance: float = 1e-2 | |
| ): | |
| """Validate that all implementations produce similar outputs.""" | |
| outputs = {k: v["output"] for k, v in outputs_dict.items() if v is not None} | |
| if len(outputs) < 2: | |
| print("Not enough implementations to validate") | |
| return True | |
| reference_key = list(outputs.keys())[0] | |
| reference = outputs[reference_key] | |
| print(f"Reference: {reference_key}") | |
| valid = True | |
| for name, output in outputs.items(): | |
| if name == reference_key: | |
| continue | |
| max_diff = torch.max(torch.abs(reference - output)).item() | |
| if max_diff > tolerance: | |
| valid = False | |
| print(f"Validation failed: {reference_key} vs {name}, max_diff={max_diff}") | |
| if valid: | |
| print("All outputs match within tolerance") | |
| return valid | |
| def run_benchmark(config: BenchmarkConfig): | |
| """Run the complete benchmark suite.""" | |
| results = [] | |
| print(f"\n{'=' * 80}") | |
| print(f"Running benchmarks on {torch.cuda.get_device_name()}") | |
| print(f"{'=' * 80}\n") | |
| total_configs = ( | |
| len(config.seq_lengths or []) | |
| * len(config.num_heads_list or []) | |
| * len(config.head_dims or []) | |
| ) | |
| current = 0 | |
| for batch_size in config.batch_sizes or []: | |
| for seq_len in config.seq_lengths or []: | |
| for num_heads in config.num_heads_list or []: | |
| for head_dim in config.head_dims or []: | |
| current += 1 | |
| print( | |
| f"[{current}/{total_configs}] Testing: seq_len={seq_len}, " | |
| f"num_heads={num_heads}, head_dim={head_dim}" | |
| ) | |
| q = torch.randn( | |
| batch_size, | |
| num_heads, | |
| seq_len, | |
| head_dim, | |
| dtype=config.dtype, | |
| device="cuda", | |
| ) | |
| k = torch.randn( | |
| batch_size, | |
| num_heads, | |
| seq_len, | |
| head_dim, | |
| dtype=config.dtype, | |
| device="cuda", | |
| ) | |
| v = torch.randn( | |
| batch_size, | |
| num_heads, | |
| seq_len, | |
| head_dim, | |
| dtype=config.dtype, | |
| device="cuda", | |
| ) | |
| bench_results = {} | |
| # PyTorch SDPA (baseline) | |
| print(" - Running PyTorch SDPA w/ CUDNN attention...") | |
| with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): | |
| sdpa_result = benchmark_torch_sdpa( | |
| q, | |
| k, | |
| v, | |
| config.causal, | |
| config.warmup_iters, | |
| config.benchmark_iters, | |
| ) | |
| if sdpa_result: | |
| bench_results["pytorch_sdpa"] = sdpa_result | |
| # FlashInfer | |
| if FLASHINFER_AVAILABLE: | |
| print(" - Running FlashInfer...") | |
| fi_result = benchmark_flashinfer( | |
| q, | |
| k, | |
| v, | |
| config.causal, | |
| config.warmup_iters, | |
| config.benchmark_iters, | |
| ) | |
| if fi_result: | |
| bench_results["flashinfer"] = fi_result | |
| # Flash Attention | |
| if FLASH_ATTN_AVAILABLE: | |
| print(" - Running Flash Attention...") | |
| fa_result = benchmark_flash_attention( | |
| q, | |
| k, | |
| v, | |
| config.causal, | |
| config.warmup_iters, | |
| config.benchmark_iters, | |
| ) | |
| if fa_result: | |
| bench_results["flash_attn"] = fa_result | |
| # Max | |
| if MAX_AVAILABLE: | |
| print(" - Running Max...") | |
| max_result = benchmark_max( | |
| q, | |
| k, | |
| v, | |
| config.causal, | |
| config.warmup_iters, | |
| config.benchmark_iters, | |
| ) | |
| if max_result: | |
| bench_results["max"] = max_result | |
| valid = validate_outputs(bench_results) | |
| for impl_name, impl_results in bench_results.items(): | |
| results.append( | |
| { | |
| "implementation": impl_name, | |
| "batch_size": batch_size, | |
| "seq_length": seq_len, | |
| "num_heads": num_heads, | |
| "head_dim": head_dim, | |
| "mean_latency_ms": impl_results["mean_ms"], | |
| "p50_latency_ms": impl_results["p50_ms"], | |
| "p90_latency_ms": impl_results["p90_ms"], | |
| "p95_latency_ms": impl_results["p95_ms"], | |
| "p99_latency_ms": impl_results["p99_ms"], | |
| "std_latency_ms": impl_results["std_ms"], | |
| "median_latency_ms": impl_results["median_ms"], | |
| "min_latency_ms": impl_results["min_ms"], | |
| "max_latency_ms": impl_results["max_ms"], | |
| "causal": config.causal, | |
| "dtype": str(config.dtype), | |
| "validation_passed": valid, | |
| } | |
| ) | |
| print(" Results:") | |
| for impl_name, impl_results in bench_results.items(): | |
| print( | |
| f" {impl_name:15s}: {impl_results['mean_ms']:.3f} ± " | |
| f"{impl_results['std_ms']:.3f} ms" | |
| ) | |
| print() | |
| return results | |
| def save_results(results: list[dict[str, Any]], output_dir: str = "."): | |
| """Save benchmark results to CSV and summary.""" | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| df = pd.DataFrame(results) | |
| csv_path = f"{output_dir}/flash_benchmark_results_{timestamp}.csv" | |
| df.to_csv(csv_path, index=False) | |
| print(f"Detailed results saved to: {csv_path}") | |
| if len(df["implementation"].unique()) > 1: | |
| pivot = df.pivot_table( | |
| values="mean_latency_ms", | |
| index=["seq_length", "num_heads", "head_dim"], | |
| columns="implementation", | |
| aggfunc="mean", | |
| ) | |
| # Calculate speedups relative to pytorch_sdpa baseline | |
| if "pytorch_sdpa" in pivot.columns: | |
| for col in pivot.columns: | |
| if col != "pytorch_sdpa": | |
| pivot[f"{col}_speedup"] = pivot["pytorch_sdpa"] / pivot[col] | |
| pivot_path = f"{output_dir}/flash_benchmark_comparison_{timestamp}.csv" | |
| pivot.to_csv(pivot_path) | |
| print(f"Comparison table saved to: {pivot_path}") | |
| # Print summary | |
| print("\n" + "=" * 80) | |
| print("BENCHMARK SUMMARY") | |
| print("=" * 80) | |
| print("\nAverage Latencies by Implementation:") | |
| avg_latencies = df.groupby("implementation")["mean_latency_ms"].mean().round(3) | |
| print(avg_latencies) | |
| # Print comprehensive comparison table | |
| # print("\n" + "=" * 80) | |
| # print("ALL FRAMEWORKS COMPARISON (mean latency in ms)") | |
| # print("=" * 80) | |
| # print("\n" + pivot.round(3).to_string()) | |
| if "pytorch_sdpa" in df["implementation"].unique(): | |
| print("\n" + "=" * 80) | |
| print("SPEEDUP SUMMARY (relative to PyTorch SDPA)") | |
| print("=" * 80) | |
| baseline_df = df[df["implementation"] == "pytorch_sdpa"] | |
| for impl in df["implementation"].unique(): | |
| if impl != "pytorch_sdpa": | |
| impl_df = df[df["implementation"] == impl] | |
| # Merge on common configurations | |
| merged = pd.merge( | |
| baseline_df[["seq_length", "num_heads", "head_dim", "mean_latency_ms"]], | |
| impl_df[["seq_length", "num_heads", "head_dim", "mean_latency_ms"]], | |
| on=["seq_length", "num_heads", "head_dim"], | |
| suffixes=("_baseline", f"_{impl}"), | |
| ) | |
| if not merged.empty: | |
| merged["speedup"] = merged["mean_latency_ms_baseline"] / merged[f"mean_latency_ms_{impl}"] | |
| print(f"\n{impl}:") | |
| print(f" Average speedup: {merged['speedup'].mean():.3f}x") | |
| print(f" Median speedup: {merged['speedup'].median():.3f}x") | |
| print(f" Max speedup: {merged['speedup'].max():.3f}x") | |
| print(f" Min speedup: {merged['speedup'].min():.3f}x") | |
| return df | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Benchmark FlashInfer vs Flash Attention on B200" | |
| ) | |
| parser.add_argument( | |
| "--batch-sizes", | |
| type=int, | |
| nargs="+", | |
| default=[1], | |
| help="Batch sizes", | |
| ) | |
| parser.add_argument( | |
| "--seq-lengths", | |
| type=int, | |
| nargs="+", | |
| default=[32, 64, 128, 256, 512, 1024, 1536, 2048, 4096, 8192, 16384], | |
| help="Sequence lengths to test", | |
| ) | |
| parser.add_argument( | |
| "--num-heads", | |
| type=int, | |
| nargs="+", | |
| default=[16, 32, 64, 128], | |
| help="Number of attention heads to test", | |
| ) | |
| parser.add_argument( | |
| "--head-dims", | |
| type=int, | |
| nargs="+", | |
| default=[64, 128], | |
| help="Head dimensions to test", | |
| ) | |
| parser.add_argument( | |
| "--warmup-iters", | |
| type=int, | |
| default=10, | |
| help="Number of warmup iterations", | |
| ) | |
| parser.add_argument( | |
| "--benchmark-iters", | |
| type=int, | |
| default=100, | |
| help="Number of benchmark iterations", | |
| ) | |
| parser.add_argument( | |
| "--no-causal", action="store_true", help="Disable causal masking" | |
| ) | |
| parser.add_argument( | |
| "--fp16", action="store_true", help="Use FP16 instead of BF16" | |
| ) | |
| parser.add_argument( | |
| "--output-dir", type=str, default=".", help="Directory to save results" | |
| ) | |
| args = parser.parse_args() | |
| # Check CUDA availability | |
| if not torch.cuda.is_available(): | |
| print("Error: CUDA is not available!") | |
| sys.exit(1) | |
| # Check for B200 | |
| device_name = torch.cuda.get_device_name() | |
| if "B200" not in device_name and "Blackwell" not in device_name: | |
| print(f"Warning: Current GPU is {device_name}, not B200/Blackwell") | |
| response = input("Continue anyway? (y/n): ") | |
| if response.lower() != "y": | |
| sys.exit(0) | |
| config = BenchmarkConfig( | |
| batch_sizes=args.batch_sizes, | |
| seq_lengths=args.seq_lengths, | |
| num_heads_list=args.num_heads, | |
| head_dims=args.head_dims, | |
| causal=not args.no_causal, | |
| dtype=torch.float16 if args.fp16 else torch.bfloat16, | |
| warmup_iters=args.warmup_iters, | |
| benchmark_iters=args.benchmark_iters, | |
| ) | |
| print("\nBenchmark Configuration:") | |
| print(f" Batch sizes: {config.batch_sizes}") | |
| print(f" Sequence lengths: {config.seq_lengths}") | |
| print(f" Number of heads: {config.num_heads_list}") | |
| print(f" Head dimensions: {config.head_dims}") | |
| print(f" Causal: {config.causal}") | |
| print(f" Data type: {config.dtype}") | |
| print(f" Warmup iterations: {config.warmup_iters}") | |
| print(f" Benchmark iterations: {config.benchmark_iters}") | |
| # Run benchmark | |
| results = run_benchmark(config) | |
| # Save and display results | |
| if results: | |
| save_results(results, args.output_dir) | |
| else: | |
| print("No results collected!") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment