Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Created March 7, 2023 22:22
Show Gist options
  • Save pszemraj/6fba92d370ca487d86e82c849a929c2e to your computer and use it in GitHub Desktop.
Save pszemraj/6fba92d370ca487d86e82c849a929c2e to your computer and use it in GitHub Desktop.
basic rough fn
import subprocess
import torch
def check_ampere_gpu():
"""Check if the GPU supports NVIDIA Ampere or later and enable FP32 in PyTorch if it does."""
cmd = "nvidia-smi --query-gpu=name --format=csv,noheader"
output = subprocess.check_output(cmd, shell=True, universal_newlines=True)
gpu_name = output.strip()
if "A100" in gpu_name or "A6000" in gpu_name or "RTX 30" in gpu_name:
torch.backends.cuda.matmul.allow_tf32 = True
print("GPU supports NVIDIA Ampere or later, enabled TF32 in PyTorch.")
else:
print("GPU does not support NVIDIA Ampere or later.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment