-
Are there a possibility for doing smth like => with jax.api.disable_jit():
func(x,y) but using omnistag scope? with jax.api.disable_omnistage():
jit(func)(x,y) My use case is function dependent on a mask created from not static arguments and jitting it works very well without omnistaging but rest of the code still need to use it :) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
There is such a function, but it's an internal API function, and we might break it in the future. I don't want to cause you more headaches! As you may have seen in the omnistaging upgrade guide, oftentimes if you Would that work? Otherwise, can you provide an example? |
Beta Was this translation helpful? Give feedback.
There is such a function, but it's an internal API function, and we might break it in the future. I don't want to cause you more headaches!
As you may have seen in the omnistaging upgrade guide, oftentimes if you
import numpy as np
andimport jax.numpy as jnp
, for code you don't want to stage out it can be a better idea to usenp
rather thanjnp
. That is, you can control what gets staged out to XLA simply by only usingjnp
for things that you might want to stage out.Would that work? Otherwise, can you provide an example?