Skip to content

Instantly share code, notes, and snippets.

@seberg
Last active April 9, 2020 16:02
Show Gist options
  • Save seberg/ce4563f3cb00e33997e4f80675b80953 to your computer and use it in GitHub Desktop.
Save seberg/ce4563f3cb00e33997e4f80675b80953 to your computer and use it in GitHub Desktop.
Snippet that loads numpy and tries to guess (and give a bit of information) the BLAS/LAPACK implementation loaded at runtime
"""
Simply run the script to try to guess some information about how numpy
is linked.
If there is something odd going on, run/import the script after
your real import of numpy.
All versions tested on Linux, MKL is confusing me a bit since both lower and
upper case versions exist.
Yes, multiple/different BLAS versions are possible if linking is odd.
Different can happen if the compile time setup/distutils is confused,
multiple can happen for example if a `libblas` is already loaded when
numpy gets imported. In that case that BLAS may be used, while LAPACK
still loads another BLAS implementation (or something like it anyway)
On Linux, also print out the information from `ldd` for both files.
"""
import numpy as np
import ctypes
import sys
import subprocess
try:
multiarray = np.core._multiarray_umath.__file__
except:
multiarray = np.core.multiarray.__file__
linalg = np.linalg.linalg._umath_linalg.__file__
def print_info_inspecting_symbols():
"""Prints out information by checking what symbols are defined
when loading the numpy modules using BLAS.
"""
for library in [multiarray]: # , linalg (not usually necessary to probe)
if library == multiarray:
print('Probing Multiarray')
print('------------------')
else:
print()
print('Probing Linalg')
print('--------------')
dll = ctypes.CDLL(library)
blas = []
implementations = {"openblas_get_num_threads": "openblas",
"ATL_buildinfo": "atlas",
"bli_thread_get_num_threads": "blis",
"MKL_Get_Max_Threads": "MKL",
"_APL_dgemm": "accelerate",}
for func, implementation in implementations.items():
try:
getattr(dll, func)
blas.append(implementation)
except:
continue
if len(blas) > 1:
print(" WARNING: Multiple BLAS/LAPACK libs loaded:", blas)
print()
if len(blas) == 0:
print("Unable to guess BLAS implementation, it is not one of:",
implementations.values())
print(" or additional symbols are not loaded?!")
for impl in blas:
if impl == "openblas":
dll.openblas_get_config.restype = ctypes.c_char_p
print("OpenBLAS:")
dll.openblas_get_num_threads.restype = ctypes.c_int
print(" num threads:", dll.openblas_get_num_threads())
print(" version info:", dll.openblas_get_config().decode('utf8'))
elif impl == "blis":
dll.bli_thread_get_num_threads.restype = ctypes.c_int
print("BLIS running with:", dll.bli_thread_get_num_threads(), "threads.")
print(" threads enabled?:", dll.bli_info_get_enable_threading())
elif impl == "MKL":
try:
version_func = dll.mkl_get_version_string
version_func.argtypes = (ctypes.c_char_p, ctypes.c_int)
out_buf = ctypes.c_buffer(500)
version_func(out_buf, 500)
version_info = out_buf.value.decode("utf8")
except:
version_info = "Fetching version info failed (probably can't happen?)"
print("MKL: ")
print(" max threads:", dll.MKL_Get_Max_Threads())
print(" version info:", version_info)
elif impl == "atlas":
print("ATLAS:")
print(" ATLAS is threadsafe, max number of threads are fixed at compile time")
print(" version info printed by ATLAS:")
dll.ATL_buildinfo()
elif impl == "accelerate":
print("Accelerate:")
print(" Accelerate is buggy, please do not use it!")
else:
print("Found BLAS/LAPACK implementation:", impl)
def print_ldd_info():
"""
Print out LDD information on linux, maybe it is useful to get information
for some bug reports.
"""
print("LDD information:")
print("----------------")
for library in [multiarray, linalg]:
command = "ldd {}".format(library)
print("running:", command)
print(subprocess.getoutput(command))
print_info_inspecting_symbols()
if sys.platform == "linux":
print()
# print_ldd_info()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment