Created
March 13, 2025 15:51
-
-
Save rasbt/b7980a2a33e33847f08b1543f0627950 to your computer and use it in GitHub Desktop.
Compute memory usage for various LLMs like Gemma, Llama, etc.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Sebastian Raschka 2025 | |
# | |
# | |
# Usage: | |
# python llama-vs-gemma.py \ | |
# --auth_token hf_... \ | |
# --model_name meta-llama/Llama-3.2-1B \ | |
# --prompt medium | |
import argparse | |
from packaging import version | |
import requests | |
import torch | |
import transformers | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
def load_text(): | |
url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch02/01_main-chapter-code/the-verdict.txt" | |
response = requests.get(url) | |
text = response.text | |
return text | |
def print_gpu_memory_usage(stage): | |
if torch.cuda.is_available(): | |
allocated = torch.cuda.memory_allocated(0) / (1024 ** 2) | |
reserved = torch.cuda.memory_reserved(0) / (1024 ** 2) | |
peak_memory = torch.cuda.max_memory_allocated(0) / (1024 ** 2) | |
print(f"{stage} - GPU allocated: {allocated:.2f} MB, GPU Max allocated: {peak_memory:.2f} MB, reserved: {reserved:.2f} MB") | |
else: | |
print("CUDA is not available") | |
def main(args): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print_gpu_memory_usage("Before loading model") | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name, token=args.auth_token) | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name, | |
token=args.auth_token, | |
torch_dtype=torch.bfloat16 | |
).to(device) | |
print_gpu_memory_usage("After loading model") | |
if args.prompt == "short": | |
prompt = "Hello world" | |
elif args.prompt == "medium": | |
prompt = load_text()*3 | |
elif args.prompt == "long": | |
prompt = load_text()*8 | |
else: | |
prompt = args.prompt | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
print("Prompt length:", len(inputs.input_ids[0])) | |
if "gemma-3" in args.model_name: | |
if version.parse(transformers.__version__) < version.parse("4.50.0.dev0"): | |
raise RuntimeError(f"Detected transformers {transformers.__version__} but requires transformers >= 4.50.0") | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, max_new_tokens=50) | |
print_gpu_memory_usage("After generation") | |
result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print("Generated output:", result[-100:]) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Load Llama 3.2 1B model, generate text, and measure GPU memory usage." | |
) | |
parser.add_argument("--auth_token", type=str, required=True, | |
help="Hugging Face access token") | |
parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.2-1B", | |
help="Model name to load from Hugging Face") | |
parser.add_argument("--prompt", type=str, default="Hello world", | |
help="Prompt text to generate text for") | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment