Skip to content

Instantly share code, notes, and snippets.

@fpaupier
Created February 11, 2024 09:40
Show Gist options
  • Save fpaupier/5eee2eb9279137b59144f4c0d2b511f4 to your computer and use it in GitHub Desktop.
Save fpaupier/5eee2eb9279137b59144f4c0d2b511f4 to your computer and use it in GitHub Desktop.
Get torch device
import torch
import subprocess
def get_device():
# Check for CUDA GPU
if torch.cuda.is_available():
return 'cuda'
# Check for Apple Silicon (M1/M2) using sysctl
try:
result = subprocess.run(['sysctl', '-in', 'sysctl.proc_translated'],
capture_output=True, text=True)
if result.returncode == 0 and result.stdout.strip() == '1':
# Running under Rosetta 2, so it's an Apple Silicon Mac
return 'mps'
elif result.returncode == 0:
# Not running under Rosetta 2, check for ARM architecture
uname_result = subprocess.run(['uname', '-m'],
capture_output=True, text=True)
if uname_result.stdout.strip() == 'arm64':
return 'mps'
except Exception as e:
print(f"An error occurred while checking for Apple Silicon: {e}")
# Default to CPU if no GPU or Apple Silicon is detected
return 'cpu'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment