Skip to content

Instantly share code, notes, and snippets.

@ales-erjavec
Last active December 31, 2021 08:46
Show Gist options
  • Save ales-erjavec/d64f06e6e0b2097c3bf9be845c20b40f to your computer and use it in GitHub Desktop.
Save ales-erjavec/d64f06e6e0b2097c3bf9be845c20b40f to your computer and use it in GitHub Desktop.
Inspect/config numpy-linked OpenBLAS
"""
Inspect/config numpy-linked OpenBLAS
"""
import re
import ctypes
import numpy.core
class openblas_config:
def get_config(self) -> bytes: ...
def get_num_threads(self) -> int: ...
def set_num_threads(self, i:int): ...
def numpy_openblas():
# type: () -> Optional[openblas_config]
try:
handle = ctypes.CDLL(numpy.core.multiarray.__file__)
except OSError:
handle = ctypes.CDLL(numpy.core._multiarray_umath.__file__)
try:
openblas_get_config = handle["openblas_get_config"]
openblas_get_num_threads = handle["openblas_get_num_threads"]
openblas_set_num_threads = handle["openblas_set_num_threads"]
except AttributeError:
return None
openblas_get_config.argtypes = []
openblas_get_config.restype = ctypes.c_char_p # pointer to static memory
openblas_get_num_threads.argtypes = []
openblas_get_num_threads.restype = ctypes.c_int
openblas_set_num_threads.argtypes = [ctypes.c_int]
openblas_set_num_threads.restype = None
config = openblas_config()
config.get_config = openblas_get_config
config.set_num_threads = openblas_set_num_threads
config.get_num_threads = openblas_get_num_threads
return config
def numpy_openblas_get_config_dict():
# type: () -> Optional[Dict[str, Any]]
cfg = numpy_openblas()
if cfg is None:
return None
config = cfg.get_config()
try:
match = next(re.finditer(br"MAX_THREADS=(\d+)", config))
except StopIteration:
return {"MAX_THREADS": 1}
else:
return {"MAX_THREADS": int(match.group(1))}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment