Created
February 19, 2024 08:37
-
-
Save Lyken17/70488cd650d842c4154e1eb35f958a27 to your computer and use it in GitHub Desktop.
single model train benchmark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from collections import defaultdict | |
import os, os.path as osp | |
from datasets import load_dataset | |
import torch | |
import json | |
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig | |
from tqdm import tqdm | |
import torch | |
from torch import nn | |
import transformers | |
from accelerate import init_empty_weights | |
def print_gpu_utilization(): | |
from pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit | |
nvmlInit() | |
n = torch.cuda.device_count() | |
used = 0 | |
for index in range(n): | |
handle = nvmlDeviceGetHandleByIndex(index) | |
info = nvmlDeviceGetMemoryInfo(handle) | |
used += info.used | |
# print(index, used) | |
print(f"GPU memory occupied: {used//1024**2} MB.") | |
return used // 1024**2 | |
def benchmark_v2(data, model: nn.Module, warmups=1, totalruns=5, fn=lambda x: x): | |
with torch.no_grad(): | |
for i in range(warmups): | |
out = model(data) | |
torch.cuda.synchronize() | |
s_time = time.time() | |
for i in range(totalruns): | |
out = model(data) | |
torch.cuda.synchronize() | |
e_time = time.time() | |
mem = print_gpu_utilization() | |
print(f"fwd time: {((e_time - s_time) / totalruns * 1000):.2f}") | |
for i in range(warmups): | |
out = model(data) | |
fn(out).sum().backward() | |
torch.cuda.synchronize() | |
s_time = time.time() | |
for i in range(totalruns): | |
out = model(data) | |
fn(out).sum().backward() | |
torch.cuda.synchronize() | |
e_time = time.time() | |
mem = print_gpu_utilization() | |
print(f"fwd + bwd time: {((e_time - s_time) / totalruns * 1000):.2f}") | |
optim = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad]) | |
for i in range(warmups): | |
optim.zero_grad() | |
out = model(data) | |
fn(out).sum().backward() | |
optim.step() | |
torch.cuda.synchronize() | |
s_time = time.time() | |
for i in range(totalruns): | |
optim.zero_grad() | |
out = model(data) | |
fn(out).sum().backward() | |
optim.step() | |
torch.cuda.synchronize() | |
e_time = time.time() | |
mem = print_gpu_utilization() | |
print(f"fwd + bwd + optim time: {((e_time - s_time) / totalruns * 1000):.2f}") | |
return | |
def benchmark(data, model: nn.Module, warmups=10, totalruns=20, no_optim=False): | |
if no_optim: | |
optim = torch.optim.AdamW() | |
else: | |
optim = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad]) | |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
for i in range(warmups): | |
optim.zero_grad() | |
out = model(data) | |
out.logits.sum().backward() | |
optim.step() | |
torch.cuda.synchronize() | |
s_time = time.time() | |
for i in range(totalruns): | |
optim.zero_grad() | |
out = model(data) | |
out.logits.sum().backward() | |
optim.step() | |
torch.cuda.synchronize() | |
e_time = time.time() | |
print(f"time: {((e_time - s_time) / totalruns * 1000):.2f}") | |
mem = print_gpu_utilization() | |
latency = (e_time - s_time) / totalruns | |
print("Optim params:", total_params / (1000**3)) | |
return mem / 1024, latency * 1000, total_params | |
def main( | |
use_flash=True, model_name="NousResearch/Llama-2-7b-hf", bs=1, seqlen=512, log_wandb=False | |
): | |
device = "cuda" | |
print("loading model", model_name, f"{bs}x{seqlen}") | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
device_map="cuda", | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2" | |
) # .to(device) | |
model.config.use_cache = False # disable for training | |
print(bs, seqlen) | |
data = torch.randint(0, 5, (bs, seqlen)).to(device) | |
# print("--" * 20, "LoRA", "--" * 20) | |
# for n, p in model.named_parameters(): | |
# p.requires_grad = False | |
# if "k_proj" in n or "v_proj" in n: | |
# p.requires_grad = True | |
benchmark_v2(data, model, fn=lambda x: x.logits) | |
if __name__ == "__main__": | |
import fire | |
fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment