Created
July 11, 2016 18:50
-
-
Save JonathanRaiman/04f59b5141bea6b767dca6af25cfebf5 to your computer and use it in GitHub Desktop.
Scan multi arg in tensorflow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def listify(x): | |
if isinstance(x, tuple): | |
return list(x) | |
return x | |
def awesome_scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, | |
swap_memory=False, name=None): | |
"""scan on the list of tensors unpacked from `elems` on dimension 0. | |
This scan operator repeatedly applies the callable `fn` to a sequence | |
of elements from first to last. The elements are made of the tensors | |
unpacked from `elems` on dimension 0. The callable fn takes two tensors as | |
arguments. The first argument is the accumulated value computed from the | |
preceding invocation of fn. If `initializer` is None, `elems` must contain | |
at least one element, and its first element is used as the initializer. | |
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape | |
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`. | |
Args: | |
fn: The callable to be performed. | |
elems: A tensor to be unpacked on dimension 0. | |
initializer: (optional) The initial value for the accumulator. | |
parallel_iterations: (optional) The number of iterations allowed to run | |
in parallel. | |
back_prop: (optional) True enables back propagation. | |
swap_memory: (optional) True enables GPU-CPU memory swapping. | |
name: (optional) Name prefix for the returned tensors. | |
Returns: | |
A tensor that packs the results of applying `fn` to the list of tensors | |
unpacked from `elems`, from first to last. | |
Raises: | |
TypeError: if `fn` is not callable. | |
Example: | |
```python | |
elems = [1,2,3,4] | |
sums = awesome_scan(lambda s1, s2, x: [s1 + x + s2, 2*s2], elems, | |
initializer=[tf.constant(0, name="name_0"), tf.constant(1, name="name_1")]) | |
[s.eval() for s in sums] | |
``` | |
""" | |
if not callable(fn): | |
raise TypeError("fn must be callable.") | |
with tf.op_scope([elems], name, "awesome_scan"): | |
# Any get_variable calls in fn will cache the first call locally | |
# and not issue repeated network I/O requests for each iteration. | |
varscope = tf.get_variable_scope() | |
varscope_caching_device_was_none = False | |
if varscope.caching_device is None: | |
# TODO(ebrevdo): Change to using colocate_with here and in other methods. | |
varscope.set_caching_device(lambda op: op.device) | |
varscope_caching_device_was_none = True | |
# Convert elems to tensor array. | |
elems = tf.convert_to_tensor(elems, name="elems") | |
n = tf.shape(elems)[0] | |
elems_ta = tf.TensorArray(dtype=elems.dtype, size=n, | |
dynamic_size=False, | |
infer_shape=True) | |
elems_ta = elems_ta.unpack(elems) | |
if initializer is None: | |
a = [elems_ta.read(0)] | |
i = tf.constant(1) | |
else: | |
a = initializer | |
i = tf.constant(0) | |
# Create a tensor array to store the intermediate values. | |
acc_ta = [tf.TensorArray(dtype=each_a.dtype, size=n, | |
dynamic_size=False, | |
infer_shape=True) for each_a in a] | |
if initializer is None: | |
acc_ta = [acc_ta[0].write(0, a[0])] | |
def should_continue(i, *args): | |
return i < n | |
num_args = len(a) | |
def compute(i, *args): | |
a = listify(args[:num_args]) | |
ta = listify(args[num_args:]) | |
x = elems_ta.read(i) | |
a = fn(*(a + [x])) | |
ta = [each_ta.write(i, each_a) for each_a, each_ta in zip(a, ta)] | |
res = [i + 1] + listify(a) + ta | |
return res | |
stuff = tf.nn.control_flow_ops.while_loop( | |
should_continue, compute, [i] + listify(a) + listify(acc_ta), | |
parallel_iterations=parallel_iterations, | |
back_prop=back_prop, swap_memory=swap_memory) | |
r_a = stuff[-1*num_args:] | |
results = [each_r_a.pack() for each_r_a in r_a] | |
for result in results: | |
result.set_shape(elems.get_shape().with_rank_at_least(1)[0:1].concatenate( | |
result.get_shape()[1:])) | |
if varscope_caching_device_was_none: | |
varscope.set_caching_device(None) | |
return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment