You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I wonder if there is a good way know how XLA decides to allocate memory?
importjaximportjax.numpyasjnpimportjax.randomasjrandom@jax.jitdefg(x, a, b):
jac=jax.jacfwd(jnp.dot, argnums=1)(a, x)
returnjnp.einsum("ni,nj,jk->ik", jac, jac, b)
a=jrandom.normal(key, (n, x_dim))
x=jrandom.normal(key, (x_dim,))
b=jrandom.normal(key, (x_dim, b_dim))
res=g(x, a, b)
jax.block_until_ready(res)
For example, with the above code, as far as I can tell, XLA has at least the following choices:
allocate a big chunk of memory of shape (n, x_dim, x_dim, b_dim) or (n, x_dim, x_dim) for the einsum
allocate a chunk of memory of shape (n, x_dim) for the variable jac
allocate only the output's shape (x_dim, b_dim), and perform a single reduction.
My best guess is that JAX+XLA is doing the second option. But, is there a way to know for sure?
I tried jax.profiler but it ignores the allocation inside jit.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi, I wonder if there is a good way know how XLA decides to allocate memory?
For example, with the above code, as far as I can tell, XLA has at least the following choices:
(n, x_dim, x_dim, b_dim)
or(n, x_dim, x_dim)
for the einsum(n, x_dim)
for the variablejac
(x_dim, b_dim)
, and perform a single reduction.My best guess is that JAX+XLA is doing the second option. But, is there a way to know for sure?
I tried
jax.profiler
but it ignores the allocation insidejit
.Beta Was this translation helpful? Give feedback.
All reactions