Simple script to detect operating system, GPU architecture, and return device
for CUDA usage. Useful for when running on MacOS with Apple Silicon and need to swap out hardcoded cuda:0
type strings for mps
.
"""Uses some sensible logic to determine platform and best available device for pytorch.
Assumed combinations (in order of preference):
* CUDA (nvidia GPU) / AMD (ROCm)
* MPS (Apple Silicon)
* CPU
Fallback:
* CPU
"""
import logging
import os
from enum import Enum
from typing import Optional
import torch
# Logging
format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logging.basicConfig(format=format)
log = logging.getLogger("pytorch device")
log.setLevel(os.environ.get("LOG_LEVEL", logging.DEBUG))
class DEVICE_TYPE(Enum):
CUDA = "cuda"
MPS = "mps"
ROCm = "cuda" # ROCm acts as a fake CUDA device
CPU = "cpu"
def get_best_pytorch_device(
device_type: Optional[DEVICE_TYPE] = None, device_number: int = 0
) -> torch.device:
dev: torch.device = None
# Override if device_type specified
if device_type:
dev = torch.device(
f"{device_type.value}{':' & device_number if device_number else ''}"
)
# Detect CUDA and number of devices
elif torch.cuda.is_available():
log.debug(
f"CUDA detected. Found {torch.cuda.device_count()} device{'s' if len(torch.cuda.device_count) > 1 else ''}."
)
dev = torch.device(
f"{DEVICE_TYPE.CUDA.value}:{device_number if device_number else 0}"
)
# Detect MPS backend for Apple Silicon
elif torch.backends.mps.is_available():
dev = torch.device(DEVICE_TYPE.MPS.value)
log.debug(f"MPS device found. Assuming Apple Silicon.")
# NotImplementedError: The operator 'aten::linalg_cholesky_ex.L' is not currently implemented
# for the MPS device. If you want this op to be added in priority during the prototype phase of this feature,
# please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the
# environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this
# will be slower than running natively on MPS.
log.debug(f"Setting environment `PYTORCH_ENABLE_MPS_FALLBACK=1`.")
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# Use CPU
else:
dev = torch.device(DEVICE_TYPE.CPU.value)
log.debug("No GPU devices found. Defaulting to CPU.")
log.info(f"Set device to: {dev}")
return dev
if __name__ == "__main__":
get_best_pytorch_device()