Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas TPU] jnp.uint32 comparisons give the wrong result with large numbers #23972

Open
ayaka14732 opened this issue Sep 27, 2024 · 0 comments · May be fixed by #24086
Open

[Pallas TPU] jnp.uint32 comparisons give the wrong result with large numbers #23972

ayaka14732 opened this issue Sep 27, 2024 · 0 comments · May be fixed by #24086
Assignees
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@ayaka14732
Copy link
Collaborator

Description

import functools
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp

def main():
    @functools.partial(
        pl.pallas_call,
        out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
        out_shape=jax.ShapeDtypeStruct((2,), jnp.bool_),
    )
    def kernel(o_ref):
        o_ref[0] = jnp.uint32(5) <= jnp.uint32(4294967294)
        o_ref[1] = jnp.uint32(4294967292) > jnp.uint32(4)

    out = kernel()
    print(out)  # prints "[False False]"

if __name__ == '__main__':
    main()

Expected behaviour:

Outputs [True True] since 5 <= 4294967294 and 4294967292 > 4

Actual behaviour:

Outputs [False False]

Note that this only happens with large numbers. For example, 4294967294 exceeds the upper bound of jnp.int32, but does not exceed the upper bound of jnp.uint32. Therefore, this may be related to an unexpected cast from uint32 to int32 somewhere.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34.dev20240924+85a466d73
jaxlib: 0.4.33
numpy:  2.1.0
python: 3.12.4 (main, Jun  8 2024, 18:29:57) [GCC 11.4.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=6, process_index=0, coords=(2,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(3,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-ab2ce832-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')
@ayaka14732 ayaka14732 added bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU) labels Sep 27, 2024
@ayaka14732 ayaka14732 self-assigned this Sep 27, 2024
copybara-service bot pushed a commit that referenced this issue Oct 3, 2024
Fixes #23972

In Pallas, we use `i32` for both `jnp.int32` and `jnp.uint32`, but we need to choose the correct operation (e.g. `arith.extUI` vs `arith.extSI`) or the correct attribute (e.g. `sle` vs `ule` for `arith::CmpIOp`). In this particular issue, we need `ule` but it's currently lowered to `sle`.

PiperOrigin-RevId: 681672271
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant