Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Created September 7, 2020 15:18
Show Gist options
  • Save MLWhiz/90daf8c9593e21b560a6f43f694a3655 to your computer and use it in GitHub Desktop.
Save MLWhiz/90daf8c9593e21b560a6f43f694a3655 to your computer and use it in GitHub Desktop.
# Whether to train on a gpu
train_on_gpu = torch.cuda.is_available()
print(f'Train on gpu: {train_on_gpu}')# Number of gpus
if train_on_gpu:
gpu_count = torch.cuda.device_count()
print(f'{gpu_count} gpus detected.')
if gpu_count > 1:
multi_gpu = True
else:
multi_gpu = False
if train_on_gpu:
model = model.to('cuda')
if multi_gpu:
model = nn.DataParallel(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment