Skip to content

Why Jacfwd matrix is full of NaN after jax.scipy.signal.fftconvolve #24010

Answered by dfm
jecampagne asked this question in Q&A
Discussion options

You must be logged in to vote

Yes. The problem here is that for the specific value of cxy that you're passing, the argument to the sqrt in dist_from_center ((X - center[0])**2 + (Y - center[1])**2) is exactly zero for one of the points. The derivative of sqrt(x) at x = 0 is NaN.

You can confirm this by replacing the tst definition with (for example):

tst = jacfwd(f)(cxy + 1e-5, psf)

which doesn't have any NaNs in it.

The way that you should handle this depends on your specific needs, but you can often wrap the offending computation (the sqrt in your case) in a jnp.where, making sure that you also check the relevant entry in the FAQ.

Hope this helps!

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@dfm
Comment options

dfm Oct 2, 2024
Collaborator

Answer selected by jecampagne
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants