Skip to content

Instantly share code, notes, and snippets.

@darth-veitcher
Last active April 11, 2023 07:44
Show Gist options
  • Save darth-veitcher/0461ce520ddaf2ffa6183dc2f79b985f to your computer and use it in GitHub Desktop.
Save darth-veitcher/0461ce520ddaf2ffa6183dc2f79b985f to your computer and use it in GitHub Desktop.
Set Optimum CUDA device

Optimum CUDA Device

Simple script to detect operating system, GPU architecture, and return device for CUDA usage. Useful for when running on MacOS with Apple Silicon and need to swap out hardcoded cuda:0 type strings for mps.

"""Uses some sensible logic to determine platform and best available device for pytorch.

Assumed combinations (in order of preference):

* CUDA (nvidia GPU) / AMD (ROCm)
* MPS (Apple Silicon)
* CPU

Fallback:
* CPU
"""
import logging
import os
from enum import Enum
from typing import Optional

import torch

# Logging
format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logging.basicConfig(format=format)
log = logging.getLogger("pytorch device")
log.setLevel(os.environ.get("LOG_LEVEL", logging.DEBUG))


class DEVICE_TYPE(Enum):
    CUDA = "cuda"
    MPS = "mps"
    ROCm = "cuda"  # ROCm acts as a fake CUDA device
    CPU = "cpu"


def get_best_pytorch_device(
    device_type: Optional[DEVICE_TYPE] = None, device_number: int = 0
) -> torch.device:
    dev: torch.device = None
    # Override if device_type specified
    if device_type:
        dev = torch.device(
            f"{device_type.value}{':' & device_number if device_number else ''}"
        )
    # Detect CUDA and number of devices
    elif torch.cuda.is_available():
        log.debug(
            f"CUDA detected. Found {torch.cuda.device_count()} device{'s' if len(torch.cuda.device_count) > 1 else ''}."
        )
        dev = torch.device(
            f"{DEVICE_TYPE.CUDA.value}:{device_number if device_number else 0}"
        )
    # Detect MPS backend for Apple Silicon
    elif torch.backends.mps.is_available():
        dev = torch.device(DEVICE_TYPE.MPS.value)
        log.debug(f"MPS device found. Assuming Apple Silicon.")
        # NotImplementedError: The operator 'aten::linalg_cholesky_ex.L' is not currently implemented
        # for the MPS device. If you want this op to be added in priority during the prototype phase of this feature,
        # please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the
        # environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this
        # will be slower than running natively on MPS.
        log.debug(f"Setting environment `PYTORCH_ENABLE_MPS_FALLBACK=1`.")
        os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
    # Use CPU
    else:
        dev = torch.device(DEVICE_TYPE.CPU.value)
        log.debug("No GPU devices found. Defaulting to CPU.")
    log.info(f"Set device to: {dev}")
    return dev


if __name__ == "__main__":
    get_best_pytorch_device()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment