Skip to content

Conversion from a JAX array to a Numpy array is slower after calling a JAX function #23996

Answered by jakevdp
jiayingqi asked this question in Q&A
Discussion options

You must be logged in to vote

I'm not able to reproduce this with JAX v0.4.33 on either a Colab CPU or GPU runtime. That said, I have a guess as to why this may be happening for you: I suspect your timings are being fooled by JAX's Asyncronous dispatch: when you call a JAX operation, the Python function returns before the result is actually computed. So your first call to compute essentially just queues up the computations, which begin running in the background. By the time you get to your second call, the queue is full, and so the Python function must wait for the previous iterations to finish before it can enqueue its computations.

If I'm right, then adding this to your first line should make all runs take the same …

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jiayingqi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants