Created
December 18, 2023 03:12
-
-
Save stas00/80e10917650feb167ad9b3b7235b0c4a to your computer and use it in GitHub Desktop.
memory allocations breakdown
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "raw", | |
"id": "8e3cebc7-369d-4cf0-b9bf-63555f042bb2", | |
"metadata": {}, | |
"source": [ | |
"pip install transformers nvidia-ml-py3 einops ipyexperiments" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "43d0272a-78b0-48ac-b6d1-e7b57dc01650", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn.functional as F\n", | |
"import pynvml\n", | |
"from transformers import AutoModelForCausalLM, AutoTokenizer\n", | |
"from ipyexperiments import IPyExperimentsPytorch\n", | |
"import gc\n", | |
"import os\n", | |
"os.environ['CUDA_MODULE_LOADING'] = 'EAGER' # force kernel preloading" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ef691ce6-276d-4d79-855a-58ad721b5af0", | |
"metadata": {}, | |
"source": [ | |
"# Run parameters" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "00cb2ead-719b-4e17-8558-3c1ae4bb0d3f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"device = torch.device(\"cuda\")\n", | |
"model_name_or_path = \"microsoft/phi-1_5\" # microsoft/phi-1_5, microsoft/phi-2, NousResearch/Llama-2-7b-hf, mistralai/Mistral-7B-v0.1, gpt2, gpt2-medium, gpt2-large, gpt2-xl\n", | |
"dtype = torch.float32\n", | |
"mixed_precision_training = True\n", | |
"bs = 2\n", | |
"seq_length = 128\n", | |
"get_optimizer = lambda parameters: torch.optim.SGD(parameters, lr=0.1, momentum=0.9) # SGD(parameters, lr=0.1), SGD(parameters, lr=0.1, momentum=0.9), AdamW(parameters, lr=0.1)\n", | |
"\n", | |
"if mixed_precision_training:\n", | |
" assert dtype == torch.float32" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "3d48830c-cd59-489b-b500-459eb647c1cd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CUDA kernels VRAM: 955 MiB\n", | |
"\n", | |
"*** Experiment started with the Pytorch backend\n", | |
"Device: ID 0, NVIDIA A100 80GB PCIe (81920 RAM)\n", | |
"\n", | |
"\n", | |
"*** Current state:\n", | |
"RAM: Used Free Total Util\n", | |
"CPU: 3,106 85,241 128,649 MB 2.41% \n", | |
"GPU: 1,885 80,034 81,920 MB 2.30% \n", | |
"\n", | |
"\n", | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.000\n", | |
"・ CPU: 0 0 3,106 MB |\n", | |
"・ GPU: 0 0 1,885 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"n_bytes_per_param = 2 if dtype in (torch.float16, torch.bfloat16) else 4\n", | |
"\n", | |
"pynvml.nvmlInit()\n", | |
"handle = pynvml.nvmlDeviceGetHandleByIndex(0)\n", | |
"get_vram = lambda: pynvml.nvmlDeviceGetMemoryInfo(handle).used / 2**20 # MiB\n", | |
"\n", | |
"start_vram = get_vram()\n", | |
"\n", | |
"# Initializing CUDA kernels\n", | |
"a = torch.ones((1,1)).to(device); del a; torch.cuda.empty_cache()\n", | |
"cuda_kernels_vram = get_vram() - start_vram\n", | |
"print(f\"CUDA kernels VRAM: {cuda_kernels_vram:.0f} MiB\")\n", | |
"\n", | |
"exp = IPyExperimentsPytorch()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1689d757-e854-45b8-a35d-3e6e31994b83", | |
"metadata": {}, | |
"source": [ | |
"# Loading model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "56bcc214-a1f6-43f5-836c-157be2afd2de", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.304\n", | |
"・ CPU: 36 9 3,143 MB |\n", | |
"・ GPU: 0 0 1,885 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)\n", | |
"if tokenizer.pad_token is None:\n", | |
" tokenizer.pad_token = tokenizer.eos_token" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "90d8cc03-2889-4ee0-9869-9d932bd86ac1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"PhiConfig {\n", | |
" \"_name_or_path\": \"microsoft/phi-1_5\",\n", | |
" \"activation_function\": \"gelu_new\",\n", | |
" \"architectures\": [\n", | |
" \"PhiForCausalLM\"\n", | |
" ],\n", | |
" \"attn_pdrop\": 0.0,\n", | |
" \"auto_map\": {\n", | |
" \"AutoConfig\": \"microsoft/phi-1_5--configuration_phi.PhiConfig\",\n", | |
" \"AutoModelForCausalLM\": \"microsoft/phi-1_5--modeling_phi.PhiForCausalLM\"\n", | |
" },\n", | |
" \"embd_pdrop\": 0.0,\n", | |
" \"flash_attn\": false,\n", | |
" \"flash_rotary\": false,\n", | |
" \"fused_dense\": false,\n", | |
" \"initializer_range\": 0.02,\n", | |
" \"layer_norm_epsilon\": 1e-05,\n", | |
" \"model_type\": \"phi-msft\",\n", | |
" \"n_embd\": 2048,\n", | |
" \"n_head\": 32,\n", | |
" \"n_head_kv\": null,\n", | |
" \"n_inner\": null,\n", | |
" \"n_layer\": 24,\n", | |
" \"n_positions\": 2048,\n", | |
" \"resid_pdrop\": 0.0,\n", | |
" \"rotary_dim\": 32,\n", | |
" \"tie_word_embeddings\": false,\n", | |
" \"torch_dtype\": \"float32\",\n", | |
" \"transformers_version\": \"4.37.0.dev0\",\n", | |
" \"use_cache\": false,\n", | |
" \"vocab_size\": 51200\n", | |
"}\n", | |
"\n", | |
"===========================================================================\n", | |
"PhiForCausalLM(\n", | |
" (transformer): PhiModel(\n", | |
" (embd): Embedding(\n", | |
" (wte): Embedding(51200, 2048)\n", | |
" (drop): Dropout(p=0.0, inplace=False)\n", | |
" )\n", | |
" (h): ModuleList(\n", | |
" (0-23): 24 x ParallelBlock(\n", | |
" (ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n", | |
" (resid_dropout): Dropout(p=0.0, inplace=False)\n", | |
" (mixer): MHA(\n", | |
" (rotary_emb): RotaryEmbedding()\n", | |
" (Wqkv): Linear(in_features=2048, out_features=6144, bias=True)\n", | |
" (out_proj): Linear(in_features=2048, out_features=2048, bias=True)\n", | |
" (inner_attn): SelfAttention(\n", | |
" (drop): Dropout(p=0.0, inplace=False)\n", | |
" )\n", | |
" (inner_cross_attn): CrossAttention(\n", | |
" (drop): Dropout(p=0.0, inplace=False)\n", | |
" )\n", | |
" )\n", | |
" (mlp): MLP(\n", | |
" (fc1): Linear(in_features=2048, out_features=8192, bias=True)\n", | |
" (fc2): Linear(in_features=8192, out_features=2048, bias=True)\n", | |
" (act): NewGELUActivation()\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (lm_head): CausalLMHead(\n", | |
" (ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n", | |
" (linear): Linear(in_features=2048, out_features=51200, bias=True)\n", | |
" )\n", | |
" (loss): CausalLMLoss(\n", | |
" (loss_fct): CrossEntropyLoss()\n", | |
" )\n", | |
")\n", | |
"===========================================================================\n", | |
"Number of parameters: 1.418 B (1418271104)\n", | |
"Model VRAM usage: 5496 MiB (expected 5410 MiB, error 1.6 %)\n", | |
"===========================================================================\n", | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:04.714\n", | |
"・ CPU: 333 7,823 3,476 MB |\n", | |
"・ GPU: 5,496 0 7,381 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype, trust_remote_code=True).to(device)\n", | |
"model.config.use_cache = False\n", | |
"\n", | |
"n_parameters = sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in model.buffers())\n", | |
"model_estimated_vram = n_parameters * n_bytes_per_param / 2**20\n", | |
"model_actual_vram = get_vram() - cuda_kernels_vram - start_vram\n", | |
"\n", | |
"#n_buffers = sum(p.numel() for p in model.buffers())\n", | |
"\n", | |
"print(model.config)\n", | |
"print(\"=\" * 75)\n", | |
"print(model)\n", | |
"print(\"=\" * 75)\n", | |
"print(f\"Number of parameters: {(n_parameters / 1e9):.3f} B ({n_parameters})\")\n", | |
"print(f\"Model VRAM usage: {(model_actual_vram):.0f} MiB (expected {(model_estimated_vram):.0f} MiB, error {((model_actual_vram - model_estimated_vram) * 100 / model_actual_vram):.1f} %)\")\n", | |
"print(\"=\" * 75)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "90db98fc-8de2-4c23-97b1-3bdbf780b7c7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"For batch of 4 items with a sequence length of 512 it will consume 0.046875 MiB VRAM\n", | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.000\n", | |
"・ CPU: 0 0 3,477 MB |\n", | |
"・ GPU: 0 0 7,381 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"bs = 4\n", | |
"seq_length = 512\n", | |
"\n", | |
"batch_vram = 3 * bs * seq_length * 8 # 3 for input_ids, attention_masks and labels; 8 for each i64\n", | |
"print(f\"For batch of {bs} items with a sequence length of {seq_length} it will consume {batch_vram / 2**20} MiB VRAM\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "5762f438-77e1-4c6d-a3c9-7eaeab21dd85", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[28640, 46785, 7134, ..., 32685, 9896, 1042],\n", | |
" [28685, 26733, 956, ..., 28010, 12865, 29406],\n", | |
" [19038, 45183, 9541, ..., 14378, 25289, 32570],\n", | |
" [45000, 35482, 9371, ..., 11262, 33852, 2560]], device='cuda:0')" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[1., 1., 1., ..., 1., 1., 1.],\n", | |
" [1., 1., 1., ..., 1., 1., 1.],\n", | |
" [1., 1., 1., ..., 1., 1., 1.],\n", | |
" [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0')" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.036\n", | |
"・ CPU: 13 0 3,490 MB |\n", | |
"・ GPU: 0 0 7,381 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"input_ids = torch.randint(0, len(tokenizer), (bs, seq_length)).to(device)\n", | |
"attention_mask = torch.ones((bs, seq_length)).to(device)\n", | |
"input_ids\n", | |
"attention_mask" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0b004db1-7540-47c7-a498-aa8f43b910b6", | |
"metadata": {}, | |
"source": [ | |
"# Warmup Inference forward pass" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "41403d8a-7934-4294-b5ba-a50a6d5b69f9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:01.220\n", | |
"・ CPU: 1,865 0 5,355 MB |\n", | |
"・ GPU: 394 0 7,775 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"# warmup - could possibly load some modules / allocate structures - this is your missing eps_ram\n", | |
"_ = model.eval()\n", | |
"input_ids_1 = torch.randint(0, len(tokenizer), (1, 1)).to(device)\n", | |
"attention_mask_1 = torch.ones((1, 1)).to(device)\n", | |
"with torch.no_grad():\n", | |
" out = model(input_ids=input_ids_1, attention_mask=attention_mask_1)\n", | |
" # probs = F.softmax(out.logits[:, -1, :], dim=-1) # for inference we need probabilities only over the last token; omit this as it is very small\n", | |
" del out" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8dc419ea-cacd-4709-8a05-942c06d4887f", | |
"metadata": {}, | |
"source": [ | |
"# Real Inference forward pass" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "c701ccd7-3b5a-42e1-a7e3-2de4a78eeb66", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.063\n", | |
"・ CPU: 0 0 5,355 MB |\n", | |
"・ GPU: 400 566 8,175 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"# real run\n", | |
"_ = model.eval()\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" out = model(input_ids=input_ids, attention_mask=attention_mask)\n", | |
" # probs = F.softmax(out.logits[:, -1, :], dim=-1) # for inference we need probabilities only over the last token; omit this as it is very small" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "75bbcef2-4b7b-4a21-b6c3-fe5206ee3ec5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Out tensor dtype: torch.float32\n", | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.000\n", | |
"・ CPU: 0 0 5,355 MB |\n", | |
"・ GPU: -400 0 7,775 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"out_bs, out_sequence_length, out_embedding_size = out.logits.shape\n", | |
"n_bytes_per_param_out = 2 if out.logits.dtype in (torch.float16, torch.bfloat16) else 4\n", | |
"output_estimated_vram = out_bs * out_sequence_length * out_embedding_size * n_bytes_per_param_out / 2**20\n", | |
"print(f\"Out tensor dtype: {out.logits.dtype}\")\n", | |
"del out; torch.cuda.empty_cache() # calling `free` on allocated memory for `out` tensor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "2481fa80-b294-4c48-906b-5b4631670ff2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Total forward pass VRAM usage: 394 MiB\n", | |
"Output tensor with bs 4, seq length 512 and emb size 51200 VRAM usage: 0 MiB (expected 400 MiB)\n", | |
"Activations VRAM usage: 0 MiB\n", | |
"Random eps VRAM: 394 MiB\n", | |
"===========================================================================\n", | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.009\n", | |
"・ CPU: 0 0 5,355 MB |\n", | |
"・ GPU: 0 0 7,775 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"total_forward_pass_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram\n", | |
"torch.cuda.empty_cache() # calling `free` on allocated memory for forward pass\n", | |
"output_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram\n", | |
"\n", | |
"eps_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram # idk what is that, but it is small\n", | |
"\n", | |
"output_actual_vram = output_vram - eps_vram\n", | |
"activations_actual_vram = total_forward_pass_vram - output_actual_vram - eps_vram\n", | |
"\n", | |
"print(f\"Total forward pass VRAM usage: {total_forward_pass_vram:.0f} MiB\")\n", | |
"print(f\"Output tensor with bs {out_bs}, seq length {out_sequence_length} and emb size {out_embedding_size} VRAM usage: {output_actual_vram:.0f} MiB (expected {output_estimated_vram:.0f} MiB)\")\n", | |
"print(f\"Activations VRAM usage: {activations_actual_vram:.0f} MiB\")\n", | |
"print(f\"Random eps VRAM: {eps_vram:.0f} MiB\")\n", | |
"#print(torch.cuda.memory_summary())\n", | |
"print(\"=\" * 75)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fd1f8bac-c9a4-42b5-bad8-3f5e41fba771", | |
"metadata": {}, | |
"source": [ | |
"# Warm up Training step" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "43fe4164-5c23-43e9-93e2-df08e6c09927", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.013\n", | |
"・ CPU: 0 0 5,355 MB |\n", | |
"・ GPU: 0 0 7,775 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"# warmup\n", | |
"_ = model.train()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "df65b3fe-4ffb-4def-a139-a21bf65dcf43", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.065\n", | |
"・ CPU: 2 0 5,358 MB |\n", | |
"・ GPU: 2 2,556 7,777 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"# check forward - we already run fwd during inference - so expecting no additional memory allocated \n", | |
"with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=mixed_precision_training):\n", | |
" out = model(input_ids=input_ids_1, attention_mask=attention_mask_1)\n", | |
" probs = F.softmax(out.logits, dim=-1)\n", | |
" loss = F.cross_entropy(probs.permute(0, 2, 1), input_ids_1) # mapping tokens into themselves\n", | |
"\n", | |
"del out\n", | |
"del probs\n", | |
"del loss\n", | |
"\n", | |
"# no leaks here" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "36c4e5a6-f830-4a20-b43b-5f634a061b52", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.001\n", | |
"・ CPU: 0 0 5,358 MB |\n", | |
"・ GPU: 0 0 7,777 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"optimizer = get_optimizer(model.parameters())\n", | |
"scaler = torch.cuda.amp.GradScaler(enabled=mixed_precision_training)\n", | |
"del scaler\n", | |
"del optimizer\n", | |
"\n", | |
"# no leaks here\n", | |
"# and it didn't even allocate any memory for either object" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "a72252fb-b2ef-4e57-affd-c8e99aa57272", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"grads memory: 5410.275 MB\n", | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.130\n", | |
"・ CPU: 3 0 5,362 MB |\n", | |
"・ GPU: 5,636 1,848 13,413 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"# now running backward for the first time - so expects the grads memory allocation to occur\n", | |
"with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=mixed_precision_training):\n", | |
" out = model(input_ids=input_ids_1, attention_mask=attention_mask_1)\n", | |
" probs = F.softmax(out.logits, dim=-1)\n", | |
" loss = F.cross_entropy(probs.permute(0, 2, 1), input_ids_1) # mapping tokens into themselves\n", | |
"loss.backward()\n", | |
"\n", | |
"del out\n", | |
"del probs\n", | |
"del loss\n", | |
"\n", | |
"# backward manifested grads here \n", | |
"grads_estimated_vram = n_parameters * n_bytes_per_param / 2**20\n", | |
"\n", | |
"print(f\"grads memory: {grads_estimated_vram:.3f} MB\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "33c391dc-cfa8-4c20-b0cc-49692f7c9073", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.098\n", | |
"・ CPU: 0 0 5,362 MB |\n", | |
"・ GPU: 2 2,948 13,415 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"# now we expect the optimizer states to be manifested\n", | |
"optimizer = get_optimizer(model.parameters())\n", | |
"\n", | |
"with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=mixed_precision_training):\n", | |
" out = model(input_ids=input_ids_1, attention_mask=attention_mask_1)\n", | |
" probs = F.softmax(out.logits, dim=-1)\n", | |
" loss = F.cross_entropy(probs.permute(0, 2, 1), input_ids_1) # mapping tokens into themselves\n", | |
"loss.backward()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "a4e449a0-7ea0-488e-965d-faaa5e947844", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"optim_states estimated memory: 5410.275 MB\n", | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.076\n", | |
"・ CPU: 0 0 5,363 MB |\n", | |
"・ GPU: 5,374 2 18,789 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"optimizer.step() # we can see here the optim states get allocated only on the first step()\n", | |
"\n", | |
"del out\n", | |
"del probs\n", | |
"del loss\n", | |
"\n", | |
"optim_states_estimated_vram = n_parameters * n_bytes_per_param / 2**20\n", | |
"\n", | |
"print(f\"optim_states estimated memory: {optim_states_estimated_vram:.3f} MB\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "68ef439d-6f4b-42df-ac87-8537b850cbe0", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.001\n", | |
"・ CPU: 0 0 5,363 MB |\n", | |
"・ GPU: -5,404 0 13,385 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"# free grads\n", | |
"optimizer.zero_grad(set_to_none=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "3675c67e-1411-4fb8-a361-28c61923eedb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.001\n", | |
"・ CPU: 0 0 5,363 MB |\n", | |
"・ GPU: -5,578 0 7,807 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"# free optim states\n", | |
"del optimizer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "75894651-a847-4fc5-84de-f0a871e28e30", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"grads + optim_states estimated_vram: 10820.5 MB\n", | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.000\n", | |
"・ CPU: 0 0 5,363 MB |\n", | |
"・ GPU: 0 0 7,807 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"# optim states + grads memory\n", | |
"grads_n_optim_states_vram = grads_estimated_vram + optim_states_estimated_vram \n", | |
"print(f\"grads + optim_states estimated_vram: {grads_n_optim_states_vram:.1f} MB\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "d4237cf4-1a48-4e43-be99-c3018f98429c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.192\n", | |
"・ CPU: 0 0 5,363 MB |\n", | |
"・ GPU: 0 12,478 7,807 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"# full warmup train step with reset - check that allocated memory before and after is the same\n", | |
"_ = model.train()\n", | |
"optimizer = get_optimizer(model.parameters())\n", | |
"\n", | |
"with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=mixed_precision_training):\n", | |
" out = model(input_ids=input_ids_1, attention_mask=attention_mask_1)\n", | |
" probs = F.softmax(out.logits, dim=-1)\n", | |
" loss = F.cross_entropy(probs.permute(0, 2, 1), input_ids_1) # mapping tokens into themselves\n", | |
"loss.backward()\n", | |
"optimizer.step()\n", | |
"optimizer.zero_grad(set_to_none=True)\n", | |
"\n", | |
"del out\n", | |
"del probs\n", | |
"del loss\n", | |
"del optimizer\n", | |
"\n", | |
"# from peak memory delta we can see that optim states + grads that were allocated and then freed - do checkout - the rest of the peak memory is activations memory" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3cacc100-5376-43e5-85ea-ef938f600c8f", | |
"metadata": {}, | |
"source": [ | |
"# Real Training step" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "ca1837d7-5e00-480a-99f0-0fa0cd88e700", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model gradients type: torch.float32\n", | |
"Total train forward pass VRAM usage (activations + output tensor): 15930 MiB (expect 2705 MiB of these to be for fp16 weights copy)\n", | |
"Activations VRAM usage: 12825 MiB\n", | |
"Gradients VRAM usage: 4864 MiB (model weights were 5496 MiB)\n", | |
"Actual optimizer states VRAM usage: 6312 MiB\n", | |
"Random eps VRAM usage: 426 MiB\n", | |
"===========================================================================\n", | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.601\n", | |
"・ CPU: 0 0 5,363 MB |\n", | |
"・ GPU: 0 17,898 7,807 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"_ = model.train()\n", | |
"optimizer = get_optimizer(model.parameters())\n", | |
"\n", | |
"with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=mixed_precision_training):\n", | |
" out = model(input_ids=input_ids, attention_mask=attention_mask)\n", | |
" total_train_forward_pass_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n", | |
" \n", | |
" probs = F.softmax(out.logits, dim=-1)\n", | |
" probs_vram = get_vram() - total_train_forward_pass_vram - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n", | |
" \n", | |
" loss = F.cross_entropy(probs.permute(0, 2, 1), input_ids) # mapping tokens into themselves\n", | |
" loss_calculation_vram = get_vram() - probs_vram - total_train_forward_pass_vram - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n", | |
"loss.backward()\n", | |
"optimizer.step()\n", | |
"\n", | |
"backward_vram = get_vram() - loss_calculation_vram - probs_vram - total_train_forward_pass_vram - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n", | |
"\n", | |
"print(f\"Model gradients type: {next(model.parameters()).grad.dtype}\")\n", | |
"print(f\"Total train forward pass VRAM usage (activations + output tensor): {total_train_forward_pass_vram:.0f} MiB\" + (f\" (expect {(n_parameters * 2 / 2**20):.0f} MiB of these to be for fp16 weights copy)\" if mixed_precision_training else \"\"))\n", | |
"print(f\"Activations VRAM usage: {(total_train_forward_pass_vram - (n_parameters * 2 / 2**20 if mixed_precision_training else 0) - output_estimated_vram):.0f} MiB\")\n", | |
"#print(f\"Actual probs tensor VRAM usage: {probs_vram:.0f} MiB\")\n", | |
"#print(f\"Loss calculation VRAM usage: {loss_calculation_vram:.0f} MiB\")\n", | |
"#print(f\"Backward calculation VRAM usage: {backward_vram:.0f} MiB\")\n", | |
"\n", | |
"del out\n", | |
"del probs\n", | |
"del loss\n", | |
"torch.cuda.empty_cache() # calling `free` on allocated memory for activations and outputs\n", | |
"\n", | |
"gradients_optimizer_total_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n", | |
"optimizer.zero_grad(set_to_none=True); torch.cuda.empty_cache()\n", | |
"optimizer_total_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n", | |
"del optimizer; torch.cuda.empty_cache()\n", | |
"eps_2_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n", | |
"\n", | |
"gradients_actual_vram = gradients_optimizer_total_vram - optimizer_total_vram\n", | |
"optimizer_actual_vram = optimizer_total_vram - eps_2_vram\n", | |
"print(f\"Gradients VRAM usage: {gradients_actual_vram:.0f} MiB (model weights were {model_actual_vram:.0f} MiB)\")\n", | |
"print(f\"Actual optimizer states VRAM usage: {optimizer_actual_vram:.0f} MiB\")\n", | |
"\n", | |
"eps_vram += eps_2_vram\n", | |
"print(f\"Random eps VRAM usage: {eps_vram:.0f} MiB\")\n", | |
"print(\"=\" * 75)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cfd62ab9-22e3-4e84-943e-82f4b63762a1", | |
"metadata": {}, | |
"source": [ | |
"# Estimation activations" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "ca072747-a715-4827-a4a1-3335b9c844a5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Calculating size of activation for single block with:\n", | |
"hidden size 2048\n", | |
"num attention heads 32\n", | |
"num key value heads 32\n", | |
"intermediate size 8192\n", | |
"head dim 64\n", | |
"num hidden layers 24\n", | |
"===========================================================================\n", | |
"Single layer (out of 24) estimated activations VRAM usage: 296 MiB\n", | |
"All layers estimated activations VRAM usage: 7104 MiB\n", | |
"Estimated activations on inference forward pass VRAM usage (softmax output + v): 72 MiB\n", | |
"===========================================================================\n", | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.001\n", | |
"・ CPU: 0 0 5,363 MB |\n", | |
"・ GPU: 0 0 7,807 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"n_bytes_per_param = 2 if mixed_precision_training or dtype in (torch.float16, torch.bfloat16) else 4\n", | |
"\n", | |
"hidden_size = model.config.hidden_size\n", | |
"num_attention_heads = model.config.num_attention_heads\n", | |
"num_key_value_heads = model.config.num_key_value_heads if hasattr(model.config, \"num_key_value_heads\") else model.config.num_attention_heads # different from num_attention_heads in case of GQA\n", | |
"intermediate_size = model.config.intermediate_size if hasattr(model.config, \"intermediate_size\") else 4 * model.config.hidden_size # MLP projection\n", | |
"num_hidden_layers = model.config.num_hidden_layers\n", | |
"head_dim = hidden_size // num_attention_heads\n", | |
"print(f\"Calculating size of activation for single block with:\\nhidden size {hidden_size}\\nnum attention heads {num_attention_heads}\\nnum key value heads {num_key_value_heads}\\nintermediate size {intermediate_size}\\nhead dim {head_dim}\\nnum hidden layers {num_hidden_layers}\")\n", | |
"print(\"=\" * 75)\n", | |
"\n", | |
"attention_input = n_bytes_per_param * bs * seq_length * hidden_size\n", | |
"q = n_bytes_per_param * bs * seq_length * head_dim * num_attention_heads # for Q @ K.T\n", | |
"k = n_bytes_per_param * bs * seq_length * head_dim * num_key_value_heads # num_key_value_heads might be different from num_attention_heads in case of GQA\n", | |
"softmax_output = n_bytes_per_param * bs * num_attention_heads * seq_length ** 2 # to multiply with V\n", | |
"softmax_dropout_mask = 1 * bs * num_attention_heads * seq_length ** 2 # single byte per elem\n", | |
"dropout_output = n_bytes_per_param * bs * num_attention_heads * seq_length ** 2\n", | |
"v = n_bytes_per_param * bs * seq_length * head_dim * num_key_value_heads\n", | |
"out_proj_input = n_bytes_per_param * bs * seq_length * num_attention_heads * head_dim\n", | |
"attention_dropout = 1 * bs * seq_length * hidden_size\n", | |
"#attention_block = attention_input + q + k + softmax_output + v + out_proj_input\n", | |
"attention_block = attention_input + q + k + softmax_output + v + out_proj_input + softmax_dropout_mask + dropout_output + attention_dropout\n", | |
"\n", | |
"mlp_input = n_bytes_per_param * bs * seq_length * hidden_size\n", | |
"activation_input = n_bytes_per_param * bs * seq_length * intermediate_size # SiLU\n", | |
"down_proj_input = n_bytes_per_param * bs * seq_length * intermediate_size\n", | |
"dropout_mask = 1 * bs * seq_length * hidden_size # single byte per elem\n", | |
"#mlp_block = mlp_input + activation_input + down_proj_input\n", | |
"mlp_block = mlp_input + activation_input + down_proj_input + dropout_mask\n", | |
"\n", | |
"layer_norms = n_bytes_per_param * bs * seq_length * hidden_size * 2 # 2 layer norms\n", | |
"\n", | |
"layer = attention_block + mlp_block + layer_norms\n", | |
"print(f\"Single layer (out of {num_hidden_layers}) estimated activations VRAM usage: {layer // 2**20} MiB\")\n", | |
"print(f\"All layers estimated activations VRAM usage: {layer * num_hidden_layers // 2**20} MiB\")\n", | |
"print(f\"Estimated activations on inference forward pass VRAM usage (softmax output + v): {(softmax_output + v) // 2**20} MiB\")\n", | |
"print(\"=\" * 75)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "7d70deac-a86a-403d-a21b-097e77e932fe", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.007\n", | |
"・ CPU: 0 0 5,363 MB |\n", | |
"・ GPU: 0 0 7,807 MB |\n" | |
] | |
} | |
], | |
"source": [ | |
"# https://arxiv.org/pdf/2205.05198.pdf\n", | |
"\n", | |
"def calculate_attention_block():\n", | |
" return 11 * seq_length * bs * hidden_size + 5 * num_attention_heads * seq_length ** 2 * bs\n", | |
"\n", | |
"def calculate_mlp_block():\n", | |
" return 19 * seq_length * bs * hidden_size\n", | |
"\n", | |
"def calculate_layernorms():\n", | |
" return 4 * seq_length * bs * hidden_size\n", | |
"\n", | |
"def calculate_per_layer():\n", | |
" return seq_length * bs * hidden_size * (34 + 5 * num_attention_heads * seq_length / hidden_size)\n", | |
"\n", | |
"assert calculate_attention_block() + calculate_mlp_block() + calculate_layernorms() == calculate_per_layer()" | |
] | |
}, | |
{ | |
"cell_type": "raw", | |
"id": "23628603-07f8-4f3f-9731-018939acf519", | |
"metadata": {}, | |
"source": [ | |
"from torch.profiler import profile, record_function, ProfilerActivity\n", | |
"\n", | |
"with profile(activities=[ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:\n", | |
" with torch.no_grad():\n", | |
" out = model(input_ids=input_ids, attention_mask=attention_mask)\n", | |
"\n", | |
"prof.key_averages().table(sort_by=\"self_cuda_memory_usage\", row_limit=10)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.18" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment