From abff24140830e1a226b0f6d73a265b0972196008 Mon Sep 17 00:00:00 2001 From: phawkins Date: Mon, 10 Jun 2024 10:42:41 -0700 Subject: [PATCH] [JAX] Update users of jax.tree.map() to be more careful about how they handle Nones. Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself. Fix user code that was relying on this bug. Most commonly, the fix is to write `jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`. PiperOrigin-RevId: 641952015 --- spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py index bb0805c4d5..116d66321a 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py @@ -68,10 +68,12 @@ def unflatten_tree(tree, xs): def map_tree_up_to(shallow, fn, tree, *rest): """`map_tree` with recursion depth defined by depth of `shallow`.""" - def wrapper(_, *rest): - return fn(*rest) + def wrapper(x, *rest): + return None if x is None else fn(*rest) - return tree_util.tree_map(wrapper, shallow, tree, *rest) + return tree_util.tree_map( + wrapper, shallow, tree, *rest, is_leaf=lambda x: x is None + ) def get_shallow_tree(is_leaf, tree):