Skip to content

Commit

Permalink
Merge pull request #8 from alanlujan91/gradients
Browse files Browse the repository at this point in the history
implement derivatives for regular grids
  • Loading branch information
alanlujan91 authored Oct 12, 2023
2 parents eb76432 + b30d40f commit c5fe27e
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,4 @@ cython_debug/

# Built Visual Studio Code Extensions
*.vsix
.vscode/settings.json
207 changes: 207 additions & 0 deletions examples/Multivariate Interpolation with Derivatives.ipynb

Large diffs are not rendered by default.

33 changes: 31 additions & 2 deletions src/multinterp/regular.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import cupy as cp
import numpy as np

from multinterp.backend._numba import numba_get_coordinates, numba_map_coordinates
from multinterp.backend._scipy import scipy_get_coordinates, scipy_map_coordinates
from multinterp.core import (
Expand All @@ -19,12 +22,17 @@ def get_methods():
"scipy": scipy_map_coordinates,
"numba": numba_map_coordinates,
}
get_grad = {
"scipy": np.gradient,
"numba": np.gradient,
}

try:
from multinterp.backend._cupy import cupy_get_coordinates, cupy_map_coordinates

get_coords["cupy"] = cupy_get_coordinates
map_coords["cupy"] = cupy_map_coordinates
get_grad["cupy"] = cp.gradient
except ImportError:
pass

Expand All @@ -36,10 +44,10 @@ def get_methods():
except ImportError:
pass

return get_coords, map_coords
return get_coords, map_coords, get_grad


GET_COORDS, MAP_COORDS = get_methods()
GET_COORDS, MAP_COORDS, GET_GRAD = get_methods()

AVAILABLE_BACKENDS, BACKEND_MODULES = import_backends()

Expand Down Expand Up @@ -67,6 +75,7 @@ def __init__(self, values, grids, backend="scipy", options=None):

super().__init__(values, grids, backend=backend)
self._parse_mc_options(options)
self._gradient = {}

def _parse_mc_options(self, options):
self.mc_kwargs = MC_KWARGS if self.backend != "jax" else JAX_MC_KWARGS
Expand Down Expand Up @@ -136,3 +145,23 @@ def _map_coordinates(self, coords):
"""

return MAP_COORDS[self.backend](self.values, coords, **self.mc_kwargs)

def diff(self, axis=None, edge_order=1):
# if axis is not an integer less than or equal to the number
# of dimensions of the input array, then a ValueError is raised.
if axis is None:
msg = "Must specify axis to differentiate along."
raise ValueError(msg)
if axis >= self.ndim:
msg = "Axis must be less than number of dimensions."
raise ValueError(msg)

grad = self._gradient.get(axis)
if grad is None:
grad = GET_GRAD[self.backend](
self.values, self.grids[axis], axis=axis, edge_order=edge_order
)

return MultivariateInterp(
grad, self.grids, backend=self.backend, options=self.mc_kwargs
)

0 comments on commit c5fe27e

Please sign in to comment.