Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas TPU] Unrelated 'maximum recursion depth exceeded' error when using .astype(jnp.int64) #23988

Open
ayaka14732 opened this issue Sep 28, 2024 · 0 comments
Assignees
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@ayaka14732
Copy link
Collaborator

Description

int64 is not supported on Pallas TPU. However, the error message is not very helpful:

import functools
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

jax.config.update('jax_enable_x64', True)

def main():
    @functools.partial(
        pl.pallas_call,
        out_shape=jax.ShapeDtypeStruct((1, 1), jnp.int64),
    )
    def kernel(o_ref):
        x = (jnp.uint32(1) + jnp.uint32(2)).astype(jnp.int64)
        o_ref[...] = x.reshape((1, 1))

    out = kernel()
    print(out)

if __name__ == '__main__':
    main()

Error:

...
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
    ^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
    return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 823, in jaxpr_subcomp
    raise LoweringException(
jax._src.pallas.mosaic.lowering.LoweringException: Exception while lowering eqn:
  a:i64[] = convert_element_type[new_dtype=int64 weak_type=False] b
With context:
  LoweringRuleContext(lowering_context=LoweringContext(ir_context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fb888e55af0>, grid_sizes=(), grid_names=None, mapped_dims=(), user_grid_indices=(), block_shapes=[None], name_stack=NameStack(stack=()), mesh_context=None, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x36886c0>: loc(callsite("main.<locals>.kernel"("/home/ayx/jax/2.py":14:13) at callsite("main"("/home/ayx/jax/2.py":17:10) at "<module>"("/home/ayx/jax/2.py":21:4)))), <jaxlib.xla_extension.Traceback object at 0x3688900>: loc(callsite("main.<locals>.kernel"("/home/ayx/jax/2.py":14:12) at callsite("main"("/home/ayx/jax/2.py":17:10) at "<module>"("/home/ayx/jax/2.py":21:4))))}, location_cache={(<code object kernel at 0x7fb937780030, file "/home/ayx/jax/2.py", line 9>, 82): loc("main.<locals>.kernel"("/home/ayx/jax/2.py":14:13)), (<code object main at 0x7fb93730c800, file "/home/ayx/jax/2.py", line 8>, 160): loc("main"("/home/ayx/jax/2.py":17:10)), (<code object <module> at 0x7fb9377342d0, file "/home/ayx/jax/2.py", line 1>, 124): loc("<module>"("/home/ayx/jax/2.py":21:4)), (<code object kernel at 0x7fb937780030, file "/home/ayx/jax/2.py", line 9>, 142): loc("main.<locals>.kernel"("/home/ayx/jax/2.py":14:12))}, canonical_name_cache={'/home/ayx/jax/2.py': '/home/ayx/jax/2.py'}, is_user_file_cache={'/home/ayx/jax/jax/_src/source_info_util.py': False, '/home/ayx/jax/jax/_src/interpreters/partial_eval.py': False, '/home/ayx/jax/jax/_src/pjit.py': False, '/home/ayx/jax/jax/_src/core.py': False, '/home/ayx/jax/jax/_src/traceback_util.py': False, '/home/ayx/jax/jax/_src/numpy/ufunc_api.py': False, '/home/ayx/jax/jax/_src/numpy/array_methods.py': False, '/home/ayx/jax/2.py': True, '/home/ayx/jax/jax/_src/linear_util.py': False, '/home/ayx/jax/jax/_src/profiler.py': False, '/home/ayx/jax/jax/_src/pallas/pallas_call.py': False, '/home/ayx/jax/jax/_src/lax/lax.py': False, '/home/ayx/jax/jax/_src/numpy/lax_numpy.py': False}), for_verification=False), avals_in=[ShapedArray(uint32[])], avals_out=[ShapedArray(int64[])], block_shapes=[None])
With inval shapes=[None]
With inval types=[IntegerType(i32)]
In jaxpr:
{ lambda ; a:u32[]. let
    b:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
  in (b,) }
Exception: maximum recursion depth exceeded
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34.dev20240924+85a466d73
jaxlib: 0.4.33
numpy:  2.1.0
python: 3.12.4 (main, Jun  8 2024, 18:29:57) [GCC 11.4.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=6, process_index=0, coords=(2,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(3,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-ab2ce832-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')
@ayaka14732 ayaka14732 added bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU) labels Sep 28, 2024
@ayaka14732 ayaka14732 self-assigned this Sep 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

No branches or pull requests

1 participant