Skip to content

Instantly share code, notes, and snippets.

@rasbt
Created March 13, 2025 15:51
Show Gist options
  • Save rasbt/b7980a2a33e33847f08b1543f0627950 to your computer and use it in GitHub Desktop.
Save rasbt/b7980a2a33e33847f08b1543f0627950 to your computer and use it in GitHub Desktop.
Compute memory usage for various LLMs like Gemma, Llama, etc.
# Sebastian Raschka 2025
#
#
# Usage:
# python llama-vs-gemma.py \
# --auth_token hf_... \
# --model_name meta-llama/Llama-3.2-1B \
# --prompt medium
import argparse
from packaging import version
import requests
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
def load_text():
url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch02/01_main-chapter-code/the-verdict.txt"
response = requests.get(url)
text = response.text
return text
def print_gpu_memory_usage(stage):
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated(0) / (1024 ** 2)
reserved = torch.cuda.memory_reserved(0) / (1024 ** 2)
peak_memory = torch.cuda.max_memory_allocated(0) / (1024 ** 2)
print(f"{stage} - GPU allocated: {allocated:.2f} MB, GPU Max allocated: {peak_memory:.2f} MB, reserved: {reserved:.2f} MB")
else:
print("CUDA is not available")
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_gpu_memory_usage("Before loading model")
tokenizer = AutoTokenizer.from_pretrained(args.model_name, token=args.auth_token)
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
token=args.auth_token,
torch_dtype=torch.bfloat16
).to(device)
print_gpu_memory_usage("After loading model")
if args.prompt == "short":
prompt = "Hello world"
elif args.prompt == "medium":
prompt = load_text()*3
elif args.prompt == "long":
prompt = load_text()*8
else:
prompt = args.prompt
inputs = tokenizer(prompt, return_tensors="pt").to(device)
print("Prompt length:", len(inputs.input_ids[0]))
if "gemma-3" in args.model_name:
if version.parse(transformers.__version__) < version.parse("4.50.0.dev0"):
raise RuntimeError(f"Detected transformers {transformers.__version__} but requires transformers >= 4.50.0")
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=50)
print_gpu_memory_usage("After generation")
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated output:", result[-100:])
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Load Llama 3.2 1B model, generate text, and measure GPU memory usage."
)
parser.add_argument("--auth_token", type=str, required=True,
help="Hugging Face access token")
parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.2-1B",
help="Model name to load from Hugging Face")
parser.add_argument("--prompt", type=str, default="Hello world",
help="Prompt text to generate text for")
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment