Skip to content

Commit

Permalink
Replace deprecated jax.random.shuffle with jax.random.permutation
Browse files Browse the repository at this point in the history
jax.random.shuffle has long been deprecated, because it cannot operate in-place like np.random.shuffle, and because its functionality can be performed with jax.random.permutation (with independent=True in the case of multi-dimensional arrays).

PiperOrigin-RevId: 622905545
  • Loading branch information
vanderplas authored and tensorflower-gardener committed Apr 8, 2024
1 parent 51f1ed3 commit 63b2100
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _shuffle_jax(value, seed=None, name=None): # pylint: disable=unused-argumen
import jax.random as jaxrand # pylint: disable=g-import-not-at-top
if seed is None:
raise ValueError('Must provide PRNGKey to sample in JAX.')
return jaxrand.shuffle(seed, value, axis=0)
return jaxrand.permutation(seed, value, axis=0, independent=True)


def _truncated_normal(
Expand Down

0 comments on commit 63b2100

Please sign in to comment.