Last active
April 9, 2020 16:02
-
-
Save seberg/ce4563f3cb00e33997e4f80675b80953 to your computer and use it in GitHub Desktop.
Snippet that loads numpy and tries to guess (and give a bit of information) the BLAS/LAPACK implementation loaded at runtime
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Simply run the script to try to guess some information about how numpy | |
is linked. | |
If there is something odd going on, run/import the script after | |
your real import of numpy. | |
All versions tested on Linux, MKL is confusing me a bit since both lower and | |
upper case versions exist. | |
Yes, multiple/different BLAS versions are possible if linking is odd. | |
Different can happen if the compile time setup/distutils is confused, | |
multiple can happen for example if a `libblas` is already loaded when | |
numpy gets imported. In that case that BLAS may be used, while LAPACK | |
still loads another BLAS implementation (or something like it anyway) | |
On Linux, also print out the information from `ldd` for both files. | |
""" | |
import numpy as np | |
import ctypes | |
import sys | |
import subprocess | |
try: | |
multiarray = np.core._multiarray_umath.__file__ | |
except: | |
multiarray = np.core.multiarray.__file__ | |
linalg = np.linalg.linalg._umath_linalg.__file__ | |
def print_info_inspecting_symbols(): | |
"""Prints out information by checking what symbols are defined | |
when loading the numpy modules using BLAS. | |
""" | |
for library in [multiarray]: # , linalg (not usually necessary to probe) | |
if library == multiarray: | |
print('Probing Multiarray') | |
print('------------------') | |
else: | |
print() | |
print('Probing Linalg') | |
print('--------------') | |
dll = ctypes.CDLL(library) | |
blas = [] | |
implementations = {"openblas_get_num_threads": "openblas", | |
"ATL_buildinfo": "atlas", | |
"bli_thread_get_num_threads": "blis", | |
"MKL_Get_Max_Threads": "MKL", | |
"_APL_dgemm": "accelerate",} | |
for func, implementation in implementations.items(): | |
try: | |
getattr(dll, func) | |
blas.append(implementation) | |
except: | |
continue | |
if len(blas) > 1: | |
print(" WARNING: Multiple BLAS/LAPACK libs loaded:", blas) | |
print() | |
if len(blas) == 0: | |
print("Unable to guess BLAS implementation, it is not one of:", | |
implementations.values()) | |
print(" or additional symbols are not loaded?!") | |
for impl in blas: | |
if impl == "openblas": | |
dll.openblas_get_config.restype = ctypes.c_char_p | |
print("OpenBLAS:") | |
dll.openblas_get_num_threads.restype = ctypes.c_int | |
print(" num threads:", dll.openblas_get_num_threads()) | |
print(" version info:", dll.openblas_get_config().decode('utf8')) | |
elif impl == "blis": | |
dll.bli_thread_get_num_threads.restype = ctypes.c_int | |
print("BLIS running with:", dll.bli_thread_get_num_threads(), "threads.") | |
print(" threads enabled?:", dll.bli_info_get_enable_threading()) | |
elif impl == "MKL": | |
try: | |
version_func = dll.mkl_get_version_string | |
version_func.argtypes = (ctypes.c_char_p, ctypes.c_int) | |
out_buf = ctypes.c_buffer(500) | |
version_func(out_buf, 500) | |
version_info = out_buf.value.decode("utf8") | |
except: | |
version_info = "Fetching version info failed (probably can't happen?)" | |
print("MKL: ") | |
print(" max threads:", dll.MKL_Get_Max_Threads()) | |
print(" version info:", version_info) | |
elif impl == "atlas": | |
print("ATLAS:") | |
print(" ATLAS is threadsafe, max number of threads are fixed at compile time") | |
print(" version info printed by ATLAS:") | |
dll.ATL_buildinfo() | |
elif impl == "accelerate": | |
print("Accelerate:") | |
print(" Accelerate is buggy, please do not use it!") | |
else: | |
print("Found BLAS/LAPACK implementation:", impl) | |
def print_ldd_info(): | |
""" | |
Print out LDD information on linux, maybe it is useful to get information | |
for some bug reports. | |
""" | |
print("LDD information:") | |
print("----------------") | |
for library in [multiarray, linalg]: | |
command = "ldd {}".format(library) | |
print("running:", command) | |
print(subprocess.getoutput(command)) | |
print_info_inspecting_symbols() | |
if sys.platform == "linux": | |
print() | |
# print_ldd_info() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment