Last active
August 15, 2016 08:20
-
-
Save nasimrahaman/4fd66225901808586f02adabf9b7775d to your computer and use it in GitHub Desktop.
A dictionary-ready wrapper for theano functions.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__author__ = "nasim.rahaman at iwr.uni-heidelberg.de" | |
__doc__ = """A few bells and whistles for the theano function callable. | |
Examples: | |
import theano.tensor as T | |
x = T.scalar() | |
y = T.scalar() | |
f1 = function(inputs={'x': x, 'y': y}, outputs={'z1': x + y, 'z2': x + 2*y}) | |
f1(x=2, y=3) | |
# Output: {'z1': 5, 'z2': 8} | |
f2 = function(inputs={'x': x, 'y': y}, outputs={'z12': [x + y, x + 2*y]}) | |
f2(x=2, y=3) | |
# Output: {'z12': [5, 8]} | |
f3 = function(inputs=[x, y], outputs=[x + y, x + 2*y]) | |
f3(2, 3) | |
# Output: (5, 8) | |
f4 = function(inputs=[x, y], outputs={'z1': x + y, 'z2': x + 2*y}) | |
f4(2, 3) | |
# Output: {'z1': 5, 'z2': 8} | |
Can be useful for e.g. having a theano function return the gradients | |
of a variable w.r.t. multiple theano variables alongside (say) the cost scalar, | |
without having to worry about the output ordering/bookkeeping. | |
P.S. Dictionary outputs for theano functions is built in, but they can't (yet) | |
organize outputs to lists. That's what this wrapper is intended for. | |
""" | |
import theano as th | |
import numpy as np | |
class pyk: | |
@staticmethod | |
# Convert a tuple or a non iterable to a list, simultaneously | |
def obj2list(obj, ndarray2list=True): | |
listlike = (list, tuple, np.ndarray) if ndarray2list else (list, tuple) | |
# Try-except clause may not work here because layertrain is an iterator and can be converted to list | |
if isinstance(obj, listlike): | |
return list(obj) | |
else: | |
return [obj] | |
@staticmethod | |
def delist(l): | |
if isinstance(l, (list, tuple)) and len(l) == 1: | |
return l[0] | |
else: | |
return l | |
@staticmethod | |
# Function to fold a list according to a given lenlist. For l = [a, b, c, d, e] and lenlist = [1, 1, 2, 1], | |
# unflatten(l) = [a, b, [c, d], e] | |
def unflatten(l, lenlist): | |
assert len(l) == sum(lenlist), "Provided length list is not consistent with the list length." | |
lc = l[:] | |
outlist = [] | |
for len_ in lenlist: | |
outsublist = [] | |
for _ in range(len_): | |
outsublist.append(lc.pop(0)) | |
outlist.append(pyk.delist(outsublist)) | |
return outlist | |
@staticmethod | |
def flatten(*args): | |
return (result for mid in args for result in (pyk.flatten(*mid) if isinstance(mid, (tuple, list)) else (mid,))) | |
@staticmethod | |
# Smart len function that doesn't break when input is not a list/tuple | |
def smartlen(l): | |
if isinstance(l, (list, tuple)): | |
return len(l) | |
else: | |
return 1 | |
# Generic class for functions | |
class function(object): | |
def __init__(self, inputs, outputs, mode=None, updates=None, givens=None, no_default_updates=False, | |
accept_inplace=False, name=None, rebuild_strict=True, allow_input_downcast=None, profile=None, | |
on_unused_input='raise'): | |
""" | |
A simple wrapper for theano functions (can be used with Lasagne), with added syntactic sugar. | |
:type inputs: list or dict | |
:param inputs: List of inputs, or alternatively a dict with {'name1': var1, ...}. | |
:type outputs: list or dict | |
:param outputs: List of outputs, or alternatively a dict with {'name1': var1, ...}. | |
:type mode: str or theano.function.Mode | |
:param mode: Compilation Mode. | |
:type updates: list or tuple or dict | |
:param updates: Expressions for new SharedVariable values. Must be iterable over pairs of | |
(shared_variable, update expression) | |
:type givens: list or tuple or dict | |
:param givens: Substitutions to make in the computational graph. Must be iterable over pairs of variables | |
(var1, var2) where var2 replaces var1 in the computational graph. | |
:type no_default_updates: bool or list | |
:param no_default_updates: If True: whether to update variables. See official theano documentation here: | |
http://deeplearning.net/software/theano/library/compile/function.html#function.function | |
:type accept_inplace: bool | |
:param accept_inplace: See official theano documentation: | |
http://deeplearning.net/software/theano/library/compile/function.html#function.function | |
:type name: str | |
:param name: Name of the function. Useful for profiling. | |
:type rebuild_strict: bool | |
:param rebuild_strict: See official theano documentation: | |
http://deeplearning.net/software/theano/library/compile/function.html#function.function | |
:type allow_input_downcast: bool | |
:param allow_input_downcast: Whether to allow the input to be downcasted to floatX. | |
:type profile: bool | |
:param profile: Whether to profile function. See official theano documentation: | |
http://deeplearning.net/software/theano/library/compile/function.html#function.function | |
:type on_unused_input: str | |
:param on_unused_input: What to do if an input is not used. | |
""" | |
# Meta | |
self.inputs = inputs | |
self.outputs = outputs | |
self.mode = mode | |
self.updates = updates | |
self.givens = givens | |
self.no_default_updates = no_default_updates | |
self.accept_inplace = accept_inplace | |
self.name = name | |
self.rebuild_strict = rebuild_strict | |
self.allow_input_downcast = allow_input_downcast | |
self.profile = profile | |
self.on_unused_input = on_unused_input | |
# Function containers | |
self._thfunction = None | |
self._function = self.__call__ | |
# Compile function | |
self.compile() | |
def compile(self): | |
# If self.inputs is a dict, it must be parsed as kwargs | |
# If self.outputs is a dict, the output of the compiled function must be parsed to a dict | |
# Step 1. Compile theano function. | |
# Fetch input list | |
inplist = self.inputs if isinstance(self.inputs, list) else self.inputs.values() \ | |
if isinstance(self.inputs, dict) else [self.inputs] | |
# Flatten inplist to a list | |
inplist = list(pyk.flatten(inplist)) | |
# Fetch output list | |
outlist = self.outputs if isinstance(self.outputs, list) else self.outputs.values() \ | |
if isinstance(self.outputs, dict) else [self.outputs] | |
# Flatten outlist | |
outlist = pyk.delist(list(pyk.flatten(outlist))) | |
# Compile | |
thfunction = th.function(inputs=inplist, outputs=outlist, mode=self.mode, updates=self.updates, | |
givens=self.givens, no_default_updates=self.no_default_updates, | |
accept_inplace=self.accept_inplace, name=self.name, rebuild_strict=self.rebuild_strict, | |
allow_input_downcast=self.allow_input_downcast, profile=self.profile, | |
on_unused_input=self.on_unused_input) | |
# Write to container | |
self._thfunction = thfunction | |
return thfunction | |
def __call__(self, *args, **kwargs): | |
# This function wraps the compiled theano function. | |
# ------------------------------------------------------ | |
# Don't allow args if self.inputs is a dictionary. This is because the user can not be expected to know | |
# exactly how a dictionary is ordered, unless the dictionary is ordered. | |
args = list(args) | |
if isinstance(self.inputs, dict): | |
assert not args, "Antipasti function object expects keyword arguments because the " \ | |
"provided input was a dict." | |
if isinstance(self.inputs, list): | |
assert not kwargs, "Keywords could not be parsed by the Antipasti function object." | |
# Flatten kwargs or args | |
if args: | |
funcargs = list(pyk.flatten(args)) | |
else: | |
funcargs = list(pyk.flatten(kwargs.values())) | |
# Evaluate function | |
outlist = pyk.obj2list(self._thfunction(*funcargs), ndarray2list=False) | |
# Parse output list | |
expoutputs = self.outputs.values() if isinstance(self.outputs, dict) else self.outputs | |
expoutputs = pyk.obj2list(expoutputs, ndarray2list=False) | |
# Make sure the theano function has returned the correct number of outputs | |
assert len(outlist) == len(list(pyk.flatten(expoutputs))), "Number of outputs returned by the theano function " \ | |
"is not consistent with the number of expected " \ | |
"outputs." | |
# Unflatten theano function output (outlist) | |
# Get list with sublist lengths | |
lenlist = [pyk.smartlen(expoutput) for expoutput in expoutputs] | |
# Unflatten outlist | |
outputs = pyk.unflatten(outlist, lenlist) | |
# Write to dictionary if self.outputs is a dictionary | |
if isinstance(self.outputs, dict): | |
outputs = {outname: outvar for outname, outvar in zip(self.outputs.keys(), outputs)} | |
elif isinstance(self.outputs, list): | |
outputs = tuple(outputs) | |
else: | |
outputs = pyk.delist(outputs) | |
return outputs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment