Last active
March 7, 2023 16:58
-
-
Save Stonesjtu/368ddf5d9eb56669269ecdf9b0d21cbe to your computer and use it in GitHub Desktop.
A simple Pytorch memory usages profiler
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gc | |
import torch | |
## MEM utils ## | |
def mem_report(): | |
'''Report the memory usage of the tensor.storage in pytorch | |
Both on CPUs and GPUs are reported''' | |
def _mem_report(tensors, mem_type): | |
'''Print the selected tensors of type | |
There are two major storage types in our major concern: | |
- GPU: tensors transferred to CUDA devices | |
- CPU: tensors remaining on the system memory (usually unimportant) | |
Args: | |
- tensors: the tensors of specified type | |
- mem_type: 'CPU' or 'GPU' in current implementation ''' | |
print('Storage on %s' %(mem_type)) | |
print('-'*LEN) | |
total_numel = 0 | |
total_mem = 0 | |
visited_data = [] | |
for tensor in tensors: | |
if tensor.is_sparse: | |
continue | |
# a data_ptr indicates a memory block allocated | |
data_ptr = tensor.storage().data_ptr() | |
if data_ptr in visited_data: | |
continue | |
visited_data.append(data_ptr) | |
numel = tensor.storage().size() | |
total_numel += numel | |
element_size = tensor.storage().element_size() | |
mem = numel*element_size /1024/1024 # 32bit=4Byte, MByte | |
total_mem += mem | |
element_type = type(tensor).__name__ | |
size = tuple(tensor.size()) | |
print('%s\t\t%s\t\t%.2f' % ( | |
element_type, | |
size, | |
mem) ) | |
print('-'*LEN) | |
print('Total Tensors: %d \tUsed Memory Space: %.2f MBytes' % (total_numel, total_mem) ) | |
print('-'*LEN) | |
LEN = 65 | |
print('='*LEN) | |
objects = gc.get_objects() | |
print('%s\t%s\t\t\t%s' %('Element type', 'Size', 'Used MEM(MBytes)') ) | |
tensors = [obj for obj in objects if torch.is_tensor(obj)] | |
cuda_tensors = [t for t in tensors if t.is_cuda] | |
host_tensors = [t for t in tensors if not t.is_cuda] | |
_mem_report(cuda_tensors, 'GPU') | |
_mem_report(host_tensors, 'CPU') | |
print('='*LEN) |
@kkonevets, thanks, didn't pay much attention to the scenario when # tensors goes high. But I think the in
operation won't take much time in this script, how do you think.
I wrote a more powerful and pip installable tool recently, you can check this out:
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey man, use
visited_data = set()
and
visited_data.update([data_ptr])
for much faster loop