Skip to content

multi condtions/switch with predicate functions #14264

Answered by jakevdp
jecampagne asked this question in Q&A
Discussion options

You must be logged in to vote

One way to do this is using jnp.select, which runs through a list of conditions and returns the associated value. For example:

import jax.numpy as jnp
import jax

@jax.jit
def f(x):
  x = abs(x)
  return jnp.select(
      [x == 0, x <= 4, x <= 8],
      [x, x - 1, x - 2],
      default = x)
  
print(f(0))
# 0
print(f(-2))
# 1
print(f(7))
# 5
print(f(10))
# 10

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@jecampagne
Comment options

@jakevdp
Comment options

Answer selected by jecampagne
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants