Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

visualize_array_sharding/inspect_array_sharding doesn't work with shard_map #23936

Open
Findus23 opened this issue Sep 26, 2024 · 0 comments
Open
Assignees
Labels
bug Something isn't working

Comments

@Findus23
Copy link

Findus23 commented Sep 26, 2024

Description

I really like using inspect_array_sharding and my own sharding_info() based on it, to better understand how sharding works in jax and what is going wrong in my code.

But recently I have been using shard_map more and there the inspect_array_sharding callback seems to be broken.

import os

from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec
from jax.experimental import mesh_utils

os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count=4'

import jax
import jax.numpy as jnp
from jax.debug import visualize_array_sharding

devices = mesh_utils.create_device_mesh((4,))

mesh = Mesh(devices, axis_names=('a',))
sharding = jax.sharding.NamedSharding(mesh, PartitionSpec('a'))

def some_function():
    a = jnp.zeros(1000)
    visualize_array_sharding(a)
    return a

some_function()

# ┌───────┐
# │ CPU 0 │
# └───────┘

some_function_jitted = jax.jit(some_function, out_shardings=sharding)

some_function_jitted()

# ┌───────┬───────┬───────┬───────┐
# │ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │
# └───────┴───────┴───────┴───────┘

some_function_shard_map = shard_map(
    some_function,
    mesh=mesh,
    in_specs=PartitionSpec(None),
    out_specs=PartitionSpec("a"),
    # check_rep=False
)

some_function_shard_map()

This causes the following issue:

Traceback (most recent call last):
  File ".../inspect_shard_map.py", line 48, in <module>
    some_function_shard_map()
  File ".../inspect_shard_map.py", line 22, in some_function
    visualize_array_sharding(a)
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 627, in visualize_array_sharding
    inspect_array_sharding(arr, callback=_visualize)
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 621, in inspect_array_sharding
    tree_util.tree_map(_inspect, value)
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 620, in _inspect
    inspect_sharding_p.bind(val, callback=callback)
NotImplementedError: No replication rule for inspect_sharding. As a workaround, pass the `check_rep=False` argument to `shard_map`. To get this fixed, open an issue at https://github.com/google/jax/issues

But even with check_rep=False added, the code still fails:

Traceback (most recent call last):
  File ".../inspect_shard_map.py", line 48, in <module>
    some_function_shard_map()
  File ".../inspect_shard_map.py", line 22, in some_function
    visualize_array_sharding(a)
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 627, in visualize_array_sharding
    inspect_array_sharding(arr, callback=_visualize)
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 621, in inspect_array_sharding
    tree_util.tree_map(_inspect, value)
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 620, in _inspect
    inspect_sharding_p.bind(val, callback=callback)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error calling inspect_sharding: Traceback (most recent call last):
  File ".../inspect_shard_map.py", line 48, in <module>
  File "venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 191, in wrapped
  File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 473, in bind
  File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 745, in _shard_map_impl
  File "venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
  File ".../inspect_shard_map.py", line 22, in some_function
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 627, in visualize_array_sharding
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 621, in inspect_array_sharding
  File "venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 344, in tree_map
  File "venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 344, in <genexpr>
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 620, in _inspect
  File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 439, in bind
  File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 443, in bind_with_trace
  File "venv/lib/python3.12/site-packages/jax/experimental/shard_map.py", line 835, in process_primitive
  File "venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 332, in cache_miss
  File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
  File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 2782, in bind
  File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 443, in bind_with_trace
  File "venv/lib/python3.12/site-packages/jax/_src/core.py", line 949, in process_primitive
  File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1739, in _pjit_call_impl
  File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1721, in call_impl_cache_miss
  File "venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1651, in _pjit_call_impl_python
  File "venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile
  File "venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2827, in from_hlo
  File "venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2639, in _cached_compilation
  File "venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 426, in compile_or_get_cached
  File "venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 654, in _compile_and_write_cache
  File "venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
  File "venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 266, in backend_compile
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 385, in _hlo_sharding_callback
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 626, in _visualize
  File "venv/lib/python3.12/site-packages/jax/_src/debugging.py", line 496, in visualize_sharding
  File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 200, in devices_indices_map
  File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 302, in wrapper
  File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 296, in cached
  File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 49, in common_devices_indices_map
  File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 217, in shard_shape
  File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 302, in wrapper
  File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 296, in cached
  File "venv/lib/python3.12/site-packages/jax/_src/sharding.py", line 58, in _common_shard_shape
  File "venv/lib/python3.12/site-packages/jax/_src/sharding_impls.py", line 745, in _to_xla_hlo_sharding
  File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 302, in wrapper
  File "venv/lib/python3.12/site-packages/jax/_src/util.py", line 296, in cached
  File "venv/lib/python3.12/site-packages/jax/_src/sharding_impls.py", line 617, in _positional_sharding_to_xla_hlo_sharding
ValueError: not enough values to unpack (expected 1, got 0)

set_size, = {len(device_set) for device_set in self._ids.flat}

It would be great if inspect_array_sharding could work inside a shard_map the same way it already does inside a jitted function.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.33
jaxlib: 0.4.33
numpy:  2.0.1
python: 3.12.6 (main, Sep  7 2024, 14:20:15) [GCC 14.2.0]
jax.devices (4 total, 4 local): [CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]
process_count: 1
platform: uname_result(system='Linux', node='standpc', release='6.10.9-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.10.9-1 (2024-09-08)', machine='x86_64')
@Findus23 Findus23 added the bug Something isn't working label Sep 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants