Skip to content

Instantly share code, notes, and snippets.

@gmarkall
Created August 18, 2020 19:59
Show Gist options
  • Select an option

  • Save gmarkall/99d7a23934840bce3b0918368d9fd5e2 to your computer and use it in GitHub Desktop.

Select an option

Save gmarkall/99d7a23934840bce3b0918368d9fd5e2 to your computer and use it in GitHub Desktop.
diff --git a/numba/cuda/cudadrv/nvvm.py b/numba/cuda/cudadrv/nvvm.py
index 24569f99b..505e6797f 100644
--- a/numba/cuda/cudadrv/nvvm.py
+++ b/numba/cuda/cudadrv/nvvm.py
@@ -272,29 +272,39 @@ data_layout = {
default_data_layout = data_layout[tuple.__itemsize__ * 8]
+_supported_cc = None
-try:
- from numba.cuda.cudadrv.runtime import runtime
- cudart_version_major = runtime.get_version()[0]
-except:
- # The CUDA Runtime may not be present
- cudart_version_major = 0
-
-# List of supported compute capability in sorted order
-if cudart_version_major == 0:
- SUPPORTED_CC = (),
-elif cudart_version_major < 9:
- # CUDA 8.x
- SUPPORTED_CC = (2, 0), (2, 1), (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2)
-elif cudart_version_major < 10:
- # CUDA 9.x
- SUPPORTED_CC = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0)
-elif cudart_version_major < 11:
- # CUDA 10.x
- SUPPORTED_CC = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5)
-else:
- # CUDA 11.0 and later
- SUPPORTED_CC = (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5), (8, 0)
+
+def get_supported_ccs():
+ global _supported_cc
+
+ if _supported_cc:
+ return _supported_cc
+
+ try:
+ from numba.cuda.cudadrv.runtime import runtime
+ cudart_version_major = runtime.get_version()[0]
+ except:
+ # The CUDA Runtime may not be present
+ cudart_version_major = 0
+
+ # List of supported compute capability in sorted order
+ if cudart_version_major == 0:
+ _supported_cc = (),
+ elif cudart_version_major < 9:
+ # CUDA 8.x
+ _supported_cc = (2, 0), (2, 1), (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2)
+ elif cudart_version_major < 10:
+ # CUDA 9.x
+ _supported_cc = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0)
+ elif cudart_version_major < 11:
+ # CUDA 10.x
+ _supported_cc = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5)
+ else:
+ # CUDA 11.0 and later
+ _supported_cc = (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5), (8, 0)
+
+ return _supported_cc
def find_closest_arch(mycc):
@@ -305,7 +315,9 @@ def find_closest_arch(mycc):
:param mycc: Compute capability as a tuple ``(MAJOR, MINOR)``
:return: Closest supported CC as a tuple ``(MAJOR, MINOR)``
"""
- for i, cc in enumerate(SUPPORTED_CC):
+ supported_cc = get_supported_ccs()
+
+ for i, cc in enumerate(supported_cc):
if cc == mycc:
# Matches
return cc
@@ -317,10 +329,10 @@ def find_closest_arch(mycc):
"not supported (requires >=%d.%d)" % (mycc + cc))
else:
# return the previous CC
- return SUPPORTED_CC[i - 1]
+ return supported_cc[i - 1]
# CC higher than supported
- return SUPPORTED_CC[-1] # Choose the highest
+ return supported_cc[-1] # Choose the highest
def get_arch_option(major, minor):
diff --git a/numba/cuda/tests/cudadrv/test_nvvm_driver.py b/numba/cuda/tests/cudadrv/test_nvvm_driver.py
index 0e40a906d..ec6d575b5 100644
--- a/numba/cuda/tests/cudadrv/test_nvvm_driver.py
+++ b/numba/cuda/tests/cudadrv/test_nvvm_driver.py
@@ -1,7 +1,7 @@
from llvmlite.llvmpy.core import Module, Type, Builder
from numba.cuda.cudadrv.nvvm import (NVVM, CompilationUnit, llvm_to_ptx,
set_cuda_kernel, fix_data_layout,
- get_arch_option, SUPPORTED_CC)
+ get_arch_option, get_supported_ccs)
from ctypes import c_size_t, c_uint64, sizeof
from numba.cuda.testing import unittest
from numba.cuda.cudadrv.nvvm import LibDevice, NvvmError
@@ -54,7 +54,7 @@ class TestNvvmDriver(unittest.TestCase):
def test_nvvm_support(self):
"""Test supported CC by NVVM
"""
- for arch in SUPPORTED_CC:
+ for arch in get_supported_ccs():
self._test_nvvm_support(arch=arch)
@unittest.skipIf(True, "No new CC unknown to NVVM yet")
@@ -80,10 +80,11 @@ class TestArchOption(unittest.TestCase):
self.assertEqual(get_arch_option(5, 1), 'compute_50')
self.assertEqual(get_arch_option(3, 7), 'compute_35')
# Test known arch.
- for arch in SUPPORTED_CC:
+ supported_cc = get_supported_ccs()
+ for arch in supported_cc:
self.assertEqual(get_arch_option(*arch), 'compute_%d%d' % arch)
self.assertEqual(get_arch_option(1000, 0),
- 'compute_%d%d' % SUPPORTED_CC[-1])
+ 'compute_%d%d' % supported_cc[-1])
@skip_on_cudasim('NVVM Driver unsupported in the simulator')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment