Skip to content

Instantly share code, notes, and snippets.

@stas00
Last active January 24, 2021 17:39
Show Gist options
  • Save stas00/d29711a4b594b8335c9053a6624444cb to your computer and use it in GitHub Desktop.
Save stas00/d29711a4b594b8335c9053a6624444cb to your computer and use it in GitHub Desktop.
# same as the other script, but this time each thread allocates on a different device
# still reports correctly
import threading
import time
import torch
def print_mem_usage(prefix):
n_gpus = torch.cuda.device_count()
for id in range(n_gpus):
with torch.cuda.device(id):
print(f"{prefix:>4}: {id}: {torch.cuda.max_memory_allocated() >> 20:2d}MB")
def thread_function(index):
id = index
index += 1
time.sleep(index*3)
# 10MB thread 0
# 20MB thread 1
x = 0
with torch.cuda.device(id):
x = torch.ones((10*index*2**18)).cuda().contiguous().to(id)
print_mem_usage(index)
time.sleep(6)
if __name__ == "__main__":
threads = list()
for index in range(2):
x = threading.Thread(target=thread_function, args=(index,))
threads.append(x)
x.start()
for i in range (5):
time.sleep(i*2)
print_mem_usage("main")
for index, thread in enumerate(threads):
thread.join()
print_mem_usage("main")
# checking that torch.cuda.max_memory_allocated() reports correctly peak memory for the whole process when python threads are used
# indeed it does the right thing
import threading
import time
import torch
def print_mem_usage(prefix):
n_gpus = torch.cuda.device_count()
for id in range(n_gpus):
with torch.cuda.device(id):
print(f"{prefix:>4}: {id}: {torch.cuda.max_memory_allocated() >> 20:2d}MB")
def thread_function(index):
index += 1
time.sleep(index*3)
# 10MB thread 0
# 20MB thread 1
x = torch.ones((10*index*2**18)).cuda().contiguous()
print_mem_usage(index)
time.sleep(6)
if __name__ == "__main__":
threads = list()
for index in range(2):
x = threading.Thread(target=thread_function, args=(index,))
threads.append(x)
x.start()
for i in range (5):
time.sleep(i*2)
print_mem_usage("main")
for index, thread in enumerate(threads):
thread.join()
print_mem_usage("main")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment