Last active
May 8, 2025 23:03
-
-
Save praateekmahajan/896cf3ae7b0bae380fe750b2f5a85171 to your computer and use it in GitHub Desktop.
Python GPU Monitor (using gpustat + context)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
You need gpustat for this to work. | |
Optionally install pandas / seaborn / matplotlib if you want plotting features. | |
```python | |
from gpu_monitor import GPUUtilizationManager | |
gpu_monitor = GPUUtilizationManager(gpu_ids=[0,1], interval=0.1) | |
with gpu_monitor.monitor(): | |
gpu_monitor.add_event("A variable") | |
A = torch.random.randn(....) | |
time.sleep(1) | |
gpu_monitor.add_event("B variable") | |
B = torch.random.randn(....) | |
time.sleep(1) | |
gpu_monitor.add_event("Mat Mul") | |
C = A.matmul(B) | |
time.sleep(1) | |
gpu_monitor.plot(figisze=(12, 6)) | |
""" | |
import time | |
import threading | |
import gpustat | |
import datetime | |
class GPUUtilizationManager: | |
def __init__(self, gpu_ids: list[int] | None = None, interval: float = 0.1): | |
self.gpu_ids = gpu_ids if gpu_ids is not None else [0, 1, 2, 3] | |
self.interval = interval | |
self.data = [] | |
self._monitoring = False | |
self._thread = None | |
self.events = [] | |
def _format_duration(self, seconds: float) -> str: | |
"""Convert a float number of seconds into 'Xm Ys' or 'Zs'.""" | |
total = int(round(seconds)) | |
if total == 0: | |
return f"0s" | |
minutes, secs = divmod(total, 60) | |
if minutes: | |
return f"{minutes}m" + (f" {secs}s" if secs > 0 else "") | |
return f"{secs}s" | |
class _MonitorContextManager: | |
def __init__(self, manager): | |
self.manager = manager | |
def __enter__(self): | |
self.manager._start_monitoring() | |
return self.manager | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.manager._stop_monitoring() | |
self.manager._stamp_last_event_duration() | |
def monitor(self): | |
"""Return a context manager for monitoring GPU utilization.""" | |
return self._MonitorContextManager(self) | |
def _start_monitoring(self): | |
"""Start the monitoring thread.""" | |
self._monitoring = True | |
self._thread = threading.Thread(target=self._monitor_loop) | |
self._thread.start() | |
def _stop_monitoring(self): | |
"""Stop the monitoring thread.""" | |
self._monitoring = False | |
if self._thread: | |
self._thread.join() | |
def _monitor_loop(self): | |
"""Monitor loop that collects GPU statistics.""" | |
while self._monitoring: | |
try: | |
stats = gpustat.new_query() | |
timestamp = datetime.datetime.now() | |
gpu_stats = [] | |
for gpu in stats.gpus: | |
if self.gpu_ids is None or gpu.index in self.gpu_ids: | |
gpu_stats.append( | |
{ | |
"gpu_id": gpu.index, | |
"utilization": gpu.utilization, | |
"memory_used": gpu.memory_used / 1024, | |
"memory_total": gpu.memory_total, | |
"temperature": getattr(gpu, "temperature", None), | |
"timestamp": timestamp, | |
} | |
) | |
self.data.extend(gpu_stats) | |
except Exception as e: | |
# Handle errors silently or log if needed | |
pass | |
time.sleep(self.interval) | |
def add_event(self, event_name: str): | |
now = datetime.datetime.now() | |
# stamp previous event with elapsed time | |
if self.events: | |
prev = self.events[-1] | |
delta = now - prev["timestamp"] | |
prev["duration_s"] = self._format_duration(delta.total_seconds()) | |
self.events.append( | |
{"timestamp": now, "event_name": event_name, "duration": None} | |
) | |
def _stamp_last_event_duration(self): | |
"""When no further events will occur, stamp the last event’s duration.""" | |
if not self.events: | |
return | |
now = datetime.datetime.now() | |
last = self.events[-1] | |
delta = now - last["timestamp"] | |
last["duration_s"] = self._format_duration(delta.total_seconds()) | |
def plot(self, figsize=(14, 10), title: str = "GPU Over time"): | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
# Convert collected data to DataFrame | |
df = pd.DataFrame(self.data) | |
if df.empty: | |
print("No data to plot.") | |
return | |
# Convert events to DataFrame | |
events_df = pd.DataFrame(self.events) | |
# Use a clean grid style | |
sns.set(style="whitegrid") | |
# Create two subplots: utilization and memory used | |
fig, axes = plt.subplots(2, 1, figsize=figsize, sharex=True) | |
# Plot GPU utilization over time | |
avg_df = df.groupby("timestamp")["memory_used"].sum().reset_index() | |
for i in range(2): | |
data, kwargs = ( | |
(df, dict(hue="gpu_id", palette="tab10")) | |
if i == 0 | |
else ( | |
df.groupby("timestamp")["utilization"].mean().reset_index(), | |
dict( | |
label="Avg Utilization", | |
color="gray", | |
linestyle="--", | |
linewidth=2, | |
), | |
) | |
) | |
sns.lineplot( | |
data=data, x="timestamp", y="utilization", ax=axes[0], **kwargs | |
) | |
axes[0].set_title("GPU Utilization Over Time") | |
axes[0].set_ylabel("Utilization (%)") | |
# Plot GPU memory used over time | |
for i in range(2): | |
# If i==1 then plot sum | |
data, kwargs = ( | |
(df, dict(hue="gpu_id", palette="tab10")) | |
if i == 0 | |
else ( | |
df.groupby("timestamp")["memory_used"].sum().reset_index(), | |
dict( | |
label="Total Memory Used", | |
color="gray", | |
linestyle="--", | |
linewidth=2, | |
), | |
) | |
) | |
sns.lineplot( | |
data=data, x="timestamp", y="memory_used", ax=axes[1], **kwargs | |
) | |
axes[1].set_title("GPU Memory Used Over Time") | |
axes[1].set_ylabel("Memory Used (GB)") | |
axes[1].set_xlabel("Timestamp") | |
# Annotate events with vertical dashed lines and labels | |
for ax in axes: | |
for _, event in events_df.iterrows(): | |
ax.axvline(event["timestamp"], color="gray", linestyle="dotted") | |
# Place label at 90% of the y-axis height | |
ylim = ax.get_ylim() | |
ax.text( | |
event["timestamp"], | |
ylim[0] + 0.9 * (ylim[1] - ylim[0]), | |
event["event_name"] + f" ({event['duration_s']})", | |
rotation=90, | |
verticalalignment="top", | |
fontsize=9, | |
color="gray", | |
) | |
plt.tight_layout() | |
plt.show() | |
plt.suptitle(title) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment