Skip to content

Instantly share code, notes, and snippets.

@so298
Last active January 16, 2025 11:57
Show Gist options
  • Save so298/f19ab543ff30bf13ce4a8b1e7651162f to your computer and use it in GitHub Desktop.
Save so298/f19ab543ff30bf13ce4a8b1e7651162f to your computer and use it in GitHub Desktop.
Simple example of JAX's pytree utility
import jax.numpy as jnp
from dataclasses import dataclass
from functools import partial
from jax.tree_util import register_dataclass, tree_flatten, tree_leaves
from jax import vmap
import jax
@partial(
register_dataclass,
data_fields=["x", "y"],
meta_fields=["z"],
)
@dataclass
class MyStruct:
x: int
y: int
z: int = 100
@partial(vmap, in_axes=(0,))
def vsimple_fn(s: MyStruct) -> int:
return (s.x + s.y) * s.z
s = MyStruct(1, 2)
print(f"{tree_flatten(s)=}")
print(f"{tree_leaves(s)=}")
"""
tree_flatten(s)=([1, 2], PyTreeDef(CustomNode(MyStruct[(100,)], [*, *])))
tree_leaves(s)=[1, 2]
"""
# vsimple_fn(s) # ... error
s2 = MyStruct(jnp.array([1, 2]), jnp.array([3, 4]))
print(f"{tree_flatten(s2)=}")
print(f"{tree_leaves(s2)=}")
"""
tree_flatten(s2)=([Array([1, 2], dtype=int32), Array([3, 4], dtype=int32)], PyTreeDef(CustomNode(MyStruct[(100,)], [*, *])))
tree_leaves(s2)=[Array([1, 2], dtype=int32), Array([3, 4], dtype=int32)]
"""
res = vsimple_fn(s2) # ... no error
print(f"{res=}")
"""
res=Array([400, 600], dtype=int32)
"""
def tree_transpose(list_of_trees):
"""
Converts a list of trees of identical structure into a single tree of lists.
"""
return jax.tree.map(lambda *xs: jnp.array(xs), *list_of_trees)
# Vectrize input
s_list = [
MyStruct(1, 2),
MyStruct(3, 4),
MyStruct(5, 6),
]
tt = tree_transpose(s_list)
print(f"{tt=}")
print(f"{vsimple_fn(tt)=}")
"""
tt=MyStruct(x=Array([1, 3, 5], dtype=int32), y=Array([2, 4, 6], dtype=int32), z=100)
vsimple_fn(tt)=Array([ 300, 700, 1100], dtype=int32)
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment