Skip to content

Commit

Permalink
Removing deprecation warning and jitting predict (#120)
Browse files Browse the repository at this point in the history
* removing deprecation warning and jitting predict

* adding news

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dfm and pre-commit-ci[bot] authored Oct 27, 2022
1 parent 8c8cb34 commit c1893ca
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
2 changes: 2 additions & 0 deletions news/120.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Removed deprecation warning from ``predict`` method and wrapped it in a
``jax.jit`` in order to support interactive use.
13 changes: 5 additions & 8 deletions src/tinygp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ def condition(

return ConditionResult(log_prob, gp)

@partial(
jax.jit,
static_argnums=(0,),
static_argnames=("include_mean", "return_var", "return_cov"),
)
def predict(
self,
y: JAXArray,
Expand Down Expand Up @@ -250,14 +255,6 @@ def predict(
returned with shape ``(N_test,)`` or ``(N_test, N_test)``
respectively.
"""
import warnings

warnings.warn(
"The 'predict' method is deprecated and 'condition' should be preferred",
DeprecationWarning,
stacklevel=2,
)

_, cond = self.condition(
y, X_test, kernel=kernel, include_mean=include_mean
)
Expand Down

0 comments on commit c1893ca

Please sign in to comment.