-
Hello. I was reading up on Google Jax because I have never used it before. A lot of the gotcha documentation says you cannot use impure functions with Jax. Is this restriction only when you jit compile a function? In other words, can you use impure functions within your program if you are not jit compiling those functions? For example can I have two impure functions that run, then have 2 jit compiled functions that run after inside the same program? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi - thanks for the question! You can use impure functions in JAX as long as you don't transform them (i.e. use |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! You can use impure functions in JAX as long as you don't transform them (i.e. use
jit
,vmap
,grad
,pmap
, etc.) or use them in control flow operations: (fori_loop
,while_loop
,scan
, etc.). A typical pattern is to use impure functions as a sort of set-up step (e.g. loading data from disk) and then make the core parts of your algorithm pure.