Skip to content

Instantly share code, notes, and snippets.

@chn-lee-yumi
Created April 29, 2025 07:10
Show Gist options
  • Select an option

  • Save chn-lee-yumi/0ab9ed85599069d21f619b540ea7e1f1 to your computer and use it in GitHub Desktop.

Select an option

Save chn-lee-yumi/0ab9ed85599069d21f619b540ea7e1f1 to your computer and use it in GitHub Desktop.
Torch dtype Check
import importlib
import torch
def test_supported_dtypes(device):
dtypes = [
torch.bool,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
torch.complex64,
torch.complex128,
]
support_results = {}
for dtype in dtypes:
try:
# 创建一个张量并移动到设备
t = torch.tensor([1], dtype=dtype, device=device)
support_results[dtype] = True
except Exception as e:
support_results[dtype] = False
return support_results
device_list = ["cpu", "cuda", "mps"] # 推理设备,可选cpu、cuda、mps
if importlib.util.find_spec("torch_directml") is not None: # 如果支持DirectML,则加入DirectML设备
import torch_directml
if torch_directml.device_count() > 0:
device_list.append(torch_directml.device())
for device in device_list:
try:
# print(f"Testing device: {device}")
device = torch.device(device)
t = torch.tensor([1], device=device)
results = test_supported_dtypes(device)
print(f"Supported dtypes on device {device}:")
for dtype, supported in results.items():
print(f" {dtype}: {'Yes' if supported else 'No'}")
except (AssertionError, RuntimeError):
# print(f"Not supported device: {device}")
continue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment