-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixing the gradient of the L2 distance at the origin (#121)
* Providing a zero-safe custom JVP for L2-distance * lint * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * adding news and simplifying test case * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
c1893ca
commit 2bdccea
Showing
4 changed files
with
49 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
Fixed bug where the gradient of the L2 distance would return NaN when the | ||
distance was zero. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,11 +9,18 @@ | |
:ref:`tutorials` for a more complete introduction. | ||
""" | ||
|
||
__all__ = ["kernels", "noise", "solvers", "transforms", "GaussianProcess"] | ||
__all__ = [ | ||
"__version__", | ||
"kernels", | ||
"noise", | ||
"solvers", | ||
"transforms", | ||
"GaussianProcess", | ||
] | ||
|
||
from tinygp import kernels, noise, solvers, transforms | ||
from tinygp.gp import GaussianProcess | ||
from tinygp.tinygp_version import version as __version__ | ||
from tinygp.tinygp_version import __version__ | ||
|
||
__author__ = "Dan Foreman-Mackey" | ||
__email__ = "[email protected]" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# -*- coding: utf-8 -*- | ||
# mypy: ignore-errors | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from jax.test_util import check_grads | ||
|
||
from tinygp.kernels import distance | ||
|
||
|
||
def check(comp, expect, args, order=2, **kwargs): | ||
np.testing.assert_allclose(expect(*args), comp(*args)) | ||
np.testing.assert_allclose(jax.grad(comp)(*args), jax.grad(expect)(*args)) | ||
check_grads(comp, args, order=order, **kwargs) | ||
|
||
|
||
def test_l2_distance_grad_at_zero(): | ||
expect = lambda x1, x2: jnp.sqrt(jnp.sum(jnp.square(x1 - x2))) | ||
comp = distance.L2Distance().distance | ||
|
||
x1 = 0.0 | ||
x2 = 1.5 | ||
check(comp, expect, (x1, x2)) | ||
|
||
x1 = jnp.array([0.0, 0.1]) | ||
x2 = jnp.array([1.5, -0.2]) | ||
check(comp, expect, (x1, x2)) | ||
|
||
g = jax.grad(comp)(x1, x1) | ||
np.testing.assert_allclose(expect(x1, x1), comp(x1, x1)) | ||
assert np.all(np.isfinite(g)) |