Skip to content

abstracted_axes and eval_jaxpr #18567

Answered by mattjj
erick-xanadu asked this question in Ideas
Discussion options

You must be logged in to vote

Thanks for the question!

This means that the the jaxpr computation will have an implicit parameter corresponding to the value of the axis 0.

Actually, the parameter isn't implicit in the jaxpr; it appears as just an ordinary parameter, and correspondingly the caller can just pass it as an ordinary argument:

import jax
import jax.numpy as jnp

jax.config.update('jax_dynamic_shapes', True)

def f(x):
  return jnp.sin(x) + jnp.cos(x)

jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(jnp.arange(3.)).jaxpr
print(jaxpr)
# { lambda ; a:i32[] b:f32[a]. let
#     c:f32[a] = sin b
#     d:f32[a] = cos b
#     e:f32[a] = add c d
#   in (e,) }

from jax._src.core import eval_jaxpr
ans, = eval_jaxpr

Replies: 1 comment 7 replies

Comment options

You must be logged in to vote
7 replies
@soraros
Comment options

@erick-xanadu
Comment options

@mattjj
Comment options

@soraros
Comment options

@mattjj
Comment options

Answer selected by erick-xanadu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Ideas
Labels
None yet
3 participants