Skip to content

Instantly share code, notes, and snippets.

@hadifar
Last active October 28, 2020 09:12
Show Gist options
  • Save hadifar/8d2be5b51dc81fbc2418acfe5a41ab57 to your computer and use it in GitHub Desktop.
Save hadifar/8d2be5b51dc81fbc2418acfe5a41ab57 to your computer and use it in GitHub Desktop.
import torch
from transformers import *
import sys, logging
print('cuda available? ', torch.cuda.is_available())
print('how many gpus?', torch.cuda.device_count())
logging.root.handlers = []
logging.basicConfig(level="INFO", format='%(asctime)s:%(levelname)s: %(message)s', stream=sys.stdout)
logger = logging.getLogger(__name__)
logger.info('hello')
def check_memory():
logger.info('GPU memory: %.1f' % (torch.cuda.memory_allocated() // 1024 ** 2))
device = torch.device('cuda')
torch.cuda.empty_cache()
check_memory()
model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-large')
logger.info('moving model to GPU')
gpu_model = model.to(device)
print('-' * 50)
print('single model memory usage')
check_memory() # the model is 2.135 Gb
print('-' * 50)
def run_transformers(x):
# x.requires_grad=False
check_memory()
logger.info('moving tensors to GPU')
x = x.to(device)
check_memory()
logger.info('Running bert forward on x')
yhat = gpu_model(x)
check_memory()
logger.info(f'yhat[0].requires_grad = {yhat[0].requires_grad} . Detaching yhat')
yhat = yhat[0].detach()
logger.info(f'x shape = {x.shape}, yhat.shape = {yhat.shape}')
check_memory()
for b in [1, 2, 4, 8, 16, 32]:
print('-' * 50)
torch.cuda.empty_cache()
check_memory()
print('batch size {} analysis'.format(b))
x = torch.randint(low=1000, high=30000, size=(b, 512))
run_transformers(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment