Created
September 15, 2023 18:29
-
-
Save muellerzr/8523a9f2868898c839b4ea3163594aa0 to your computer and use it in GitHub Desktop.
Model memory stuff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoModel, AutoConfig, AutoModelForSequenceClassification | |
def get_model_memory(model: torch.nn.Module): | |
""" | |
Returns the memory usage of the given model | |
""" | |
total_memory = 0 | |
for param in model.parameters(): | |
total_memory += param.numel() * param.element_size() | |
return total_memory | |
class ActivationCounter: | |
"""Helper class to count the number of activations in a model.""" | |
def __init__(self): | |
self.activation_bytes = 0 | |
def add_activations(self, tensor): | |
self.activation_bytes += tensor.numel() * tensor.element_size() | |
def add_activation_bytes(self, bytes): | |
self.activation_bytes += bytes | |
def activation_counter_hook(counter: ActivationCounter): | |
"""Returns a hook that counts the number of activations.""" | |
def hook(self, _, output): | |
if self.__class__.__name__ == "Dropout": | |
# for dropout layers, we only need to store the mask | |
counter.add_activation_bytes(output.data.numel()) | |
else: | |
if isinstance(output, tuple): | |
for o in output: | |
if isinstance(o, torch.Tensor): | |
counter.add_activations(o.data) | |
elif isinstance(o, tuple): | |
for o2 in o: | |
counter.add_activations(o2.data) | |
elif isinstance(output, torch.Tensor): | |
counter.add_activations(output.data) | |
return hook | |
def register_hooks_recursive(model: torch.nn.Module, counter: ActivationCounter): | |
"""Recursively injects activation counting hooks into the given model.""" | |
for module in model.children(): | |
module.register_forward_hook(activation_counter_hook(counter)) | |
register_hooks_recursive(module, counter) | |
import math | |
def format_size(size_bytes): | |
""" | |
Converts the given size in bytes to a human readable format | |
Reference: https://stackoverflow.com/questions/5194057/better-way-to-convert-file-sizes-in-python | |
""" | |
if size_bytes == 0: | |
return "0B" | |
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") | |
i = int(math.floor(math.log(size_bytes, 1024))) | |
p = math.pow(1024, i) | |
s = round(size_bytes / p, 2) | |
return "%s %s" % (s, size_name[i]) | |
def get_current_memory_allocation(): | |
return torch.cuda.memory_allocated() | |
def get_optimizer_memory(model: torch.nn.Module, optimizer: torch.optim.Optimizer): | |
""" | |
Returns the memory usage (in bytes) of the given optimizer and model. | |
Note: Currently only supports SGD, Adam, and AdamW. | |
""" | |
model_parameters = sum(param.numel() for param in model.parameters()) | |
bytes_per_param = 0 | |
if type(optimizer) == torch.optim.SGD: | |
has_momentum = any(param_group.get('momentum', 0) != 0 | |
for param_group in optimizer.param_groups) | |
if has_momentum: | |
bytes_per_param = 4 | |
elif type(optimizer) in (torch.optim.Adam, torch.optim.AdamW): | |
bytes_per_param = 8 | |
else: | |
raise ValueError(f"Unsupported optimizer: {optimizer}") | |
return model_parameters * bytes_per_param | |
def project_transformer_memory( | |
layers, hidden_size, num_attention_heads, | |
batch_size, sequence_length, optimizer): | |
model_memory = 4 * layers * hidden_size * (13 + 12 * hidden_size) | |
gradient_memory = model_memory | |
# activation memory formula from: https://arxiv.org/pdf/2205.05198.pdf | |
activation_memory = layers * batch_size * sequence_length * hidden_size * ( | |
67 + (9*num_attention_heads*sequence_length) / hidden_size | |
) | |
optimizer_memory = get_optimizer_memory(model, optimizer) | |
return model_memory + gradient_memory + activation_memory + optimizer_memory | |
if __name__ == "__main__": | |
batch_size = 1 | |
model_name = "bert-base-cased" | |
# model_name = "bert-base-uncased" | |
# model_name = "albert-base-v2" | |
# model_name = "distilbert-base-uncased" | |
# model_name = "gpt2" | |
# model_name = "roberta-base" | |
config = AutoConfig.from_pretrained(model_name) | |
config.return_dict = True | |
model = AutoModelForSequenceClassification.from_config(config) | |
print(f'Model used: {model_name}') | |
# bert-base-cased should have 12 | |
projected_total_memory = format_size(project_transformer_memory( | |
config.num_hidden_groups if hasattr(config, "num_hidden_groups") else config.num_hidden_layers, | |
config.hidden_size, | |
config.num_attention_heads, batch_size, | |
config.max_position_embeddings, torch.optim.Adam(model.parameters())) | |
) | |
print(f"Projected total memory usage: {projected_total_memory}") | |
print("-" * 80) | |
############################################################ | |
## Measure model memory | |
############################################################ | |
device = "cuda" | |
model.to(device) | |
memory_allocation_with_model = get_current_memory_allocation() | |
estimated_model_memory = get_model_memory(model) | |
print(f"Measured Model Memory: {format_size(memory_allocation_with_model)}") | |
print(f"Estimated Model Memory: {format_size(estimated_model_memory)}") | |
print(f"Percent difference: {abs(memory_allocation_with_model - estimated_model_memory) / estimated_model_memory * 100:.2f}%") | |
print("-" * 80) | |
############################################################ | |
## Measure activation memory | |
############################################################ | |
activation_counter = ActivationCounter() | |
register_hooks_recursive(model, activation_counter) | |
batch = { | |
"labels": torch.tensor([0]).to("cuda"), | |
"input_ids":torch.randint(0, model.config.max_position_embeddings-1, (batch_size, 64)).to("cuda") | |
} | |
outputs = model(batch["input_ids"]) | |
activation_counter.add_activations(batch["input_ids"]) | |
memory_allocation_forward_pass = get_current_memory_allocation() - memory_allocation_with_model | |
print(f"Consumed Activation Memory: {format_size(memory_allocation_forward_pass)}") | |
print(f"Estimated Activation Memory: {format_size(activation_counter.activation_bytes)}") | |
print(f"Percent difference: {abs(memory_allocation_forward_pass - activation_counter.activation_bytes) / activation_counter.activation_bytes * 100:.2f}%") | |
print("-" * 80) | |
############################################################ | |
## Measure gradient memory | |
############################################################ | |
loss_fn = torch.nn.MSELoss() | |
labels = torch.randn_like(outputs.logits).to(device) | |
labels_size = labels.numel() * labels.element_size() | |
outputs_size = outputs.logits.numel() * outputs.logits.element_size() | |
loss = loss_fn(outputs.logits, labels) | |
loss.backward(retain_graph=True) | |
memory_allocation_with_gradients = get_current_memory_allocation() - memory_allocation_with_model - memory_allocation_forward_pass | |
estimated_gradient_memory = estimated_model_memory | |
print(f"Consumed Gradient Memory: {format_size(memory_allocation_with_gradients)}" ) | |
print(f"Estimated Gradient Memory: {format_size(estimated_gradient_memory)}") | |
print(f"Percent difference: {abs(memory_allocation_with_gradients - estimated_gradient_memory) / (estimated_gradient_memory) * 100:.2f}%") | |
print("-" * 80) | |
############################################################ | |
## Measure optimizer memory | |
############################################################ | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
optimizer.step() | |
post_step_memory = get_current_memory_allocation() - memory_allocation_with_model - memory_allocation_forward_pass - memory_allocation_with_gradients | |
estimated_optimizer_memory = get_optimizer_memory(model, optimizer) | |
print(f"Consumed Optimizer + Gradient Memory: {format_size(post_step_memory)}" ) | |
print(f"Estimated Optimizer + Gradient Memory: {format_size(estimated_optimizer_memory)}") | |
print(f"Percent difference: {abs(post_step_memory - estimated_optimizer_memory) / (estimated_optimizer_memory) * 100:.2f}%") | |
print("-" * 80) | |
print(f"Actual total memory usage: {format_size(get_current_memory_allocation())}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment