Created
August 18, 2020 19:59
-
-
Save gmarkall/99d7a23934840bce3b0918368d9fd5e2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/numba/cuda/cudadrv/nvvm.py b/numba/cuda/cudadrv/nvvm.py | |
| index 24569f99b..505e6797f 100644 | |
| --- a/numba/cuda/cudadrv/nvvm.py | |
| +++ b/numba/cuda/cudadrv/nvvm.py | |
| @@ -272,29 +272,39 @@ data_layout = { | |
| default_data_layout = data_layout[tuple.__itemsize__ * 8] | |
| +_supported_cc = None | |
| -try: | |
| - from numba.cuda.cudadrv.runtime import runtime | |
| - cudart_version_major = runtime.get_version()[0] | |
| -except: | |
| - # The CUDA Runtime may not be present | |
| - cudart_version_major = 0 | |
| - | |
| -# List of supported compute capability in sorted order | |
| -if cudart_version_major == 0: | |
| - SUPPORTED_CC = (), | |
| -elif cudart_version_major < 9: | |
| - # CUDA 8.x | |
| - SUPPORTED_CC = (2, 0), (2, 1), (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2) | |
| -elif cudart_version_major < 10: | |
| - # CUDA 9.x | |
| - SUPPORTED_CC = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0) | |
| -elif cudart_version_major < 11: | |
| - # CUDA 10.x | |
| - SUPPORTED_CC = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5) | |
| -else: | |
| - # CUDA 11.0 and later | |
| - SUPPORTED_CC = (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5), (8, 0) | |
| + | |
| +def get_supported_ccs(): | |
| + global _supported_cc | |
| + | |
| + if _supported_cc: | |
| + return _supported_cc | |
| + | |
| + try: | |
| + from numba.cuda.cudadrv.runtime import runtime | |
| + cudart_version_major = runtime.get_version()[0] | |
| + except: | |
| + # The CUDA Runtime may not be present | |
| + cudart_version_major = 0 | |
| + | |
| + # List of supported compute capability in sorted order | |
| + if cudart_version_major == 0: | |
| + _supported_cc = (), | |
| + elif cudart_version_major < 9: | |
| + # CUDA 8.x | |
| + _supported_cc = (2, 0), (2, 1), (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2) | |
| + elif cudart_version_major < 10: | |
| + # CUDA 9.x | |
| + _supported_cc = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0) | |
| + elif cudart_version_major < 11: | |
| + # CUDA 10.x | |
| + _supported_cc = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5) | |
| + else: | |
| + # CUDA 11.0 and later | |
| + _supported_cc = (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5), (8, 0) | |
| + | |
| + return _supported_cc | |
| def find_closest_arch(mycc): | |
| @@ -305,7 +315,9 @@ def find_closest_arch(mycc): | |
| :param mycc: Compute capability as a tuple ``(MAJOR, MINOR)`` | |
| :return: Closest supported CC as a tuple ``(MAJOR, MINOR)`` | |
| """ | |
| - for i, cc in enumerate(SUPPORTED_CC): | |
| + supported_cc = get_supported_ccs() | |
| + | |
| + for i, cc in enumerate(supported_cc): | |
| if cc == mycc: | |
| # Matches | |
| return cc | |
| @@ -317,10 +329,10 @@ def find_closest_arch(mycc): | |
| "not supported (requires >=%d.%d)" % (mycc + cc)) | |
| else: | |
| # return the previous CC | |
| - return SUPPORTED_CC[i - 1] | |
| + return supported_cc[i - 1] | |
| # CC higher than supported | |
| - return SUPPORTED_CC[-1] # Choose the highest | |
| + return supported_cc[-1] # Choose the highest | |
| def get_arch_option(major, minor): | |
| diff --git a/numba/cuda/tests/cudadrv/test_nvvm_driver.py b/numba/cuda/tests/cudadrv/test_nvvm_driver.py | |
| index 0e40a906d..ec6d575b5 100644 | |
| --- a/numba/cuda/tests/cudadrv/test_nvvm_driver.py | |
| +++ b/numba/cuda/tests/cudadrv/test_nvvm_driver.py | |
| @@ -1,7 +1,7 @@ | |
| from llvmlite.llvmpy.core import Module, Type, Builder | |
| from numba.cuda.cudadrv.nvvm import (NVVM, CompilationUnit, llvm_to_ptx, | |
| set_cuda_kernel, fix_data_layout, | |
| - get_arch_option, SUPPORTED_CC) | |
| + get_arch_option, get_supported_ccs) | |
| from ctypes import c_size_t, c_uint64, sizeof | |
| from numba.cuda.testing import unittest | |
| from numba.cuda.cudadrv.nvvm import LibDevice, NvvmError | |
| @@ -54,7 +54,7 @@ class TestNvvmDriver(unittest.TestCase): | |
| def test_nvvm_support(self): | |
| """Test supported CC by NVVM | |
| """ | |
| - for arch in SUPPORTED_CC: | |
| + for arch in get_supported_ccs(): | |
| self._test_nvvm_support(arch=arch) | |
| @unittest.skipIf(True, "No new CC unknown to NVVM yet") | |
| @@ -80,10 +80,11 @@ class TestArchOption(unittest.TestCase): | |
| self.assertEqual(get_arch_option(5, 1), 'compute_50') | |
| self.assertEqual(get_arch_option(3, 7), 'compute_35') | |
| # Test known arch. | |
| - for arch in SUPPORTED_CC: | |
| + supported_cc = get_supported_ccs() | |
| + for arch in supported_cc: | |
| self.assertEqual(get_arch_option(*arch), 'compute_%d%d' % arch) | |
| self.assertEqual(get_arch_option(1000, 0), | |
| - 'compute_%d%d' % SUPPORTED_CC[-1]) | |
| + 'compute_%d%d' % supported_cc[-1]) | |
| @skip_on_cudasim('NVVM Driver unsupported in the simulator') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment