Skip to content

Instantly share code, notes, and snippets.

@alex0dd
Last active April 20, 2023 16:59
Show Gist options
  • Select an option

  • Save alex0dd/192efe6d67c34c37c8c494c661170a1f to your computer and use it in GitHub Desktop.

Select an option

Save alex0dd/192efe6d67c34c37c8c494c661170a1f to your computer and use it in GitHub Desktop.
Visualizing memory consumption of a PyTorch model.
import torch
from pytorch_memory_profiling import memory_logging
from torchvision.models import resnet18
def training_loop(model, criterion, optimizer, dataloader):
model.train()
for (batch_x, batch_y) in dataloader:
pred_y = model(batch_x)
loss = criterion(pred_y, batch_y)
loss.backward()
optimizer.step()
@memory_logging("resnet18_snapshot.pkl")
def main():
model = resnet18().cuda()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
batch_size = 8
n_batches = 16
inputs = torch.rand(batch_size, 3, 224, 224, device='cuda')
labels = torch.zeros(batch_size, dtype=torch.long, device='cuda')
dataset = [(inputs, labels)] * n_batches
training_loop(model, criterion, optimizer, dataset)
if __name__ == "__main__":
main()
"""
Inspired by the following blog post:
https://zdevito.github.io/2022/12/09/memory-traces.html
This is a decorator that can be added to any function (e.g., training loop, model's main file) to dump stats about its memory usage.
The dumped stats can later on be interpreted and visualized via the memory_viz tool provided within PyTorch:
https://raw.githubusercontent.com/pytorch/pytorch/main/torch/cuda/_memory_viz.py
"""
import functools
import torch
from pickle import dump
def begin_memory_logging(trace_alloc_max_entries):
torch.cuda.memory._record_memory_history(
True,
trace_alloc_max_entries=trace_alloc_max_entries,
# record stack information for the trace events
trace_alloc_record_context=True
)
def end_memory_logging(snapshot_dump_path):
snapshot = torch.cuda.memory._snapshot()
with open(snapshot_dump_path, 'wb') as f:
dump(snapshot, f)
def memory_logging(path, trace_alloc_max_entries=1000000):
def decorator_memory_logging(func):
@functools.wraps(func)
def wrapper_decorator(*args, **kwargs):
begin_memory_logging(trace_alloc_max_entries)
# Run the function
value = func(*args, **kwargs)
end_memory_logging(path)
return value
return wrapper_decorator
return decorator_memory_logging
@alex0dd
Copy link
Author

alex0dd commented Jan 16, 2023

Instructions:

  1. Run the example file to generate a pickle file containing the memory trace:
python example.py
  1. Download visualization tool from PyTorch repository:
wget https://raw.githubusercontent.com/pytorch/pytorch/main/torch/cuda/_memory_viz.py
  1. Use the visualization tool to generate the usage chart:
python _memory_viz.py trace_plot resnet18_snapshot.pkl -o resnet18_trace.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment