Skip to content

Commit

Permalink
Add is_leaf to tree_{leaves,structure}.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 417783880
  • Loading branch information
tomhennigan authored and jax authors committed Dec 22, 2021
1 parent 927f3e5 commit 2f62574
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
8 changes: 4 additions & 4 deletions jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def tree_unflatten(treedef, leaves):
"""
return treedef.unflatten(leaves)

def tree_leaves(tree):
def tree_leaves(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
"""Gets the leaves of a pytree."""
return pytree.flatten(tree)[0]
return pytree.flatten(tree, is_leaf)[0]

def tree_structure(tree):
def tree_structure(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
"""Gets the treedef for a pytree."""
return pytree.flatten(tree)[1]
return pytree.flatten(tree, is_leaf)[1]

def treedef_tuple(treedefs):
"""Makes a tuple treedef from a list of child treedefs."""
Expand Down
41 changes: 30 additions & 11 deletions tests/tree_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ def tree_unflatten(cls, meta, data):
# pytest expects "tree_util_test.ATuple"
STRS = []
for tree_str in TREE_STRINGS:
tree_str = re.escape(tree_str)
tree_str = tree_str.replace("__main__", ".*")
STRS.append(tree_str)
tree_str = re.escape(tree_str)
tree_str = tree_str.replace("__main__", ".*")
STRS.append(tree_str)
TREE_STRINGS = STRS

LEAVES = (
Expand Down Expand Up @@ -240,23 +240,42 @@ def testTreeMultimapWithIsLeafArgument(self):
self.assertEqual(out, (((1, [3]), (2, None)),
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))

def testFlattenIsLeaf(self):
@parameterized.parameters(
tree_util.tree_leaves,
lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[0])
def testFlattenIsLeaf(self, leaf_fn):
x = [(1, 2), (3, 4), (5, 6)]
leaves, _ = tree_util.tree_flatten(x, is_leaf=lambda t: False)
leaves = leaf_fn(x, is_leaf=lambda t: False)
self.assertEqual(leaves, [1, 2, 3, 4, 5, 6])
leaves, _ = tree_util.tree_flatten(
x, is_leaf=lambda t: isinstance(t, tuple))
leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, tuple))
self.assertEqual(leaves, x)
leaves, _ = tree_util.tree_flatten(x, is_leaf=lambda t: isinstance(t, list))
leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, list))
self.assertEqual(leaves, [x])
leaves, _ = tree_util.tree_flatten(x, is_leaf=lambda t: True)
leaves = leaf_fn(x, is_leaf=lambda t: True)
self.assertEqual(leaves, [x])

y = [[[(1,)], [[(2,)], {"a": (3,)}]]]
leaves, _ = tree_util.tree_flatten(
y, is_leaf=lambda t: isinstance(t, tuple))
leaves = leaf_fn(y, is_leaf=lambda t: isinstance(t, tuple))
self.assertEqual(leaves, [(1,), (2,), (3,)])

@parameterized.parameters(
tree_util.tree_structure,
lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[1])
def testStructureIsLeaf(self, structure_fn):
x = [(1, 2), (3, 4), (5, 6)]
treedef = structure_fn(x, is_leaf=lambda t: False)
self.assertEqual(treedef.num_leaves, 6)
treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, tuple))
self.assertEqual(treedef.num_leaves, 3)
treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, list))
self.assertEqual(treedef.num_leaves, 1)
treedef = structure_fn(x, is_leaf=lambda t: True)
self.assertEqual(treedef.num_leaves, 1)

y = [[[(1,)], [[(2,)], {"a": (3,)}]]]
treedef = structure_fn(y, is_leaf=lambda t: isinstance(t, tuple))
self.assertEqual(treedef.num_leaves, 3)

@parameterized.parameters(*TREES)
def testRoundtripIsLeaf(self, tree):
xs, treedef = tree_util.tree_flatten(
Expand Down

0 comments on commit 2f62574

Please sign in to comment.