Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas TPU] Wrong value when using input_output_aliases with multiple arrays #24023

Open
ayaka14732 opened this issue Sep 30, 2024 · 0 comments
Assignees
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@ayaka14732
Copy link
Collaborator

ayaka14732 commented Sep 30, 2024

Repro

import functools
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

@functools.partial(
    pl.pallas_call,
    out_shape=(
        jax.ShapeDtypeStruct((2,), jnp.float32),
        jax.ShapeDtypeStruct((2,), jnp.float32),
    ),
    grid=1,
    input_output_aliases={0: 0, 1: 1},
)
def kernel(_, _2, x_ref, y_ref):
    pass

def main():
    x = jnp.array([1, 1], dtype=jnp.float32)
    y = jnp.array([2, 2], dtype=jnp.float32)

    x_out, y_out = kernel(x, y)

    print(x_out)
    print(y_out)

if __name__ == '__main__':
    main()

Expected behaviour

Prints out

[1. 1.]
[2. 2.]

because this is the normal behaviour (which can be confirmed in interpret mode).

The kernel should essentially do nothing but passing the inputs directly to the outputs.

Actual behaviour

Prints out

[0. 0.]
[0. 0.]

Note that this issue does not happens when there is only 1 array.

The repro is originally from a test

def test_swap(self):
. It was identified while working on #23967.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34.dev20240924+85a466d73
jaxlib: 0.4.33
numpy:  2.1.0
python: 3.12.4 (main, Jun  8 2024, 18:29:57) [GCC 11.4.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=6, process_index=0, coords=(2,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(3,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-ab2ce832-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')
@ayaka14732 ayaka14732 added bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU) labels Sep 30, 2024
@ayaka14732 ayaka14732 self-assigned this Sep 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

No branches or pull requests

1 participant