Skip to content

Instantly share code, notes, and snippets.

@younesbelkada
Last active November 25, 2024 16:54
Show Gist options
  • Save younesbelkada/02f35734da906cc0f2389ae4f665c58f to your computer and use it in GitHub Desktop.
Save younesbelkada/02f35734da906cc0f2389ae4f665c58f to your computer and use it in GitHub Desktop.
Benchmark FA2 + transformers integration
import torch
import os
import argparse
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import seaborn as sns
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-batches",
type=int,
default=10,
help="",
)
parser.add_argument(
"--max-batch-size",
type=int,
default=16,
help="",
)
parser.add_argument(
"--max-seqlen",
type=int,
default=64,
help="",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=64,
help="",
)
parser.add_argument(
"--bench-backward",
action="store_true",
)
parser.add_argument(
"--bench-generate",
action="store_true",
)
parser.add_argument(
"--use-padding",
action="store_true",
)
return parser
model_id = "meta-llama/Llama-2-7b-hf"
@torch.no_grad()
def warmup_and_benchmark(
model,
batch_size,
max_seq_len,
use_padding,
num_batches,
bench_generate,
bench_backward,
max_new_tokens,
):
input_ids = torch.randint(0, model.config.vocab_size, (batch_size, max_seq_len)).to(0)
inputs = {"input_ids": input_ids}
if use_padding:
attention_mask = torch.zeros_like(input_ids)
attention_mask[:, :max_seq_len // 2] = 1
inputs["attention_mask"] = attention_mask
# warmup
_ = model.generate(**inputs, max_new_tokens=20, eos_token_id=-1, use_cache=False)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.empty_cache()
torch.cuda.synchronize()
with torch.no_grad():
start_event.record()
for _ in range(num_batches):
if bench_generate:
_ = model.generate(**inputs, max_new_tokens=max_new_tokens, eos_token_id=-1, use_cache=False)
else:
_ = model(input_ids)
end_event.record()
torch.cuda.synchronize()
forward_timing = (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches
backward_timing = 0
if bench_backward:
for _ in range(num_batches):
torch.cuda.empty_cache()
torch.cuda.synchronize()
logits = model(input_ids).logits
loss = logits.mean()
start_event.record()
loss.backward()
end_event.record()
torch.cuda.synchronize()
backward_timing += (start_event.elapsed_time(end_event) * 1.0e-3)
return forward_timing, backward_timing / num_batches
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
num_batches = args.num_batches
max_seq_len = args.max_seqlen
max_batch_size = args.max_batch_size
max_new_tokens = args.max_new_tokens
bench_generate = args.bench_generate
bench_backward = args.bench_backward
use_padding = args.use_padding
# TODO: change this
BATCH_SIZE = [max_batch_size // 4, max_batch_size // 2, max_batch_size]
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16
).to(0)
model_fa = AutoModelForCausalLM.from_pretrained(
model_id,
device_map={"":0},
torch_dtype=torch.float16,
use_flash_attention_2=True
)
print("native", model)
print("FA2", model_fa)
native_total_time_dict = {}
fa2_total_time_dict = {}
forward_speedups = {}
backward_speedups = {}
for batch_size in tqdm(BATCH_SIZE):
# warmup
native_timing, native_backward_timing = warmup_and_benchmark(
model,
batch_size,
max_seq_len,
use_padding,
num_batches,
bench_generate,
bench_backward,
max_new_tokens
)
native_total_time_dict[f"{batch_size}"] = native_timing
fa2_timing, fa2_backward_timing = warmup_and_benchmark(
model_fa,
batch_size,
max_seq_len,
use_padding,
num_batches,
bench_generate,
bench_backward,
max_new_tokens
)
fa2_total_time_dict[f"{batch_size}"] = fa2_timing
forward_speedups[f"{batch_size}"] = native_timing / fa2_timing
if bench_backward:
backward_speedups[f"{batch_size}"] = native_backward_timing / fa2_backward_timing
else:
backward_speedups[f"{batch_size}"] = 0
dir_name = f"flash-attn-2-benchmarks/{model_id}/seq_len_{max_seq_len}_padding_{use_padding}_generate_{bench_generate}_max_batch_size_{max_batch_size}/"
os.makedirs(dir_name, exist_ok=True)
sns.set(style="darkgrid")
# plot both lines
sns.lineplot(data=native_total_time_dict, color="blue", label="llama2-native")
sns.lineplot(data=fa2_total_time_dict, color="orange", label="llama2-FA2")
plt.ylabel("Average inference time (s)")
plt.xlabel("Batch size")
plt.title("Comparing average inference time between native model vs Flash Attention-2 model - ran on NVIDIA A100", fontsize = 8)
plt.suptitle(f"Sequence length {max_seq_len} | Use generate {bench_generate} | Use padding {use_padding} - ", fontsize = 8)
plt.legend()
# save plot
plt.savefig(os.path.join(dir_name, "timing_plot.jpg"), dpi=300)
plt.figure()
sns.set(style="darkgrid")
# plot both lines
sns.lineplot(data=forward_speedups, color="orange", label="forward-speedup")
if bench_backward:
sns.lineplot(data=backward_speedups, color="blue", label="backward-speedup")
plt.ylabel("Speedup (x)")
plt.xlabel("Batch size")
plt.title("Comparing forward/backward speedup between native model vs Flash Attention-2 model - ran on NVIDIA A100", fontsize = 8)
plt.suptitle(f"Sequence length {max_seq_len} | Use generate {bench_generate} | Use padding {use_padding} - ", fontsize = 8)
plt.legend()
# save plot
plt.savefig(os.path.join(dir_name, "speedup_plot.jpg"), dpi=300)
@njbrake
Copy link

njbrake commented Sep 4, 2024

image

Hi! I thought I would pitch in here with a comment: I made a few modifications to the script in order to measure the performance change as sequence length increases. Sort of confusingly I'm not seeing much of a difference between Flash-attention and native, and I'm curious if you've seen similar results?

import torch
import os
import argparse
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import seaborn as sns


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--num-batches",
        type=int,
        default=1,
        help="",
    )
    parser.add_argument(
        "--max-seqlen",
        type=int,
        default=512,
        help="",
    )
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=512,
        help="",
    )
    return parser


model_id = "meta-llama/Meta-Llama-3.1-8B"


@torch.no_grad()
def warmup_and_benchmark(
    model,
    tokenizer,
    max_seq_len,
    num_batches,
    max_new_tokens,
):
    inputs = tokenizer("Hi" * max_seq_len, return_tensors="pt").to("cuda")

    # warmup
    _ = model.generate(
        **inputs,
        max_new_tokens=20,
        pad_token_id=tokenizer.eos_token_id,
        use_cache=False,
    )

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()

    with torch.no_grad():
        start_event.record()
        for _ in range(num_batches):
            _ = model.generate(
                **inputs,
                pad_token_id=tokenizer.eos_token_id,
                max_new_tokens=max_new_tokens,
                use_cache=False,
            )
        end_event.record()
        torch.cuda.synchronize()

    forward_timing = (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches

    return forward_timing


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()

    num_batches = args.num_batches
    max_seq_len = args.max_seqlen
    max_new_tokens = args.max_new_tokens

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token_id = tokenizer.eos_token_id

    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.float16
    ).to("cuda")

    model_fa = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map={"": 0},
        torch_dtype=torch.float16,
        attn_implementation="flash_attention_2",
    )

    native_total_time_dict = {}
    fa2_total_time_dict = {}
    forward_speedups = {}
    for max_seq_len in [256, 1024, 2048, 4096, 8192]:
        print(f"Running for sequence length {max_seq_len}")
        native_timing = warmup_and_benchmark(
            model,
            tokenizer,
            max_seq_len,
            num_batches,
            max_new_tokens,
        )
        native_total_time_dict[f"{max_seq_len}"] = native_timing

        fa2_timing = warmup_and_benchmark(
            model_fa,
            tokenizer,
            max_seq_len,
            num_batches,
            max_new_tokens,
        )
        fa2_total_time_dict[f"{max_seq_len}"] = fa2_timing

        forward_speedups[f"{max_seq_len}"] = native_timing / fa2_timing

    dir_name = f"flash-attn-2-benchmarks/{model_id}/seq_len_{max_seq_len}/"
    os.makedirs(dir_name, exist_ok=True)

    sns.set(style="darkgrid")
    # plot both lines
    sns.lineplot(data=native_total_time_dict, color="blue", label=f"{model_id}-native")
    sns.lineplot(data=fa2_total_time_dict, color="orange", label=f"{model_id}-FA2")

    plt.ylabel("Average inference time (s)")
    plt.xlabel("Seq Length")
    plt.title(
        "Comparing average inference time between native model vs Flash Attention-2 model",
        fontsize=8,
    )

    plt.legend()

    # save plot
    plt.savefig(os.path.join(dir_name, "timing_plot.jpg"), dpi=300)

    # plt.figure()
    # sns.set(style="darkgrid")
    # # plot both lines
    # sns.lineplot(data=forward_speedups, color="orange", label="forward-speedup")

@njbrake
Copy link

njbrake commented Sep 4, 2024

Here's an interesting one. If I change the code to set use_cache=True, then the difference between FA and native becomes clear as the sequence length increases. I'm using a single NVIDIA L40S

image

accelerate==0.34.0
bitsandbytes==0.43.3
flash-attn==2.6.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.6.68
nvidia-nvtx-cu12==12.1.105
peft==0.12.0
tokenizers==0.19.1
torch==2.2.0
transformers==4.44.2
triton==2.2.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment