Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.cond is broken within compute_on('device_host') #23887

Open
hanzhi713 opened this issue Sep 24, 2024 · 0 comments
Open

lax.cond is broken within compute_on('device_host') #23887

hanzhi713 opened this issue Sep 24, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@hanzhi713
Copy link

hanzhi713 commented Sep 24, 2024

Description

Error message:

E0924 21:59:40.491062   22297 status_macros.cc:56] INTERNAL: RET_CHECK failure (platforms/xla/service/ba16c7433/target.cc:3042) TransferSizeUtil::HasSparseCoreLayout(*topology_, leaf_shape) 
=== Source Location Trace: ===
third_party/tensorflow/compiler/xla/status_macros.cc:79

Stack trace:
    @     0x7b88d43e9982  (unknown)
    @     0x7b88d43e974e  (unknown)
    @     0x7b88d0a17407  (unknown)
    @     0x7b88cda6f868  (unknown)
    @     0x7b88cda6f562  (unknown)
    @     0x7b88cda6f60b  (unknown)
    @     0x7b88cda6f4e5  (unknown)
    @     0x7b88d424d51f  (unknown)
    @     0x7b88d424b2d6  (unknown)
    @     0x7b88c95421aa  (unknown)
    @     0x7b88cda6dfc0  (unknown)
    @     0x7b88cbe5e0e3  (unknown)
    @     0x7b88cbeb5ad4  (unknown)
    @     0x7b88cd7ee9b3  (unknown)
    @     0x7b88d4776f9e  (unknown)
    @     0x7b88d477d276  (unknown)
    @     0x7b88d4785da5  (unknown)
    @     0x7b88d4a27ae3  (unknown)
    @     0x7b8985b71134  (unknown)
Traceback (most recent call last):
  File "/root/lax_cond_host_fn.py", line 35, in <module>
    jit_fn = jit_fn.lower(x, y).compile()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv-311/lib/python3.11/site-packages/jax/_src/stages.py", line 660, in compile
    self._lowering.compile(**kw),  # pytype: disable=wrong-keyword-args
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv-311/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2315, in compile
    executable = UnloadedMeshExecutable.from_hlo(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv-311/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2829, in from_hlo
    xla_executable = _cached_compilation(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/root/venv-311/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2641, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv-311/lib/python3.11/site-packages/jax/_src/compiler.py", line 424, in compile_or_get_cached
    return _compile_and_write_cache(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv-311/lib/python3.11/site-packages/jax/_src/compiler.py", line 652, in _compile_and_write_cache
    executable = backend_compile(
                 ^^^^^^^^^^^^^^^^
  File "/root/venv-311/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv-311/lib/python3.11/site-packages/jax/_src/compiler.py", line 270, in backend_compile
    raise e
  File "/root/venv-311/lib/python3.11/site-packages/jax/_src/compiler.py", line 264, in backend_compile
    return backend.compile(built_c, compile_options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure (platforms/xla/service/ba16c7433/target.cc:3042) TransferSizeUtil::HasSparseCoreLayout(*topology_, leaf_shape) 
I0924 21:59:40.563519   21980 allocator_stats_reporter.cc:147] Stopping AllocatorStatsReporter. Reporting one last time first.
W0924 21:59:40.665513   23332 manuallysampledmetrics.cc:278] ManuallySampledMetrics couldn't flush samples within the 10s deadline: monitoring_streamz::ERROR_STREAMZ_UNAVAILABLE: Streamz wasn't configured to discover the manual sampling servers because you didn't link the streamz library in your BUILD rule. See go/streamz-force-collection#code-samples.
W0924 21:59:40.665563   23332 manuallysampledmetrics.cc:278] ManuallySampledMetrics couldn't flush samples within the 10s deadline: monitoring_streamz::ERROR_STREAMZ_UNAVAILABLE: Streamz wasn't configured to discover the manual sampling servers because you didn't link the streamz library in your BUILD rule. See go/streamz-force-collection#code-samples.
W0924 21:59:40.665574   23332 manuallysampledmetrics.cc:278] ManuallySampledMetrics couldn't flush samples within the 10s deadline: monitoring_streamz::ERROR_STREAMZ_UNAVAILABLE: Streamz wasn't configured to discover the manual sampling servers because you didn't link the streamz library in your BUILD rule. See go/streamz-force-collection#code-samples.
W0924 21:59:40.665580   23332 manuallysampledmetrics.cc:278] ManuallySampledMetrics couldn't flush samples within the 10s deadline: monitoring_streamz::ERROR_STREAMZ_UNAVAILABLE: Streamz wasn't configured to discover the manual sampling servers because you didn't link the streamz library in your BUILD rule. See go/streamz-force-collection#code-samples.
W0924 21:59:40.665587   23332 manuallysampledmetrics.cc:278] ManuallySampledMetrics couldn't flush samples within the 10s deadline: monitoring_streamz::ERROR_STREAMZ_UNAVAILABLE: Streamz wasn't configured to discover the manual sampling servers because you didn't link the streamz library in your BUILD rule. See go/streamz-force-collection#code-samples.
W0924 21:59:47.322119   22811 firmware_indirect_registers.cc:80] Released last reference with existing Open. Performing implicit close-on-destruction.
I0924 21:59:47.453375   22049 uptime_telemetry.cc:139] Successfully updated uptime metric
W0924 21:59:47.476233   22809 firmware_indirect_registers.cc:80] Released last reference with existing Open. Performing implicit close-on-destruction.
I0924 21:59:49.310701   22809 async_driver.cc:847] [/dev/vfio/2 tpu540:pe0:2] Driver closed.
W0924 21:59:49.453362   22809 firmware_indirect_registers.cc:80] Released last reference with existing Open. Performing implicit close-on-destruction.
W0924 21:59:49.463711   22755 firmware_indirect_registers.cc:80] Released last reference with existing Open. Performing implicit close-on-destruction.
I0924 21:59:50.395870   22811 async_driver.cc:847] [/dev/vfio/3 tpu540:pe0:3] Driver closed.
I0924 21:59:50.638586   22809 async_driver.cc:847] [/dev/vfio/1 tpu540:pe0:1] Driver closed.
I0924 21:59:50.697901   22755 async_driver.cc:847] [/dev/vfio/0 tpu540:pe0:0] Driver closed.

Simple repro:

import os
import timeit
import jax
import jax.sharding
from jax.experimental.compute_on import compute_on
import jax.numpy as jnp

sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
p_sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host")


@compute_on("device_host")
@jax.jit
def host_fn(gradient, opt_state):
    def true_fn(gradient, opt_state):
        opt_state = opt_state + jnp.sin(gradient)
        delta = opt_state * gradient
        return delta, opt_state
    
    def false_fn(gradient, opt_state):
        return gradient, opt_state
    
    return jax.lax.cond(jnp.sum(gradient) > 0, true_fn, false_fn, gradient, opt_state)

def test_fn(gradient, opt_state):
    gradient = jnp.log(gradient)
    return host_fn(gradient, opt_state)


x = jnp.arange(0, 1024*1024, dtype=jnp.float32)
y = jnp.arange(0, 1024*1024, dtype=jnp.float32)
y = jax.device_put(y, p_sharding)

jit_fn = jax.jit(test_fn, in_shardings=(sharding, p_sharding), out_shardings=(sharding, p_sharding), donate_argnums=(0,1))
jit_fn = jit_fn.lower(x, y).compile()


def fn():
    global x, y
    x, y = jit_fn(x, y)
    jax.block_until_ready((x, y))


t = timeit.Timer(fn)
print(t.timeit(10), t.repeat(5, 10))

System info (python version, jaxlib version, accelerator, etc.)

Python3.11, tested on v5p

jax                            0.4.34.dev20240924
jaxlib                         0.4.34.dev20240924
libtpu-nightly                 0.1.dev20240924+nightly
@hanzhi713 hanzhi713 added the bug Something isn't working label Sep 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants