Created
August 3, 2024 22:31
-
-
Save LouisFaure/438529d036d7adb27809d5ee854078cc to your computer and use it in GitHub Desktop.
Get gpu memory used by current JAX script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import subprocess as subp | |
| import re | |
| if jax.default_backend()=='gpu': | |
| # Step 2: Call nvidia-smi and get the output | |
| result = subp.run([ | |
| 'nvidia-smi', | |
| '--query-compute-apps=pid,used_memory', | |
| '--format=csv,noheader,nounits' | |
| ], stdout=subp.PIPE, text=True) | |
| # Step 3: Get the output | |
| output = result.stdout.strip() | |
| # Step 4: Process the output | |
| memory_usage = {} | |
| for line in output.split('\n'): | |
| pid, mem = line.split(',') | |
| memory_usage[int(pid.strip())] = int(mem.strip()) | |
| # Step 5: Print the memory usage for current PID | |
| print(f"GPU memory: {memory_usage[os.getpid()]} MiB") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment