Last active
August 8, 2016 16:31
-
-
Save brianthelion/d512de8493da9791c93752b07b3d6854 to your computer and use it in GitHub Desktop.
Decorators for numpy-to-matlab regression automation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import decorator | |
class RuntimeInsertionManager(object): | |
def insert_with_args(self, *args, **dargs): | |
@decorator.decorator | |
def _wrapper(wrapped, *_args, **_dargs): | |
if hasattr(self, 'callback'): | |
return self.callback(wrapped, _args, _dargs, args, dargs) | |
return wrapped(*_args, **_dargs) | |
return _wrapper | |
MANAGER = RuntimeInsertionMananger() | |
mimic = MANAGER.insert_with_args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from rt_insert import MANAGER | |
from whatever_ported import whatever_ported | |
try: | |
import matlab.engine as engine | |
except ImportError: | |
try: | |
import oct2py | |
engine = oct2py.Oct2Py(convert_to_float=False) | |
except ImportError: | |
raise ImportError("Found neither 'matlab.engine' nor 'oct2py'") | |
engine.addpath('./') | |
def callback(func, func_args, func_dargs, dec_args, dec_dargs): | |
mfunc_name = dec_args[0] | |
mfunc_result = getattr(engine, mfunc_name)(*func_args) | |
local_result = func(*func_args, **func_dargs) | |
assert type(mfunc_result) == type(local_result), \ | |
"octave type={} != numpy type={}".format(type(mfunc_result), type(local_result)) | |
if type(mfunc_result) == type(local_result) == type(np.array([])): | |
assert mfunc_result.dtype == local_result.dtype, \ | |
"octave dtype={} != numpy dtype={}".format(mfunc_result.dtype, local_result.dtype) | |
assert mfunc_result.shape == local_result.shape, \ | |
"octave shape={} != numpy shape={}".format(mfunc_result.shape, local_result.shape) | |
assert np.allclose(mfunc_result, local_result) | |
return local_result | |
MANAGER.callback = callback | |
TEST_CASES = [(1, 2), (3, 4), (5, 6)] | |
def test_whatever_ported(): | |
for a, b in TEST_CASES: | |
yield whatever_ported, a, b |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function c = whatever(a, b) | |
c = a+b |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from rt_insert import mimic | |
@mimic('whatever') | |
def whatever_ported(a, b): | |
return a+b |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment