Created
March 4, 2019 16:47
-
-
Save michaelosthege/6953a2af7417da6ebdd41771a9e7e7a8 to your computer and use it in GitHub Desktop.
Custom Theano Op for wrapping around an ODE-integrator.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import hashlib | |
import theano | |
import theano.tensor as tt | |
def make_hashable(obj): | |
"""Makes tuples, lists, dicts, sets and frozensets hashable.""" | |
if isinstance(obj, (tuple, list)): | |
return tuple((make_hashable(e) for e in obj)) | |
if isinstance(obj, dict): | |
return tuple(sorted((k, make_hashable(v)) for k,v in obj.items())) | |
if isinstance(obj, (set, frozenset)): | |
return tuple(sorted(make_hashable(e) for e in obj)) | |
return obj | |
def make_hash_sha256(obj): | |
"""Computes a sha256 hash for the object.""" | |
hasher = hashlib.sha256() | |
hasher.update(repr(make_hashable(obj)).encode()) | |
return base64.b64encode(hasher.digest()).decode() | |
class IntegrationOp(theano.Op): | |
"""This is a theano Op that becomes a node in the computation graph. | |
It is not differentiable, because it uses a 'solver' function that is provided by the user. | |
""" | |
__props__ = ("solver",) | |
def __init__(self, solver): | |
self.solver = solver | |
return super().__init__() | |
def __hash__(self): | |
subhashes = ( | |
hash(type(self)), | |
make_hash_sha256(self.solver) | |
) | |
return hash(subhashes) | |
def make_node(self, y0:list, x, theta:list): | |
# NOTE: theano does not allow a list of tensors to be one of the inputs | |
# that's why they have to be tt.stack()ed which also merges them into one dtype! | |
# TODO: check dtypes and raise warnings | |
y0 = tt.stack([tt.as_tensor_variable(y) for y in y0]) | |
theta = tt.stack([tt.as_tensor_variable(t) for t in theta]) | |
x = tt.as_tensor_variable(x) | |
apply_node = theano.Apply(self, | |
[y0, x, theta], # symbolic inputs: y0 and theta | |
[tt.dmatrix()]) # symbolic outputs: Y_hat | |
# NOTE: to support multiple different dtypes as transient variables, the | |
# output type would have to be a list of dvector/svectors. | |
return apply_node | |
def perform(self, node, inputs, output_storage): | |
# this performs the actual simulation using the provided solver | |
# which takes actual y0/x/theta values and returns a matrix | |
y0, x, theta = inputs | |
Y_hat = self.solver(y0, x, theta) # solve for all x | |
output_storage[0][0] = Y_hat | |
return | |
def grad(self, inputs, outputs): | |
return [theano.gradient.grad_undefined(self, k, inp, | |
'No gradient defined through Python-wrapping IntegrationOp.') | |
for k, inp in enumerate(inputs)] | |
def infer_shape(self, node, input_shapes): | |
s_y0, s_x, s_theta = input_shapes | |
output_shapes = [(s_y0[0],s_x[0])] | |
return output_shapes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment