Created
January 24, 2019 16:09
-
-
Save technic/80e8d95858b187cd8ff8677bd5cc0fbb to your computer and use it in GitHub Desktop.
MKL_NUM_THREADS python context manager
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ctypes | |
class MKLThreads(object): | |
_mkl_rt = None | |
@classmethod | |
def _mkl(cls): | |
if cls._mkl_rt is None: | |
try: | |
cls._mkl_rt = ctypes.CDLL('libmkl_rt.so') | |
except OSError: | |
cls._mkl_rt = ctypes.CDLL('mkl_rt.dll') | |
return cls._mkl_rt | |
@classmethod | |
def get_max_threads(cls): | |
return cls._mkl().mkl_get_max_threads() | |
@classmethod | |
def set_num_threads(cls, n): | |
assert type(n) == int | |
cls._mkl().mkl_set_num_threads(ctypes.byref(ctypes.c_int(n))) | |
def __init__(self, num_threads): | |
self._n = num_threads | |
self._saved_n = self.get_max_threads() | |
def __enter__(self): | |
self.set_num_threads(self._n) | |
return self | |
def __exit__(self, type, value, traceback): | |
self.set_num_threads(self._saved_n) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from unittest import TestCase | |
from mkl import MKLThreads | |
class TestMKLThreads(TestCase): | |
def test_context(self): | |
n = MKLThreads.get_max_threads() | |
self.assertTrue(n > 1, "must run on multi core to test") | |
with MKLThreads(1): | |
self.assertEqual(MKLThreads.get_max_threads(), 1) | |
self.assertEqual(MKLThreads.get_max_threads(), n) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
When using together with multiprocessing module you should invoke
set_num_threads
within child process for example:This would create a pool with each sub-process consuming two cores.