Created
December 16, 2011 20:24
-
-
Save jseabold/1487815 to your computer and use it in GitHub Desktop.
Use a decorator to transform the params input to a likelihood function.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Use a decorator to transform the params input to a likelihood function. | |
""" | |
from functools import wraps | |
from contextlib import contextmanager | |
import numpy as np | |
import inspect | |
def make_doc(func, wrapper): | |
argspec = inspect.getargspec(func) | |
formatted = inspect.formatargspec(argspec.args, varargs=argspec.varargs, | |
defaults=argspec.defaults) | |
return "%s%s\n%s" % (func.func_name, formatted, wrapper.__doc__) | |
# use a factory to make the decorators | |
class GenericTransform(object): | |
def __init__(self, transform): | |
self.transform = transform | |
def __call__(self, func): | |
transform = self.transform | |
@wraps(func) | |
def wrapper(self, *args, **kwargs): | |
if self._transparams: | |
params = transform(args[0]) | |
args = (params,) + args[1:] | |
return func(self, *args, **kwargs) | |
wrapper.__doc__ = make_doc(func, wrapper) | |
return wrapper | |
def transform_factory(transform): | |
return GenericTransform(transform) | |
# allow the decorator to explicitly take the transformation function | |
class transform2(object): | |
def __init__(self, func): | |
self.transform = func | |
def __call__(self, func, *args, **kwargs): | |
transform = self.transform | |
@wraps(func) | |
def wrapper(self, *args, **kwargs): | |
if self._transparams: | |
params = transform(args[0]) | |
args = (params,) + args[1:] | |
return func(self, *args, **kwargs) | |
wrapper.__doc__ = make_doc(func, wrapper) | |
return wrapper | |
def set_transform(func): | |
@contextmanager | |
def transparams(self): | |
self._transparams = True | |
yield | |
self._transparams = False | |
@wraps(func) | |
def wrapper(self, *args, **kwargs): | |
with transparams(self): | |
return func(self, *args, **kwargs) | |
wrapper.__doc__ = make_doc(func, wrapper) | |
return wrapper | |
####### Example ######### | |
# an example transformation function | |
def _olsen_reparam(params): | |
""" | |
Go from true parameters to gamma and theta of Olsen | |
gamma = beta/sigma | |
theta = 1/sigma | |
""" | |
beta, sigma = params[:-1], params[-1] | |
theta = 1./sigma | |
gamma = beta/sigma | |
return gamma, theta | |
class A(object): | |
_transform = transform_factory(_olsen_reparam) | |
#@transform2(_olsen_reparam) | |
@_transform | |
def loglike(self, params, extra=None): | |
""" | |
I am the help of the loglike | |
""" | |
return params | |
@set_transform | |
def fit(self, params): | |
""" | |
I am the help of the fit function. | |
""" | |
params = self.loglike(params) | |
return params | |
class B(object): | |
""" | |
No decorators, spaghetti-code | |
""" | |
def loglike(self, params): | |
if self._transparams: | |
params = _olsen_reparam(params) | |
return params | |
def fit(self, params): | |
self._transparams = True | |
params = self.loglike(params) | |
self._transparamse = False | |
return params | |
a = A() | |
a.fit(np.array([1,2,3])) | |
b = B() | |
b.fit(np.array([1,2,3])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment