Skip to content

Instantly share code, notes, and snippets.

@praateekmahajan
Last active May 8, 2025 23:03
Show Gist options
  • Save praateekmahajan/896cf3ae7b0bae380fe750b2f5a85171 to your computer and use it in GitHub Desktop.
Save praateekmahajan/896cf3ae7b0bae380fe750b2f5a85171 to your computer and use it in GitHub Desktop.
Python GPU Monitor (using gpustat + context)
"""
You need gpustat for this to work.
Optionally install pandas / seaborn / matplotlib if you want plotting features.
```python
from gpu_monitor import GPUUtilizationManager
gpu_monitor = GPUUtilizationManager(gpu_ids=[0,1], interval=0.1)
with gpu_monitor.monitor():
gpu_monitor.add_event("A variable")
A = torch.random.randn(....)
time.sleep(1)
gpu_monitor.add_event("B variable")
B = torch.random.randn(....)
time.sleep(1)
gpu_monitor.add_event("Mat Mul")
C = A.matmul(B)
time.sleep(1)
gpu_monitor.plot(figisze=(12, 6))
"""
import time
import threading
import gpustat
import datetime
class GPUUtilizationManager:
def __init__(self, gpu_ids: list[int] | None = None, interval: float = 0.1):
self.gpu_ids = gpu_ids if gpu_ids is not None else [0, 1, 2, 3]
self.interval = interval
self.data = []
self._monitoring = False
self._thread = None
self.events = []
def _format_duration(self, seconds: float) -> str:
"""Convert a float number of seconds into 'Xm Ys' or 'Zs'."""
total = int(round(seconds))
if total == 0:
return f"0s"
minutes, secs = divmod(total, 60)
if minutes:
return f"{minutes}m" + (f" {secs}s" if secs > 0 else "")
return f"{secs}s"
class _MonitorContextManager:
def __init__(self, manager):
self.manager = manager
def __enter__(self):
self.manager._start_monitoring()
return self.manager
def __exit__(self, exc_type, exc_val, exc_tb):
self.manager._stop_monitoring()
self.manager._stamp_last_event_duration()
def monitor(self):
"""Return a context manager for monitoring GPU utilization."""
return self._MonitorContextManager(self)
def _start_monitoring(self):
"""Start the monitoring thread."""
self._monitoring = True
self._thread = threading.Thread(target=self._monitor_loop)
self._thread.start()
def _stop_monitoring(self):
"""Stop the monitoring thread."""
self._monitoring = False
if self._thread:
self._thread.join()
def _monitor_loop(self):
"""Monitor loop that collects GPU statistics."""
while self._monitoring:
try:
stats = gpustat.new_query()
timestamp = datetime.datetime.now()
gpu_stats = []
for gpu in stats.gpus:
if self.gpu_ids is None or gpu.index in self.gpu_ids:
gpu_stats.append(
{
"gpu_id": gpu.index,
"utilization": gpu.utilization,
"memory_used": gpu.memory_used / 1024,
"memory_total": gpu.memory_total,
"temperature": getattr(gpu, "temperature", None),
"timestamp": timestamp,
}
)
self.data.extend(gpu_stats)
except Exception as e:
# Handle errors silently or log if needed
pass
time.sleep(self.interval)
def add_event(self, event_name: str):
now = datetime.datetime.now()
# stamp previous event with elapsed time
if self.events:
prev = self.events[-1]
delta = now - prev["timestamp"]
prev["duration_s"] = self._format_duration(delta.total_seconds())
self.events.append(
{"timestamp": now, "event_name": event_name, "duration": None}
)
def _stamp_last_event_duration(self):
"""When no further events will occur, stamp the last event’s duration."""
if not self.events:
return
now = datetime.datetime.now()
last = self.events[-1]
delta = now - last["timestamp"]
last["duration_s"] = self._format_duration(delta.total_seconds())
def plot(self, figsize=(14, 10), title: str = "GPU Over time"):
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# Convert collected data to DataFrame
df = pd.DataFrame(self.data)
if df.empty:
print("No data to plot.")
return
# Convert events to DataFrame
events_df = pd.DataFrame(self.events)
# Use a clean grid style
sns.set(style="whitegrid")
# Create two subplots: utilization and memory used
fig, axes = plt.subplots(2, 1, figsize=figsize, sharex=True)
# Plot GPU utilization over time
avg_df = df.groupby("timestamp")["memory_used"].sum().reset_index()
for i in range(2):
data, kwargs = (
(df, dict(hue="gpu_id", palette="tab10"))
if i == 0
else (
df.groupby("timestamp")["utilization"].mean().reset_index(),
dict(
label="Avg Utilization",
color="gray",
linestyle="--",
linewidth=2,
),
)
)
sns.lineplot(
data=data, x="timestamp", y="utilization", ax=axes[0], **kwargs
)
axes[0].set_title("GPU Utilization Over Time")
axes[0].set_ylabel("Utilization (%)")
# Plot GPU memory used over time
for i in range(2):
# If i==1 then plot sum
data, kwargs = (
(df, dict(hue="gpu_id", palette="tab10"))
if i == 0
else (
df.groupby("timestamp")["memory_used"].sum().reset_index(),
dict(
label="Total Memory Used",
color="gray",
linestyle="--",
linewidth=2,
),
)
)
sns.lineplot(
data=data, x="timestamp", y="memory_used", ax=axes[1], **kwargs
)
axes[1].set_title("GPU Memory Used Over Time")
axes[1].set_ylabel("Memory Used (GB)")
axes[1].set_xlabel("Timestamp")
# Annotate events with vertical dashed lines and labels
for ax in axes:
for _, event in events_df.iterrows():
ax.axvline(event["timestamp"], color="gray", linestyle="dotted")
# Place label at 90% of the y-axis height
ylim = ax.get_ylim()
ax.text(
event["timestamp"],
ylim[0] + 0.9 * (ylim[1] - ylim[0]),
event["event_name"] + f" ({event['duration_s']})",
rotation=90,
verticalalignment="top",
fontsize=9,
color="gray",
)
plt.tight_layout()
plt.show()
plt.suptitle(title)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment