Skip to content

Commit

Permalink
Fixing the gradient of the L2 distance at the origin (#121)
Browse files Browse the repository at this point in the history
* Providing a zero-safe custom JVP for L2-distance

* lint

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* adding news and simplifying test case

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* lint

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* lint

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dfm and pre-commit-ci[bot] authored Oct 28, 2022
1 parent c1893ca commit 2bdccea
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 4 deletions.
2 changes: 2 additions & 0 deletions news/121.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed bug where the gradient of the L2 distance would return NaN when the
distance was zero.
11 changes: 9 additions & 2 deletions src/tinygp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@
:ref:`tutorials` for a more complete introduction.
"""

__all__ = ["kernels", "noise", "solvers", "transforms", "GaussianProcess"]
__all__ = [
"__version__",
"kernels",
"noise",
"solvers",
"transforms",
"GaussianProcess",
]

from tinygp import kernels, noise, solvers, transforms
from tinygp.gp import GaussianProcess
from tinygp.tinygp_version import version as __version__
from tinygp.tinygp_version import __version__

__author__ = "Dan Foreman-Mackey"
__email__ = "[email protected]"
Expand Down
8 changes: 6 additions & 2 deletions src/tinygp/kernels/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
with multivariate data. By default, all
:class:`tinygp.kernels.stationary.Stationary` kernels will use either an
:class:`L1Distance` or :class:`L2Distance`, when applied in multiple dimensions,
but it is possible to define custom metrics, as dicussed in the :ref:`geometry`
but it is possible to define custom metrics, as discussed in the :ref:`geometry`
tutorial.
"""

Expand Down Expand Up @@ -51,7 +51,11 @@ class L2Distance(Distance):
"""The L2 or Euclidean distance between two coordinates"""

def distance(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
return jnp.sqrt(self.squared_distance(X1, X2))
r1 = L1Distance().distance(X1, X2)
r2 = self.squared_distance(X1, X2)
zeros = jnp.equal(r2, 0)
r2 = jnp.where(zeros, jnp.ones_like(r2), r2)
return jnp.where(zeros, r1, jnp.sqrt(r2))

def squared_distance(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
return jnp.sum(jnp.square(X1 - X2))
32 changes: 32 additions & 0 deletions tests/test_kernels/test_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
# mypy: ignore-errors

import jax
import jax.numpy as jnp
import numpy as np
from jax.test_util import check_grads

from tinygp.kernels import distance


def check(comp, expect, args, order=2, **kwargs):
np.testing.assert_allclose(expect(*args), comp(*args))
np.testing.assert_allclose(jax.grad(comp)(*args), jax.grad(expect)(*args))
check_grads(comp, args, order=order, **kwargs)


def test_l2_distance_grad_at_zero():
expect = lambda x1, x2: jnp.sqrt(jnp.sum(jnp.square(x1 - x2)))
comp = distance.L2Distance().distance

x1 = 0.0
x2 = 1.5
check(comp, expect, (x1, x2))

x1 = jnp.array([0.0, 0.1])
x2 = jnp.array([1.5, -0.2])
check(comp, expect, (x1, x2))

g = jax.grad(comp)(x1, x1)
np.testing.assert_allclose(expect(x1, x1), comp(x1, x1))
assert np.all(np.isfinite(g))

0 comments on commit 2bdccea

Please sign in to comment.