Skip to content

Commit

Permalink
no jax
Browse files Browse the repository at this point in the history
  • Loading branch information
alanlujan91 committed Sep 8, 2023
1 parent 29d85d0 commit 6272c93
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
3 changes: 2 additions & 1 deletion src/multinterp/backend/_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import numpy as np
from numba import njit, prange, typed
from scipy.ndimage import map_coordinates

# from scipy.ndimage import map_coordinates
from multinterp.backend.numba_jax import map_coordinates
from multinterp.core import MC_KWARGS


Expand Down
43 changes: 24 additions & 19 deletions src/multinterp/backend/numba_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,19 @@
from collections.abc import Callable, Sequence

import numpy as np
from jax._src.typing import Array, ArrayLike


def _mirror_index_fixer(index: Array, size: int) -> Array:
def _mirror_index_fixer(index: np.ndarray, size: int) -> np.ndarray:
s = size - 1 # Half-wavelength of triangular wave
# Scaled, integer-valued version of the triangular wave |x - round(x)|
return np.abs((index + s) % (2 * s) - s)


def _reflect_index_fixer(index: Array, size: int) -> Array:
def _reflect_index_fixer(index: np.ndarray, size: int) -> np.ndarray:
return np.floor_divide(_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2)


_INDEX_FIXERS: dict[str, Callable[[Array, int], Array]] = {
_INDEX_FIXERS: dict[str, Callable[[np.ndarray, int], np.ndarray]] = {
"constant": lambda index, size: index,
"nearest": lambda index, size: np.clip(index, 0, size - 1),
"wrap": lambda index, size: index % size,
Expand All @@ -26,17 +25,21 @@ def _reflect_index_fixer(index: Array, size: int) -> Array:
}


def _round_half_away_from_zero(a: Array) -> Array:
def _round_half_away_from_zero(a: np.ndarray) -> np.ndarray:
return a if np.issubdtype(a.dtype, np.integer) else np.round(a)


def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
def _nearest_indices_and_weights(
coordinate: np.ndarray,
) -> list[tuple[np.ndarray, np.ndarray]]:
index = _round_half_away_from_zero(coordinate).astype(np.int32)
weight = coordinate.dtype.type(1)
return [(index, weight)]


def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
def _linear_indices_and_weights(
coordinate: np.ndarray,
) -> list[tuple[np.ndarray, np.ndarray]]:
lower = np.floor(coordinate)
upper_weight = coordinate - lower
lower_weight = 1 - upper_weight
Expand All @@ -45,12 +48,12 @@ def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLik


def _map_coordinates(
input: ArrayLike,
coordinates: Sequence[ArrayLike],
input: np.ndarray,
coordinates: Sequence[np.ndarray],
order: int,
mode: str,
cval: ArrayLike,
) -> Array:
cval: np.ndarray,
) -> np.ndarray:
input_arr = np.asarray(input)
coordinate_arrs = [np.asarray(c) for c in coordinates]
cval = np.asarray(cval, input_arr.dtype)
Expand Down Expand Up @@ -89,7 +92,7 @@ def is_valid(index, size):
raise NotImplementedError(msg)

valid_1d_interpolations = []
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
for coordinate, size in zip(coordinate_arrs, input_arr.shape, strict=True):
interp_nodes = interp_fun(coordinate)
valid_interp = []
for index, weight in interp_nodes:
Expand All @@ -100,25 +103,27 @@ def is_valid(index, size):

outputs = []
for items in itertools.product(*valid_1d_interpolations):
indices, validities, weights = zip(*items)
indices, validities, weights = zip(*items, strict=True)
if all(valid is True for valid in validities):
# fast path
contribution = input_arr[indices]
else:
all_valid = np.all(validities)
all_valid = np.all(validities, axis=0)
contribution = np.where(all_valid, input_arr[indices], cval)
outputs.append(np.prod(weights) * contribution)
result = np.sum(outputs)
outputs.append(np.prod(weights, axis=0) * contribution)
result = np.sum(outputs, axis=0)
if np.issubdtype(input_arr.dtype, np.integer):
result = _round_half_away_from_zero(result)
return result.astype(input_arr.dtype)


def map_coordinates(
input: ArrayLike,
coordinates: Sequence[ArrayLike],
input: np.ndarray,
coordinates: Sequence[np.ndarray],
order: int,
output: None,
prefilter: None,
mode: str = "constant",
cval: ArrayLike = 0.0,
cval: np.ndarray = 0.0,
):
return _map_coordinates(input, coordinates, order, mode, cval)

0 comments on commit 6272c93

Please sign in to comment.