Skip to content

Commit

Permalink
[JAX] Update users of jax.tree.map() to be more careful about how the…
Browse files Browse the repository at this point in the history
…y handle Nones.

Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself.

Fix user code that was relying on this bug. Most commonly, the fix is to write
`jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`.

PiperOrigin-RevId: 641952015
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed Jun 10, 2024
1 parent b59c0a8 commit abff241
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ def unflatten_tree(tree, xs):
def map_tree_up_to(shallow, fn, tree, *rest):
"""`map_tree` with recursion depth defined by depth of `shallow`."""

def wrapper(_, *rest):
return fn(*rest)
def wrapper(x, *rest):
return None if x is None else fn(*rest)

return tree_util.tree_map(wrapper, shallow, tree, *rest)
return tree_util.tree_map(
wrapper, shallow, tree, *rest, is_leaf=lambda x: x is None
)


def get_shallow_tree(is_leaf, tree):
Expand Down

0 comments on commit abff241

Please sign in to comment.