You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I really like using inspect_array_sharding and my own sharding_info() based on it, to better understand how sharding works in jax and what is going wrong in my code.
But recently I have been using shard_map more and there the inspect_array_sharding callback seems to be broken.
importosfromjax.experimental.shard_mapimportshard_mapfromjax.shardingimportMesh, PartitionSpecfromjax.experimentalimportmesh_utilsos.environ['XLA_FLAGS'] =f'--xla_force_host_platform_device_count=4'importjaximportjax.numpyasjnpfromjax.debugimportvisualize_array_shardingdevices=mesh_utils.create_device_mesh((4,))
mesh=Mesh(devices, axis_names=('a',))
sharding=jax.sharding.NamedSharding(mesh, PartitionSpec('a'))
defsome_function():
a=jnp.zeros(1000)
visualize_array_sharding(a)
returnasome_function()
# ┌───────┐# │ CPU 0 │# └───────┘some_function_jitted=jax.jit(some_function, out_shardings=sharding)
some_function_jitted()
# ┌───────┬───────┬───────┬───────┐# │ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │# └───────┴───────┴───────┴───────┘some_function_shard_map=shard_map(
some_function,
mesh=mesh,
in_specs=PartitionSpec(None),
out_specs=PartitionSpec("a"),
# check_rep=False
)
some_function_shard_map()
This causes the following issue:
Traceback (most recent call last):
File ".../inspect_shard_map.py", line 48, in <module>
some_function_shard_map()
File ".../inspect_shard_map.py", line 22, in some_function
visualize_array_sharding(a)
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 627, in visualize_array_sharding
inspect_array_sharding(arr, callback=_visualize)
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 621, in inspect_array_sharding
tree_util.tree_map(_inspect, value)
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 620, in _inspect
inspect_sharding_p.bind(val, callback=callback)
NotImplementedError: No replication rule for inspect_sharding. As a workaround, pass the `check_rep=False` argument to `shard_map`. To get this fixed, open an issue at https://github.com/google/jax/issues
But even with check_rep=False added, the code still fails:
Traceback (most recent call last):
File ".../inspect_shard_map.py", line 48, in <module>
some_function_shard_map()
File ".../inspect_shard_map.py", line 22, in some_function
visualize_array_sharding(a)
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 627, in visualize_array_sharding
inspect_array_sharding(arr, callback=_visualize)
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 621, in inspect_array_sharding
tree_util.tree_map(_inspect, value)
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 620, in _inspect
inspect_sharding_p.bind(val, callback=callback)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error calling inspect_sharding: Traceback (most recent call last):
File ".../inspect_shard_map.py", line 48, in <module>
File "venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 191, in wrapped
File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 473, in bind
File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 745, in _shard_map_impl
File "venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
File ".../inspect_shard_map.py", line 22, in some_function
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 627, in visualize_array_sharding
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 621, in inspect_array_sharding
File "venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 344, in tree_map
File "venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 344, in <genexpr>
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 620, in _inspect
File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 439, in bind
File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 443, in bind_with_trace
File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 835, in process_primitive
File "venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 332, in cache_miss
File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 2782, in bind
File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 443, in bind_with_trace
File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 949, in process_primitive
File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1739, in _pjit_call_impl
File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1721, in call_impl_cache_miss
File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1651, in _pjit_call_impl_python
File "venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile
File "venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2827, in from_hlo
File "venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2639, in _cached_compilation
File "venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 426, in compile_or_get_cached
File "venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 654, in _compile_and_write_cache
File "venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
File "venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 266, in backend_compile
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 385, in _hlo_sharding_callback
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 626, in _visualize
File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 496, in visualize_sharding
File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 200, in devices_indices_map
File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 302, in wrapper
File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 296, in cached
File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 49, in common_devices_indices_map
File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 217, in shard_shape
File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 302, in wrapper
File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 296, in cached
File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 58, in _common_shard_shape
File "venv/lib/python3.12/site-packages/jax/_src/sharding_impls.py", line 745, in _to_xla_hlo_sharding
File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 302, in wrapper
File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 296, in cached
File "venv/lib/python3.12/site-packages/jax/_src/sharding_impls.py", line 617, in _positional_sharding_to_xla_hlo_sharding
ValueError: not enough values to unpack (expected 1, got 0)
Description
I really like using
inspect_array_sharding
and my ownsharding_info()
based on it, to better understand how sharding works in jax and what is going wrong in my code.But recently I have been using shard_map more and there the
inspect_array_sharding
callback seems to be broken.This causes the following issue:
But even with
check_rep=False
added, the code still fails:jax/jax/_src/sharding_impls.py
Line 617 in 1594d2f
It would be great if
inspect_array_sharding
could work inside ashard_map
the same way it already does inside ajit
ted function.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: