Make jax._src.api._std_basis part of the Main tree_utils
API
#10081
Replies: 3 comments 3 replies
-
Worth noting that jax._src.api._std_basis([jax.numpy.zeros(3), jax.numpy.zeros(3)]) produces [DeviceArray([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32), DeviceArray([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)] which is just a length-2 list of arrays. It is not a length-2-list-of-length-2-lists of arrays. That is to say the "PyTree axes" don't get duplicated out the way the "array axes" do. It's very easy to write a import jax
import jax.numpy as jnp
def std_basis(pytree):
leaves, structure = jax.tree_flatten(pytree)
eye_structure = structure.compose(structure)
eye_leaves = []
for i1, l1 in enumerate(leaves):
for i2, l2 in enumerate(leaves):
if i1 == i2:
eye_leaves.append(jnp.eye(jnp.size(l1)).reshape(jnp.shape(l1) + jnp.shape(l2)))
else:
eye_leaves.append(jnp.zeros(l1.shape + l2.shape))
return jax.tree_unflatten(eye_structure, eye_leaves) and of course any time you need this "matrix PyTree" to interact with a "vector PyTree" you can do so by just |
Beta Was this translation helpful? Give feedback.
-
As @patrick-kidger has shown, there are two kinds of "std_basis":
|
Beta Was this translation helpful? Give feedback.
-
BTW, you can implement import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
def std_basis_for_vmap(pytree):
flat, unravel_pytree = ravel_pytree(pytree)
return jax.vmap(unravel_pytree)(jnp.eye(jnp.size(flat), dtype=jnp.dtype(flat)))
print(std_basis_for_vmap([jnp.zeros(3), jnp.zeros(3)])) |
Beta Was this translation helpful? Give feedback.
-
Hey, I've now repeatedly come across situations in which a
jnp.eye
function generalized to pytrees would be quite helpful. Though, it surprised me that in the backend there is a_std_basis
function for pytrees that is hidden from the main API. Why is this the case?I'd say that it would be helpful to add this as part of the main API, maybe as part of the
jax.tree_util
API?An example use case I had was for evaluating each intermediate Jacobian within a function composition; or in other terms, the adjoint VJPs at each discrete time-step
t
fed through the standard basis.Don't pay too much attention to the function
F
itself, it's just an example. But this would be useful for Neural Networks in e.g., layer-wise (block-wise) factorizations of the Hessian/ Fisher.This is one example where the std-basis is helpful. It can also be helpful for eliminating off-diagonal elements in any ArrayTree, or for defining complex matrix vector products with pytrees.
What do you all think?
Beta Was this translation helpful? Give feedback.
All reactions