Created
April 29, 2025 07:10
-
-
Save chn-lee-yumi/0ab9ed85599069d21f619b540ea7e1f1 to your computer and use it in GitHub Desktop.
Torch dtype Check
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import importlib | |
| import torch | |
| def test_supported_dtypes(device): | |
| dtypes = [ | |
| torch.bool, | |
| torch.uint8, | |
| torch.int8, | |
| torch.int16, | |
| torch.int32, | |
| torch.int64, | |
| torch.float16, | |
| torch.float32, | |
| torch.float64, | |
| torch.bfloat16, | |
| torch.complex64, | |
| torch.complex128, | |
| ] | |
| support_results = {} | |
| for dtype in dtypes: | |
| try: | |
| # 创建一个张量并移动到设备 | |
| t = torch.tensor([1], dtype=dtype, device=device) | |
| support_results[dtype] = True | |
| except Exception as e: | |
| support_results[dtype] = False | |
| return support_results | |
| device_list = ["cpu", "cuda", "mps"] # 推理设备,可选cpu、cuda、mps | |
| if importlib.util.find_spec("torch_directml") is not None: # 如果支持DirectML,则加入DirectML设备 | |
| import torch_directml | |
| if torch_directml.device_count() > 0: | |
| device_list.append(torch_directml.device()) | |
| for device in device_list: | |
| try: | |
| # print(f"Testing device: {device}") | |
| device = torch.device(device) | |
| t = torch.tensor([1], device=device) | |
| results = test_supported_dtypes(device) | |
| print(f"Supported dtypes on device {device}:") | |
| for dtype, supported in results.items(): | |
| print(f" {dtype}: {'Yes' if supported else 'No'}") | |
| except (AssertionError, RuntimeError): | |
| # print(f"Not supported device: {device}") | |
| continue |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment