You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
...
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 817, in jaxpr_subcomp
ans = lowering_rules[eqn.primitive](
^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 1676, in _convert_element_type_lowering_rule
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 724, in f_lowered
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/jax/jax/_src/pallas/mosaic/lowering.py", line 823, in jaxpr_subcomp
raise LoweringException(
jax._src.pallas.mosaic.lowering.LoweringException: Exception while lowering eqn:
a:i64[] = convert_element_type[new_dtype=int64 weak_type=False] b
With context:
LoweringRuleContext(lowering_context=LoweringContext(ir_context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fb888e55af0>, grid_sizes=(), grid_names=None, mapped_dims=(), user_grid_indices=(), block_shapes=[None], name_stack=NameStack(stack=()), mesh_context=None, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x36886c0>: loc(callsite("main.<locals>.kernel"("/home/ayx/jax/2.py":14:13) at callsite("main"("/home/ayx/jax/2.py":17:10) at "<module>"("/home/ayx/jax/2.py":21:4)))), <jaxlib.xla_extension.Traceback object at 0x3688900>: loc(callsite("main.<locals>.kernel"("/home/ayx/jax/2.py":14:12) at callsite("main"("/home/ayx/jax/2.py":17:10) at "<module>"("/home/ayx/jax/2.py":21:4))))}, location_cache={(<code object kernel at 0x7fb937780030, file "/home/ayx/jax/2.py", line 9>, 82): loc("main.<locals>.kernel"("/home/ayx/jax/2.py":14:13)), (<code object main at 0x7fb93730c800, file "/home/ayx/jax/2.py", line 8>, 160): loc("main"("/home/ayx/jax/2.py":17:10)), (<code object <module> at 0x7fb9377342d0, file "/home/ayx/jax/2.py", line 1>, 124): loc("<module>"("/home/ayx/jax/2.py":21:4)), (<code object kernel at 0x7fb937780030, file "/home/ayx/jax/2.py", line 9>, 142): loc("main.<locals>.kernel"("/home/ayx/jax/2.py":14:12))}, canonical_name_cache={'/home/ayx/jax/2.py': '/home/ayx/jax/2.py'}, is_user_file_cache={'/home/ayx/jax/jax/_src/source_info_util.py': False, '/home/ayx/jax/jax/_src/interpreters/partial_eval.py': False, '/home/ayx/jax/jax/_src/pjit.py': False, '/home/ayx/jax/jax/_src/core.py': False, '/home/ayx/jax/jax/_src/traceback_util.py': False, '/home/ayx/jax/jax/_src/numpy/ufunc_api.py': False, '/home/ayx/jax/jax/_src/numpy/array_methods.py': False, '/home/ayx/jax/2.py': True, '/home/ayx/jax/jax/_src/linear_util.py': False, '/home/ayx/jax/jax/_src/profiler.py': False, '/home/ayx/jax/jax/_src/pallas/pallas_call.py': False, '/home/ayx/jax/jax/_src/lax/lax.py': False, '/home/ayx/jax/jax/_src/numpy/lax_numpy.py': False}), for_verification=False), avals_in=[ShapedArray(uint32[])], avals_out=[ShapedArray(int64[])], block_shapes=[None])
With inval shapes=[None]
With inval types=[IntegerType(i32)]
In jaxpr:
{ lambda ; a:u32[]. let
b:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
in (b,) }
Exception: maximum recursion depth exceeded
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
System info (python version, jaxlib version, accelerator, etc.)
Description
int64
is not supported on Pallas TPU. However, the error message is not very helpful:Error:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: