Skip to content

Instantly share code, notes, and snippets.

@LouisFaure
Created August 3, 2024 22:31
Show Gist options
  • Save LouisFaure/438529d036d7adb27809d5ee854078cc to your computer and use it in GitHub Desktop.
Save LouisFaure/438529d036d7adb27809d5ee854078cc to your computer and use it in GitHub Desktop.
Get gpu memory used by current JAX script
import subprocess as subp
import re
if jax.default_backend()=='gpu':
# Step 2: Call nvidia-smi and get the output
result = subp.run([
'nvidia-smi',
'--query-compute-apps=pid,used_memory',
'--format=csv,noheader,nounits'
], stdout=subp.PIPE, text=True)
# Step 3: Get the output
output = result.stdout.strip()
# Step 4: Process the output
memory_usage = {}
for line in output.split('\n'):
pid, mem = line.split(',')
memory_usage[int(pid.strip())] = int(mem.strip())
# Step 5: Print the memory usage for current PID
print(f"GPU memory: {memory_usage[os.getpid()]} MiB")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment