Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation Fault in jax-metal with variable indexing and partial slice in scan #23931

Open
bsarkar321 opened this issue Sep 25, 2024 · 0 comments
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@bsarkar321
Copy link

Description

I'm observing a consistent segmentation fault when attempting to perform a scan on the code attached below on a jax metal device (M1 Ultra Mac Studio, Sequoia 15.0). I've attempted show the minimum criteria needed to trigger the bug.

import jax
import jax.numpy as jnp

jax.config.update('jax_platform_name', 'METAL') # fine when setting to 'cpu'

n_layer = 5
state_width = 2  # fine when setting equal to to_add
to_add = 1
x_size = 5

out_state = jax.random.uniform(jax.random.key(0), (n_layer, state_width))
print(out_state)
def do_loop(x, i):
    s1 = out_state[i, :to_add]  # fine when replacing i with constant
    to_update = jnp.concat((s1, x[:-to_add]))
    # to_update = jnp.concat((s1, jnp.zeros(x_size - to_add))) # still has same bug
    # to_update = s1 @ jnp.zeros((to_add, x_size)) # still has same bug
    return to_update, i

x = jnp.zeros(x_size)
for i in range(n_layer): # fine with python for-loop
    x, _ = do_loop(x, i)
print("correct output", x)

x = jnp.zeros(x_size)
print("   scan output", jax.lax.scan(do_loop, x, jnp.arange(n_layer))[0]) # segmentation fault

Some notes:

  • this bug only happens on metal; CPU is perfectly fine
  • this bug is only triggered when indexing using a scan input AND taking a partial slice (the bug does not occur when only one of these conditions are met)
  • this bug is not triggered on python for-loops (even when using jit)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.31
numpy:  2.1.1
python: 3.12.6 | packaged by conda-forge | (main, Sep 22 2024, 14:07:06) [Clang 17.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Mac.attlocal.net', release='24.0.0', version='Darwin Kernel Version 24.0.0: Mon Aug 12 20:51:54 PDT 2024; root:xnu-11215.1.10~2/RELEASE_ARM64_T6000', machine='arm64')

Additionally, jax-metal is 0.1.0

@bsarkar321 bsarkar321 added the bug Something isn't working label Sep 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants