Created
December 29, 2011 02:30
-
-
Save josharian/1531289 to your computer and use it in GitHub Desktop.
Sample theano trampoline for supporting generic variables
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
sample theano trampoline | |
""" | |
import numpy | |
import theano | |
import theano.tensor as T | |
def landing_key(value): | |
return type(value) # TODO: Include numpy shape, dtype, stride; autocasting? | |
class BounceVariable(object): | |
def __init__(self, name): | |
self.name = name | |
self.landing_cache = {} | |
def land(self, value): | |
key = landing_key(value) | |
if key not in self.landing_cache: | |
if isinstance(value, (int, float)): | |
landed = T.dscalar(self.name) | |
else: | |
landed = T.dmatrix(self.name) | |
self.landing_cache[key] = landed | |
return self.landing_cache[key] | |
def __add__(self, other): | |
return BounceOp("__add__", self, other) | |
class BounceOp(object): | |
def __init__(self, wrapped_op, *variables): | |
self.wrapped_op = wrapped_op | |
self.variables = variables | |
def generate(self, inputs_dict): | |
landed_variables = [] | |
for variable in self.variables: | |
if isinstance(variable, BounceVariable): | |
landed = variable.land(inputs_dict[variable]) | |
elif isinstance(variable, BounceOp): | |
landed = variable.generate(inputs_dict) | |
elif isinstance(variable, (int, float, numpy.ndarray, list)): | |
landed = variable | |
else: | |
raise AssertionError() | |
landed_variables.append(landed) | |
# TODO: Recurse to new ops | |
if isinstance(self.wrapped_op, str): | |
return getattr(landed_variables[0], self.wrapped_op)(*landed_variables[1:]) | |
else: | |
# TODO: Handle "standalone" ops | |
raise NotImplementedError() | |
def __add__(self, other): | |
return BounceOp("__add__", self, other) | |
def __getattr__(self, name): | |
def callable(self, *args): | |
return BounceOp(name, *args) | |
class FunctionTrampoline(object): | |
def __init__(self, inputs, output): | |
self.inputs = inputs | |
self.output = output | |
self.cached_functions = {} | |
def __call__(self, *args): | |
land_key = tuple(map(landing_key, args)) | |
if land_key not in self.cached_functions: | |
inputs_dict = dict(zip(self.inputs, args)) | |
landed_inputs_list = [inp.land(arg) for inp, arg in inputs_dict.iteritems()] | |
landed_output = self.output.generate(inputs_dict) | |
self.cached_functions[land_key] = theano.function(landed_inputs_list, landed_output) | |
return self.cached_functions[land_key](*args) | |
def create_scalar_adder(): | |
x = T.dscalar('x') | |
y = T.dscalar('y') | |
w = T.dscalar('y') | |
z = x + y + w | |
f = theano.function([x, y, w], z) | |
return f | |
def create_matrix_adder(): | |
x = T.dmatrix('x') | |
y = T.dmatrix('y') | |
w = T.dmatrix('w') | |
z = x + y + w | |
f = theano.function([x, y, w], z) | |
return f | |
def create_generic_adder(): | |
x = BounceVariable('x') | |
y = BounceVariable('y') | |
w = BounceVariable('w') | |
z = x + y + w | |
f = FunctionTrampoline([x, y, w], z) | |
return f | |
def main(): | |
scalar_adder = create_scalar_adder() | |
assert 9 == scalar_adder(2, 3, 4) | |
matrix_adder = create_matrix_adder() | |
assert [[9]] == matrix_adder([[2]], [[3]], [[4]]) | |
try: | |
scalar_adder([[2]], [[3]], [[4]]) | |
raise AssertionError | |
except TypeError: | |
pass | |
try: | |
matrix_adder(2, 3, 4) | |
raise AssertionError | |
except TypeError: | |
pass | |
generic_adder = create_generic_adder() | |
assert 9 == generic_adder(2, 3, 4) | |
assert [[9]] == generic_adder([[2]], [[3]], [[4]]) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment