Skip to content

Is it possible to accelerate autograd by formula simplification? #11915

Answered by jakevdp
GiantElephant123 asked this question in Ideas
Discussion options

You must be logged in to vote

jax.jit already does this kind of acceleration via formula simplification automatically. I think your timings are probably being thrown-off by JAX's asynchronous dispatch; for information about how to accurately assess microbenchmarks of JAX code, see https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code

When I benchmark these functions using the recommended approaches there, I find that the automatic and manual versions of the function both have comparable runtimes of ~4µs on a Colab CPU:

import jax
import jax.numpy as jnp

x = jnp.array(1.)

def f(x):
    return jnp.exp(-x**2)

fprime = jax.jit(jax.grad(jax.grad(f)))
x = jnp.array(1.)

%timeit fprime.lower(x).compile()
# 3…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@GiantElephant123
Comment options

Answer selected by GiantElephant123
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Ideas
Labels
None yet
2 participants