Last active
January 16, 2024 07:48
-
-
Save ayaka14732/0daa4bb50e563ea556a0102a32afc33e to your computer and use it in GitHub Desktop.
Track TPU memory usage while running the training script. See https://twitter.com/ayaka14732/status/1565016471323156481 for more details.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# monitor the memory profile with `watch --color -n1 go tool pprof -tags /dev/shm/memory.prof` | |
import functools | |
import jax | |
import jax.numpy as np | |
import random | |
import threading | |
devices = jax.devices() | |
n_devices = jax.device_count() | |
def initialise_memory_tracking(): | |
def inner(): | |
import posix | |
import time | |
while True: | |
jax.profiler.save_device_memory_profile('/dev/shm/memory.prof.new') | |
posix.rename('/dev/shm/memory.prof.new', '/dev/shm/memory.prof') # atomic | |
time.sleep(1.) | |
thread = threading.Thread(target=inner, daemon=True) | |
thread.start() | |
@functools.partial(jax.pmap, axis_name='n_devices') | |
def some_heavy_computation(a, b): | |
c = np.einsum('abcd,ebcd->ae', a, b) | |
d = jax.lax.pmean(c, axis_name='n_devices') | |
return d | |
def main(): | |
initialise_memory_tracking() | |
for i in range(1000): | |
print(i) | |
x = random.randrange(100, 32000) | |
y = random.randrange(100, 32000) | |
a = np.zeros((x, 11, 4, 2), dtype=np.float32) | |
b = np.zeros((y, 11, 4, 2), dtype=np.float32) | |
a = jax.device_put_replicated(a, devices=devices) | |
b = jax.device_put_replicated(b, devices=devices) | |
some_heavy_computation(a, b) | |
print('Done') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Python package: https://github.com/ayaka14732/jax-smi