Last active
April 20, 2023 16:59
-
-
Save alex0dd/192efe6d67c34c37c8c494c661170a1f to your computer and use it in GitHub Desktop.
Visualizing memory consumption of a PyTorch model.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from pytorch_memory_profiling import memory_logging | |
| from torchvision.models import resnet18 | |
| def training_loop(model, criterion, optimizer, dataloader): | |
| model.train() | |
| for (batch_x, batch_y) in dataloader: | |
| pred_y = model(batch_x) | |
| loss = criterion(pred_y, batch_y) | |
| loss.backward() | |
| optimizer.step() | |
| @memory_logging("resnet18_snapshot.pkl") | |
| def main(): | |
| model = resnet18().cuda() | |
| criterion = torch.nn.CrossEntropyLoss() | |
| optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) | |
| batch_size = 8 | |
| n_batches = 16 | |
| inputs = torch.rand(batch_size, 3, 224, 224, device='cuda') | |
| labels = torch.zeros(batch_size, dtype=torch.long, device='cuda') | |
| dataset = [(inputs, labels)] * n_batches | |
| training_loop(model, criterion, optimizer, dataset) | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Inspired by the following blog post: | |
| https://zdevito.github.io/2022/12/09/memory-traces.html | |
| This is a decorator that can be added to any function (e.g., training loop, model's main file) to dump stats about its memory usage. | |
| The dumped stats can later on be interpreted and visualized via the memory_viz tool provided within PyTorch: | |
| https://raw.githubusercontent.com/pytorch/pytorch/main/torch/cuda/_memory_viz.py | |
| """ | |
| import functools | |
| import torch | |
| from pickle import dump | |
| def begin_memory_logging(trace_alloc_max_entries): | |
| torch.cuda.memory._record_memory_history( | |
| True, | |
| trace_alloc_max_entries=trace_alloc_max_entries, | |
| # record stack information for the trace events | |
| trace_alloc_record_context=True | |
| ) | |
| def end_memory_logging(snapshot_dump_path): | |
| snapshot = torch.cuda.memory._snapshot() | |
| with open(snapshot_dump_path, 'wb') as f: | |
| dump(snapshot, f) | |
| def memory_logging(path, trace_alloc_max_entries=1000000): | |
| def decorator_memory_logging(func): | |
| @functools.wraps(func) | |
| def wrapper_decorator(*args, **kwargs): | |
| begin_memory_logging(trace_alloc_max_entries) | |
| # Run the function | |
| value = func(*args, **kwargs) | |
| end_memory_logging(path) | |
| return value | |
| return wrapper_decorator | |
| return decorator_memory_logging |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Instructions: