Skip to content

Instantly share code, notes, and snippets.

@usr-ein
Last active January 6, 2025 13:38
Show Gist options
  • Save usr-ein/f92a805d466a290ee32e9a0221cf526a to your computer and use it in GitHub Desktop.
Save usr-ein/f92a805d466a290ee32e9a0221cf526a to your computer and use it in GitHub Desktop.
GPUDiff - a treemap GPU memory profiler for PyTorch (and TF if you want to port it)
from __future__ import annotations
import os
from typing import List, Optional, Dict, Union
import subprocess as sp
import plotly.express as px
import torch
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
def clear(cls):
cls._instances = {}
class GPUDiffRoot(metaclass=Singleton):
"""Keep track of the last entered context manager.
Whenever a context manager is entered, it becomes the parent of all the subsequently entered CMs.
Whenever a context manager is exitted, its parent become the current CM."""
def __init__(self):
self.name = "(callstack root)"
self.children = []
self.parent = self
self.current = self
self.usage = self.total_memory_installed
self.usage_excl_children = 0
self.enabled = True
def reset(self):
self.children = []
self.parent = self
self.current = self
self.__class__.clear()
def enter(self, cm: GPUDiff):
self.current.children.append(cm)
cm.parent = self.current
self.current = cm
def exit(self, cm: GPUDiff):
self.current = cm.parent
@property
def enabled(self):
return self._enabled
@enabled.setter
def enabled(self, v: bool):
self._enabled = v
@property
def disabled(self):
return not self.enabled
@disabled.setter
def disabled(self, v: bool):
self.enabled = (not v)
@property
def total_memory_installed(self) -> int:
"""Total VRAM installed across GPUs in MiB"""
output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
COMMAND = "/usr/bin/nvidia-smi --query-gpu=memory.total --format=csv"
try:
memory_use_info = output_to_list(sp.check_output(COMMAND.split(), stderr=sp.STDOUT))[1:]
except sp.CalledProcessError as e:
raise RuntimeError("command '{}' return with error (code {}): {}".format(e.cmd, e.returncode, e.output))
memory_installed = [int(x.split()[0]) for i, x in enumerate(memory_use_info)]
return sum(memory_installed)
@classmethod
def _build_nodes_list(cls, root, all_nodes=None):
if all_nodes is None:
all_nodes = [root]
all_nodes.extend(root.children)
for child in root.children:
cls._build_nodes_list(child, all_nodes)
return all_nodes
def make_treemap(self, path: Union[os.PathLike, str]):
if self.disabled:
return
assert self.current is self, "Only call this once all the GPUDiff context managers have finished wrapping !"
all_nodes = self._build_nodes_list(self)
names = [n.name for n in all_nodes]
parents = [id(n.parent) if n.parent is not n else "" for n in all_nodes]
ids = [id(n) for n in all_nodes]
# values = [n.usage for n in all_nodes]
values = [n.usage_excl_children for n in all_nodes]
values = list(map(lambda x: max(0, x), values))
fig = px.treemap(names=names, values=values, parents=parents, ids=ids,
color_continuous_scale=px.colors.diverging.RdYlGn[::-1],
# color_continuous_scale="thermal",
color_continuous_midpoint=(max(values)*1.15) // 2, color=values,
title="GPU memory usage per module",
)
fig.update_traces(root_color="lightgrey")
fig.update_layout(uniformtext=dict(minsize=10, mode='hide'), margin=dict(t=50, l=25, r=25, b=25))
fig.write_html(path, include_plotlyjs="cdn")
self.reset()
class GPUDiff:
parent: Optional[GPUDiff]
children: List[GPUDiff]
usage: Optional[int]
def __init__(self, name: str, verbose: bool = False):
self.name = name
self.verbose = verbose
self.parent = None
self.children = []
self.usage = None
def __enter__(self):
if GPUDiffRoot().disabled:
return
self.usage_before = self.total_memory_used
GPUDiffRoot().enter(self)
def __exit__(self, exc_type, exc_val, exc_tb):
if GPUDiffRoot().disabled:
return
self.usage_after = self.total_memory_used
self.usage = self.usage_after - self.usage_before
if self.usage != 0 and self.verbose:
print(f"GPU mem usage diff of {self.name}: \t{self.usage:+} MiB")
GPUDiffRoot().exit(self)
@property
def usage_excl_children(self) -> int:
return max(0, self.usage - sum(c.usage for c in self.children))
@property
def total_memory_used(self) -> int:
"""Total VRAM used in MiB"""
torch.cuda.synchronize()
return torch.cuda.memory_allocated() / 1024 ** 2
if __name__ == "__main__":
GPUDiffRoot().enabled = True
with GPUDiff("Forward"):
print("Run the model here")
with GPUDiff("Sub-Component 1"):
print("Run a specific sub-component code here")
with GPUDiff("Sub-Component 2"):
print("Run a specific sub-component code here")
with GPUDiff("Sub-Component 3"):
print("Run a specific sub-component code here")
with GPUDiff("Sub-sub-Component 3.1"):
print("Run a specific sub-sub-component code here")
# This will write an HTML file with an interactive graph displaying GPU memory usage as a treemap diagram
GPUDiffRoot().make_treemap("treemap.html")
<html>
<head><meta charset="utf-8" /></head>
<body>
<div> <script type="text/javascript">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>
<script src="https://cdn.plot.ly/plotly-2.3.1.min.js"></script> <div id="8e253e1c-da8f-4115-915a-016f2312d2c3" class="plotly-graph-div" style="height:100%; width:100%;"></div> <script type="text/javascript"> window.PLOTLYENV=window.PLOTLYENV || {}; if (document.getElementById("8e253e1c-da8f-4115-915a-016f2312d2c3")) { Plotly.newPlot( "8e253e1c-da8f-4115-915a-016f2312d2c3", [{"domain":{"x":[0.0,1.0],"y":[0.0,1.0]},"hovertemplate":"label=%{label}<br>value=%{value}<br>parent=%{parent}<br>id=%{id}<br>color=%{color}<extra></extra>","ids":[140441866195872,140441865191040,140436710357696,140436710358368,140441866678032,140441867256448,140441863858880,140436722507632,140441866314896,140441865715232,140441869392192,140441864936848,140441865527840,140441868070816,140441861209968,140441868548224,140441869125568,140441868704544,140441869126912,140441869392912,140441864558624,140441868414256,140436890652976,140436958728000,140441864794016,140441864264432,140441865189360,140441865189696,140441864264528,140441868796880,140441865987696,140441865354592,140441866861344,140441863123440,140436875829600,140436901168416,140436901169712,140436958727328,140436901167456,140436901168608,140436958727472,140441865714656,140441869281408,140441865190992,140441863432080,140441864560208,140441865191328,140441869281456,140441861541696,140441869280880,140441861159664,140441863856720,140436890886000,140436710355104],"labels":["(callstack root)","Forward","Loss network","Backward","Init PConvs","Recurrent","Tail","UNet #0","UNet #1","UNet #2","UNet #3","UNet #4","UNet #5","Encoding","Low","KCA","Up 0","Up 1","Up 2","Encoding","Low","KCA","Up 0","Up 1","Up 2","Encoding","Low","KCA","Up 0","Up 1","Up 2","Encoding","Low","KCA","Up 0","Up 1","Up 2","Encoding","Low","KCA","Up 0","Up 1","Up 2","Encoding","Low","KCA","Up 0","Up 1","Up 2","Deconv","Old new mashup","PConv","Bottleneck","Last conv"],"marker":{"coloraxis":"coloraxis","colors":[0.0,0.0,111.953125,0.0,672.001953125,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,776.01171875,208.015625,224.0625,224.00390625,480.001953125,576.0009765625,777.01171875,208.015625,256.09765625,224.00390625,480.001953125,576.0009765625,776.01171875,208.015625,256.09765625,224.00390625,480.001953125,576.0009765625,776.01171875,208.015625,256.09765625,224.00390625,480.001953125,576.0009765625,776.01171875,208.015625,256.09765625,224.00390625,480.001953125,576.0009765625,776.01171875,208.015625,256.09765625,224.00390625,480.001953125,576.0009765625,1032.0009765625,536.0,1064.0009765625,768.0029296875,536.0]},"name":"","parents":["",140441866195872,140441866195872,140441866195872,140441865191040,140441865191040,140441865191040,140441867256448,140441867256448,140441867256448,140441867256448,140441867256448,140441867256448,140436722507632,140436722507632,140436722507632,140436722507632,140436722507632,140436722507632,140441866314896,140441866314896,140441866314896,140441866314896,140441866314896,140441866314896,140441865715232,140441865715232,140441865715232,140441865715232,140441865715232,140441865715232,140441869392192,140441869392192,140441869392192,140441869392192,140441869392192,140441869392192,140441864936848,140441864936848,140441864936848,140441864936848,140441864936848,140441864936848,140441865527840,140441865527840,140441865527840,140441865527840,140441865527840,140441865527840,140441863858880,140441863858880,140441863858880,140441863858880,140441863858880],"root":{"color":"lightgrey"},"type":"treemap","values":[0.0,0.0,111.953125,0.0,672.001953125,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,776.01171875,208.015625,224.0625,224.00390625,480.001953125,576.0009765625,777.01171875,208.015625,256.09765625,224.00390625,480.001953125,576.0009765625,776.01171875,208.015625,256.09765625,224.00390625,480.001953125,576.0009765625,776.01171875,208.015625,256.09765625,224.00390625,480.001953125,576.0009765625,776.01171875,208.015625,256.09765625,224.00390625,480.001953125,576.0009765625,776.01171875,208.015625,256.09765625,224.00390625,480.001953125,576.0009765625,1032.0009765625,536.0,1064.0009765625,768.0029296875,536.0]}], {"coloraxis":{"cmid":611.0,"colorbar":{"title":{"text":"color"}},"colorscale":[[0.0,"rgb(0,104,55)"],[0.1,"rgb(26,152,80)"],[0.2,"rgb(102,189,99)"],[0.3,"rgb(166,217,106)"],[0.4,"rgb(217,239,139)"],[0.5,"rgb(255,255,191)"],[0.6,"rgb(254,224,139)"],[0.7,"rgb(253,174,97)"],[0.8,"rgb(244,109,67)"],[0.9,"rgb(215,48,39)"],[1.0,"rgb(165,0,38)"]]},"legend":{"tracegroupgap":0},"margin":{"b":25,"l":25,"r":25,"t":50},"template":{"data":{"bar":[{"error_x":{"color":"#2a3f5f"},"error_y":{"color":"#2a3f5f"},"marker":{"line":{"color":"#E5ECF6","width":0.5},"pattern":{"fillmode":"overlay","size":10,"solidity":0.2}},"type":"bar"}],"barpolar":[{"marker":{"line":{"color":"#E5ECF6","width":0.5},"pattern":{"fillmode":"overlay","size":10,"solidity":0.2}},"type":"barpolar"}],"carpet":[{"aaxis":{"endlinecolor":"#2a3f5f","gridcolor":"white","linecolor":"white","minorgridcolor":"white","startlinecolor":"#2a3f5f"},"baxis":{"endlinecolor":"#2a3f5f","gridcolor":"white","linecolor":"white","minorgridcolor":"white","startlinecolor":"#2a3f5f"},"type":"carpet"}],"choropleth":[{"colorbar":{"outlinewidth":0,"ticks":""},"type":"choropleth"}],"contour":[{"colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]],"type":"contour"}],"contourcarpet":[{"colorbar":{"outlinewidth":0,"ticks":""},"type":"contourcarpet"}],"heatmap":[{"colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]],"type":"heatmap"}],"heatmapgl":[{"colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]],"type":"heatmapgl"}],"histogram":[{"marker":{"pattern":{"fillmode":"overlay","size":10,"solidity":0.2}},"type":"histogram"}],"histogram2d":[{"colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]],"type":"histogram2d"}],"histogram2dcontour":[{"colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]],"type":"histogram2dcontour"}],"mesh3d":[{"colorbar":{"outlinewidth":0,"ticks":""},"type":"mesh3d"}],"parcoords":[{"line":{"colorbar":{"outlinewidth":0,"ticks":""}},"type":"parcoords"}],"pie":[{"automargin":true,"type":"pie"}],"scatter":[{"marker":{"colorbar":{"outlinewidth":0,"ticks":""}},"type":"scatter"}],"scatter3d":[{"line":{"colorbar":{"outlinewidth":0,"ticks":""}},"marker":{"colorbar":{"outlinewidth":0,"ticks":""}},"type":"scatter3d"}],"scattercarpet":[{"marker":{"colorbar":{"outlinewidth":0,"ticks":""}},"type":"scattercarpet"}],"scattergeo":[{"marker":{"colorbar":{"outlinewidth":0,"ticks":""}},"type":"scattergeo"}],"scattergl":[{"marker":{"colorbar":{"outlinewidth":0,"ticks":""}},"type":"scattergl"}],"scattermapbox":[{"marker":{"colorbar":{"outlinewidth":0,"ticks":""}},"type":"scattermapbox"}],"scatterpolar":[{"marker":{"colorbar":{"outlinewidth":0,"ticks":""}},"type":"scatterpolar"}],"scatterpolargl":[{"marker":{"colorbar":{"outlinewidth":0,"ticks":""}},"type":"scatterpolargl"}],"scatterternary":[{"marker":{"colorbar":{"outlinewidth":0,"ticks":""}},"type":"scatterternary"}],"surface":[{"colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]],"type":"surface"}],"table":[{"cells":{"fill":{"color":"#EBF0F8"},"line":{"color":"white"}},"header":{"fill":{"color":"#C8D4E3"},"line":{"color":"white"}},"type":"table"}]},"layout":{"annotationdefaults":{"arrowcolor":"#2a3f5f","arrowhead":0,"arrowwidth":1},"autotypenumbers":"strict","coloraxis":{"colorbar":{"outlinewidth":0,"ticks":""}},"colorscale":{"diverging":[[0,"#8e0152"],[0.1,"#c51b7d"],[0.2,"#de77ae"],[0.3,"#f1b6da"],[0.4,"#fde0ef"],[0.5,"#f7f7f7"],[0.6,"#e6f5d0"],[0.7,"#b8e186"],[0.8,"#7fbc41"],[0.9,"#4d9221"],[1,"#276419"]],"sequential":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]],"sequentialminus":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]]},"colorway":["#636efa","#EF553B","#00cc96","#ab63fa","#FFA15A","#19d3f3","#FF6692","#B6E880","#FF97FF","#FECB52"],"font":{"color":"#2a3f5f"},"geo":{"bgcolor":"white","lakecolor":"white","landcolor":"#E5ECF6","showlakes":true,"showland":true,"subunitcolor":"white"},"hoverlabel":{"align":"left"},"hovermode":"closest","mapbox":{"style":"light"},"paper_bgcolor":"white","plot_bgcolor":"#E5ECF6","polar":{"angularaxis":{"gridcolor":"white","linecolor":"white","ticks":""},"bgcolor":"#E5ECF6","radialaxis":{"gridcolor":"white","linecolor":"white","ticks":""}},"scene":{"xaxis":{"backgroundcolor":"#E5ECF6","gridcolor":"white","gridwidth":2,"linecolor":"white","showbackground":true,"ticks":"","zerolinecolor":"white"},"yaxis":{"backgroundcolor":"#E5ECF6","gridcolor":"white","gridwidth":2,"linecolor":"white","showbackground":true,"ticks":"","zerolinecolor":"white"},"zaxis":{"backgroundcolor":"#E5ECF6","gridcolor":"white","gridwidth":2,"linecolor":"white","showbackground":true,"ticks":"","zerolinecolor":"white"}},"shapedefaults":{"line":{"color":"#2a3f5f"}},"ternary":{"aaxis":{"gridcolor":"white","linecolor":"white","ticks":""},"baxis":{"gridcolor":"white","linecolor":"white","ticks":""},"bgcolor":"#E5ECF6","caxis":{"gridcolor":"white","linecolor":"white","ticks":""}},"title":{"x":0.05},"xaxis":{"automargin":true,"gridcolor":"white","linecolor":"white","ticks":"","title":{"standoff":15},"zerolinecolor":"white","zerolinewidth":2},"yaxis":{"automargin":true,"gridcolor":"white","linecolor":"white","ticks":"","title":{"standoff":15},"zerolinecolor":"white","zerolinewidth":2}}},"title":{"text":"GPU memory usage per module"},"uniformtext":{"minsize":10,"mode":"hide"}}, {"responsive": true} ) }; </script> </div>
</body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment