Why Jacfwd matrix is full of NaN after jax.scipy.signal.fftconvolve #24010
Answered
by
dfm
jecampagne
asked this question in
Q&A
-
Hello, import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit, vmap
from jax import jacfwd, jacrev, hessian
jax.config.update("jax_enable_x64", True)
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['font.size'] = 18
print(jax.__version__) #<=0.4.33 (September 16, 2024)
print(jax.devices()). # CPU
def low_pass(r,X,n=30):
u = X/r
return 1./jnp.sqrt(1 + u**(2.*n))
def create_disk(h, w, radius, center):
Y, X = jnp.ogrid[:h, :w]
dist_from_center = jnp.sqrt((X - center[0])**2 + (Y-center[1])**2)
smooth_mask = low_pass(radius,dist_from_center)
return smooth_mask Then img = create_disk(128,128,radius=32,center=(64,64))
psf = create_disk(32,32,radius=8,center=(16,16))
psf /= jnp.sum(psf)
res = jax.scipy.signal.fftconvolve(img,psf)
fig, axs = plt.subplots(1,3,figsize=(15,2))
a0=axs[0].imshow(img,cmap="gray");plt.colorbar(a0,ax=axs[0])
a1=axs[1].imshow(psf,cmap="gray");plt.colorbar(a1,ax=axs[1])
a2=axs[2].imshow(res,cmap="gray");plt.colorbar(a2,ax=axs[2]); Now, let us try the following to investigate Jacobian def f(cxy,psf):
img = create_disk(128,128,radius=32.,center=cxy)
return jax.scipy.signal.fftconvolve(img,psf)
# A shift in cebter position ok
psf = create_disk(32,32,radius=8.,center=(16,16))
psf /= jnp.sum(psf)
cxy=jnp.array([50.,64.])
tmp = f(cxy,psf)
plt.imshow(tmp,cmap="gray");plt.colorbar();
But jnp.all(jnp.isnan(tst)) # True thta is to say all elements of |
Beta Was this translation helpful? Give feedback.
Answered by
dfm
Oct 2, 2024
Replies: 2 comments 1 reply
-
Does someone has an idea of the failure? |
Beta Was this translation helpful? Give feedback.
1 reply
-
Oh. I see. Thanks def create_disk(h, w, radius, center):
Y, X = jnp.ogrid[:h, :w]
dist2_from_center = (X - center[0])**2 + (Y-center[1])**2
smooth_mask = low_pass(radius**2,dist2_from_center)
return smooth_mask Here is the Jacobian for cxy=(64,64) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Yes. The problem here is that for the specific value of
cxy
that you're passing, the argument to thesqrt
indist_from_center
((X - center[0])**2 + (Y - center[1])**2
) is exactly zero for one of the points. The derivative ofsqrt(x)
atx = 0
is NaN.You can confirm this by replacing the
tst
definition with (for example):which doesn't have any NaNs in it.
The way that you should handle this depends on your specific needs, but you can often wrap the offending computation (the
sqrt
in your case) in ajnp.where
, making sure that you also check the relevant entry in the FAQ.Hope this helps!