Created
March 3, 2025 14:04
-
-
Save lewtun/adcb88c7753df158a4c88135173e38f2 to your computer and use it in GitHub Desktop.
GRPO benchmarking
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, set_seed | |
import time | |
import torch | |
set_seed(0) | |
device = "cuda" | |
model = AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen2.5-1.5B", | |
attn_implementation="flash_attention_2", | |
torch_dtype="bfloat16" | |
).to(device) | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B") | |
generation_config = GenerationConfig( | |
max_new_tokens=1024, | |
do_sample=False, | |
pad_token_id=tokenizer.pad_token_id, | |
use_cache=True | |
) | |
prompt = r"Let \( a, b, c \) be positive real numbers. Prove that $$ \frac{1}{a(1+b)}+\frac{1}{b(1+c)}+\frac{1}{c(1+a)} \geq \frac{3}{1+abc}, $$ and that equality occurs if and only if \( a = b = c = 1 \)." | |
processed = tokenizer(prompt, return_tensors="pt") | |
input_ids = processed["input_ids"].to(device) | |
attention_mask = processed["attention_mask"].to(device) | |
# Function to get current VRAM usage | |
def get_vram_usage(): | |
return torch.cuda.memory_allocated(device) / 1024**2 # Convert to MB | |
# Record baseline VRAM usage | |
torch.cuda.reset_peak_memory_stats(device) | |
baseline_vram = get_vram_usage() | |
start = time.perf_counter() | |
# __ Critical section __ | |
# Replicate the prompt 8 times | |
input_ids = input_ids.expand(8, -1) | |
attention_mask = attention_mask.expand(8, -1) | |
# Record VRAM before generation | |
vram_before = get_vram_usage() | |
# Generate completions | |
output = model.generate(input_ids, attention_mask=attention_mask, generation_config=generation_config) | |
# Record VRAM after generation | |
vram_after_generate = get_vram_usage() | |
# Forward pass (to get the gradients) | |
attention_mask = (output == tokenizer.eos_token_id).int() | |
_ = model(output, attention_mask=attention_mask) | |
# __ End critical section __ | |
perf = time.perf_counter() - start | |
# Get peak memory usage | |
peak_vram = torch.cuda.max_memory_allocated(device) / 1024**2 # Convert to MB | |
# Print results | |
print(tokenizer.decode(output[0], skip_special_tokens=True)) | |
print(f"Execution time: {perf:.2f} seconds") | |
print(f"Baseline VRAM: {baseline_vram:.2f} MB") | |
print(f"VRAM before generation: {vram_before:.2f} MB") | |
print(f"VRAM after generation: {vram_after_generate:.2f} MB") | |
print(f"VRAM after forward pass: {get_vram_usage():.2f} MB") | |
print(f"Peak VRAM usage: {peak_vram:.2f} MB") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment