Skip to content

Instantly share code, notes, and snippets.

@JonathanRaiman
Created July 11, 2016 18:50
Show Gist options
  • Save JonathanRaiman/04f59b5141bea6b767dca6af25cfebf5 to your computer and use it in GitHub Desktop.
Save JonathanRaiman/04f59b5141bea6b767dca6af25cfebf5 to your computer and use it in GitHub Desktop.
Scan multi arg in tensorflow
def listify(x):
if isinstance(x, tuple):
return list(x)
return x
def awesome_scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
swap_memory=False, name=None):
"""scan on the list of tensors unpacked from `elems` on dimension 0.
This scan operator repeatedly applies the callable `fn` to a sequence
of elements from first to last. The elements are made of the tensors
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
arguments. The first argument is the accumulated value computed from the
preceding invocation of fn. If `initializer` is None, `elems` must contain
at least one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
Args:
fn: The callable to be performed.
elems: A tensor to be unpacked on dimension 0.
initializer: (optional) The initial value for the accumulator.
parallel_iterations: (optional) The number of iterations allowed to run
in parallel.
back_prop: (optional) True enables back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
name: (optional) Name prefix for the returned tensors.
Returns:
A tensor that packs the results of applying `fn` to the list of tensors
unpacked from `elems`, from first to last.
Raises:
TypeError: if `fn` is not callable.
Example:
```python
elems = [1,2,3,4]
sums = awesome_scan(lambda s1, s2, x: [s1 + x + s2, 2*s2], elems,
initializer=[tf.constant(0, name="name_0"), tf.constant(1, name="name_1")])
[s.eval() for s in sums]
```
"""
if not callable(fn):
raise TypeError("fn must be callable.")
with tf.op_scope([elems], name, "awesome_scan"):
# Any get_variable calls in fn will cache the first call locally
# and not issue repeated network I/O requests for each iteration.
varscope = tf.get_variable_scope()
varscope_caching_device_was_none = False
if varscope.caching_device is None:
# TODO(ebrevdo): Change to using colocate_with here and in other methods.
varscope.set_caching_device(lambda op: op.device)
varscope_caching_device_was_none = True
# Convert elems to tensor array.
elems = tf.convert_to_tensor(elems, name="elems")
n = tf.shape(elems)[0]
elems_ta = tf.TensorArray(dtype=elems.dtype, size=n,
dynamic_size=False,
infer_shape=True)
elems_ta = elems_ta.unpack(elems)
if initializer is None:
a = [elems_ta.read(0)]
i = tf.constant(1)
else:
a = initializer
i = tf.constant(0)
# Create a tensor array to store the intermediate values.
acc_ta = [tf.TensorArray(dtype=each_a.dtype, size=n,
dynamic_size=False,
infer_shape=True) for each_a in a]
if initializer is None:
acc_ta = [acc_ta[0].write(0, a[0])]
def should_continue(i, *args):
return i < n
num_args = len(a)
def compute(i, *args):
a = listify(args[:num_args])
ta = listify(args[num_args:])
x = elems_ta.read(i)
a = fn(*(a + [x]))
ta = [each_ta.write(i, each_a) for each_a, each_ta in zip(a, ta)]
res = [i + 1] + listify(a) + ta
return res
stuff = tf.nn.control_flow_ops.while_loop(
should_continue, compute, [i] + listify(a) + listify(acc_ta),
parallel_iterations=parallel_iterations,
back_prop=back_prop, swap_memory=swap_memory)
r_a = stuff[-1*num_args:]
results = [each_r_a.pack() for each_r_a in r_a]
for result in results:
result.set_shape(elems.get_shape().with_rank_at_least(1)[0:1].concatenate(
result.get_shape()[1:]))
if varscope_caching_device_was_none:
varscope.set_caching_device(None)
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment