Problem in jitting methods of a class #23975
-
Hi, I'm trying to JIT a methods of a class that updates the properties of that class.
This class is almost okay, but I noticed that the method
I would like to avoid to generate a new class instance every time I update the class parameters because it doesn't seem efficient to me. Removing the Do you have any suggestion that can help me? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question – we have a FAQ entry on this topic here, and it looks like you're already using the pytree approach, which is great.
This is working as expected: JIT-compiled functions must be pure, and cannot operate via side-effects. By design, the input to the function will not be mutated as a result of executing the function because this is a side-effect (see JAX sharp bits: pure functions for some discussion of this).
Your options are either to remove the JIT compilation to use in-place mutation, or to keep the JIT compilation and return the modified class. You can't have both JIT and mutation. That said, if you're calling |
Beta Was this translation helpful? Give feedback.
Thanks for the question – we have a FAQ entry on this topic here, and it looks like you're already using the pytree approach, which is great.
This is working as expected: JIT-compiled functions must be pure, and cannot operate via side-effects. By design, the input to the function will not be mutated as a result of executing the function because this is a side-effect (see JAX sharp bits: pure functions for some discussion of this).