Created
December 12, 2021 19:04
-
-
Save enijkamp/354f037bd1dd00df1a714d86f89ed0a7 to your computer and use it in GitHub Desktop.
leave_names.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def tree_flatten_with_names(pytree, is_leaf, path="", to_id=id): | |
id_to_name = {} | |
if getattr(pytree, "items", None): | |
for k, v in pytree.items(): | |
k_path = f"{path}/{k}" | |
if is_leaf(v): | |
id_to_name[to_id(v)] = k_path | |
else: | |
id_to_name = { | |
**id_to_name, | |
**tree_flatten_with_names(v, is_leaf=is_leaf, path=k_path), | |
} | |
elif getattr(pytree, "__getitem__", None): | |
for v in pytree: | |
if is_leaf(v): | |
id_to_name[to_id(v)] = path | |
else: | |
id_to_name = { | |
**id_to_name, | |
**tree_flatten_with_names(v, is_leaf=is_leaf, path=path), | |
} | |
else: | |
id_to_name[to_id(pytree)] = path | |
return id_to_name | |
def tree_leaves_with_names(pytree, to_id=id): | |
leaves = jax.tree_leaves(pytree) | |
is_leaf = lambda x: not isinstance(x, list) and to_id(x) in [ | |
to_id(x) for x in leaves | |
] | |
return tree_flatten_with_names(pytree, is_leaf) | |
def get_tree_leaves_names_reduced(pytree) -> List[str]: | |
leaves_ids = tree_leaves_with_names(pytree, to_id=id) | |
leaves = jax.tree_leaves(pytree) | |
return [leaves_ids[id(l)] for l in leaves] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment