-
-
Save KeAWang/f420ba439a012d969b04211a42f6c9de to your computer and use it in GitHub Desktop.
utils for stacking and unstacking jax pytrees to deal with vmap
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
def tree_stack(trees): | |
"""Takes a list of trees and stacks every corresponding leaf. | |
For example, given two trees ((a, b), c) and ((a', b'), c'), returns | |
((stack(a, a'), stack(b, b')), stack(c, c')). | |
Useful for turning a list of objects into something you can feed to a | |
vmapped function. | |
""" | |
assert isinstance(trees, list) | |
leaves_list, treedef_list = zip(*map(jax.tree_flatten, trees)) | |
assert len(set(treedef_list)) == 1, "all pytrees must be the same" | |
grouped_leaves = zip(*leaves_list) | |
result_leaves = [jnp.stack(l) for l in grouped_leaves] | |
return treedef_list[0].unflatten(result_leaves) | |
def tree_unstack(tree): | |
"""Takes a tree and turns it into a list of trees. Inverse of tree_stack. | |
For example, given a tree ((a, b), c), where a, b, and c all have first | |
dimension k, will make k trees | |
[((a[0], b[0]), c[0]), ..., ((a[k], b[k]), c[k])] | |
Useful for turning the output of a vmapped function into normal objects. | |
""" | |
leaves, treedef = jax.tree_flatten(tree) | |
if len(leaves) == 0: | |
return [tree] | |
leave_lengths = set(map(lambda x: x.shape[0], leaves)) | |
assert ( | |
len(leave_lengths) == 1 | |
), "All non-None pytrees leaves must be of the same size in the leading axis" | |
new_leaves = zip(*leaves) | |
new_trees = [treedef.unflatten(l) for l in new_leaves] | |
return new_trees | |
inputs = [[None], [(None, jnp.ones(1)), (None, jnp.zeros(1))]] | |
for inp in inputs: | |
output = tree_unstack(tree_stack(inp)) | |
assert inp == output | |
assert tree_unstack(None) == [None] | |
assert tree_unstack([None]) == [[None]] | |
assert tree_unstack([None, None]) == [[None, None]] | |
assert tree_stack([None, None]) is None | |
try: | |
assert tree_stack(None) | |
except AssertionError: | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment