We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
input_output_aliases
import functools import jax from jax.experimental import pallas as pl import jax.numpy as jnp @functools.partial( pl.pallas_call, out_shape=( jax.ShapeDtypeStruct((2,), jnp.float32), jax.ShapeDtypeStruct((2,), jnp.float32), ), grid=1, input_output_aliases={0: 0, 1: 1}, ) def kernel(_, _2, x_ref, y_ref): pass def main(): x = jnp.array([1, 1], dtype=jnp.float32) y = jnp.array([2, 2], dtype=jnp.float32) x_out, y_out = kernel(x, y) print(x_out) print(y_out) if __name__ == '__main__': main()
Prints out
[1. 1.] [2. 2.]
because this is the normal behaviour (which can be confirmed in interpret mode).
The kernel should essentially do nothing but passing the inputs directly to the outputs.
[0. 0.] [0. 0.]
Note that this issue does not happens when there is only 1 array.
The repro is originally from a test
jax/tests/pallas/ops_test.py
Line 1344 in ff1c2ac
jax: 0.4.34.dev20240924+85a466d73 jaxlib: 0.4.33 numpy: 2.1.0 python: 3.12.4 (main, Jun 8 2024, 18:29:57) [GCC 11.4.0] jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=6, process_index=0, coords=(2,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(3,1,0), core_on_chip=0)] process_count: 1 platform: uname_result(system='Linux', node='t1v-n-ab2ce832-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')
The text was updated successfully, but these errors were encountered:
ayaka14732
No branches or pull requests
Repro
Expected behaviour
Prints out
because this is the normal behaviour (which can be confirmed in interpret mode).
The kernel should essentially do nothing but passing the inputs directly to the outputs.
Actual behaviour
Prints out
Note that this issue does not happens when there is only 1 array.
The repro is originally from a test
jax/tests/pallas/ops_test.py
Line 1344 in ff1c2ac
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: