-
How do I do this? There's this section in Autodidax which roughly covers what I was thinking about. The issue is (unlike I have primitives that are active for one interpreter, but if I perform another transformation first - that other transformation might call It seems like this removes the possibility of my outer transformation processing that primitive. Edit: is |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I figured out how to do this - but I had to rewrite a few of my transformations to always use my own interpreter pattern (which always stages out to a I'm not sure there's a general way to do it without doing this because (as mentioned in |
Beta Was this translation helpful? Give feedback.
I figured out how to do this - but I had to rewrite a few of my transformations to always use my own interpreter pattern (which always stages out to a
Jaxpr
, before walking -- so each transformation which uses this interpreter gets to see every primitive).I'm not sure there's a general way to do it without doing this because (as mentioned in
Autodidax
- JAX has dataflow assumptions as an optimization, so I'm sort of breaking those assumptions so I had to go lower level and make my own interpreter stack -- this is totally okay for my use cases, however).