How can I stop gradient when losses include nan #14217
-
When I run below code, the result is key = random.PRNGKey(0)
def loss(x):
nan_mask = random.uniform(key, x.shape) > 0.5
x = x * 2.0
x = x / (~nan_mask)
x = jnp.nan_to_num(x)
x = jnp.where(nan_mask, 0, x)
x = (1 - nan_mask) * x
x = jnp.where(nan_mask, lax.stop_gradient(x), x)
return x.mean()
print(jax.grad(loss)(jnp.ones(10,))) What I expected is Is there any dynamic way to stop(or ignore) gradient when losses include nan? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
This looks related to the situation covered in the following FAQ entry: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where But if your goal is to simply have the specified entries in def loss(x):
nan_mask = random.uniform(key, x.shape) > 0.5
x = x * 2.0
x = jnp.where(nan_mask, 0, x)
return x.mean()
print(jax.grad(loss)(jnp.ones(10,)))
# [0.2 0. 0.2 0.2 0.2 0.2 0.2 0.2 0. 0. ] Is that the output you're hoping to see? |
Beta Was this translation helpful? Give feedback.
This looks related to the situation covered in the following FAQ entry: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
But if your goal is to simply have the specified entries in
x
not contribute to the gradient, you can do so by zeroing them out:Is that the output you're hoping to see?