Skip to content

Commit

Permalink
update jax grid
Browse files Browse the repository at this point in the history
  • Loading branch information
alanlujan91 committed Oct 13, 2023
1 parent c5fe27e commit d9f15d4
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/multinterp/regular.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import cupy as cp
import jax.numpy as jnp
import numpy as np

from multinterp.backend._numba import numba_get_coordinates, numba_map_coordinates
Expand Down Expand Up @@ -41,6 +42,7 @@ def get_methods():

get_coords["jax"] = jax_get_coordinates
map_coords["jax"] = jax_map_coordinates
get_grad["jax"] = jnp.gradient
except ImportError:
pass

Expand Down

0 comments on commit d9f15d4

Please sign in to comment.