We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lax.cond
compute_on('device_host')
Error message:
E0924 21:59:40.491062 22297 status_macros.cc:56] INTERNAL: RET_CHECK failure (platforms/xla/service/ba16c7433/target.cc:3042) TransferSizeUtil::HasSparseCoreLayout(*topology_, leaf_shape) === Source Location Trace: === third_party/tensorflow/compiler/xla/status_macros.cc:79 Stack trace: @ 0x7b88d43e9982 (unknown) @ 0x7b88d43e974e (unknown) @ 0x7b88d0a17407 (unknown) @ 0x7b88cda6f868 (unknown) @ 0x7b88cda6f562 (unknown) @ 0x7b88cda6f60b (unknown) @ 0x7b88cda6f4e5 (unknown) @ 0x7b88d424d51f (unknown) @ 0x7b88d424b2d6 (unknown) @ 0x7b88c95421aa (unknown) @ 0x7b88cda6dfc0 (unknown) @ 0x7b88cbe5e0e3 (unknown) @ 0x7b88cbeb5ad4 (unknown) @ 0x7b88cd7ee9b3 (unknown) @ 0x7b88d4776f9e (unknown) @ 0x7b88d477d276 (unknown) @ 0x7b88d4785da5 (unknown) @ 0x7b88d4a27ae3 (unknown) @ 0x7b8985b71134 (unknown) Traceback (most recent call last): File "/root/lax_cond_host_fn.py", line 35, in <module> jit_fn = jit_fn.lower(x, y).compile() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/venv-311/lib/python3.11/site-packages/jax/_src/stages.py", line 660, in compile self._lowering.compile(**kw), # pytype: disable=wrong-keyword-args ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/venv-311/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2315, in compile executable = UnloadedMeshExecutable.from_hlo( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/venv-311/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2829, in from_hlo xla_executable = _cached_compilation( ^^^^^^^^^^^^^^^^^^^^ File "/root/venv-311/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2641, in _cached_compilation xla_executable = compiler.compile_or_get_cached( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/venv-311/lib/python3.11/site-packages/jax/_src/compiler.py", line 424, in compile_or_get_cached return _compile_and_write_cache( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/venv-311/lib/python3.11/site-packages/jax/_src/compiler.py", line 652, in _compile_and_write_cache executable = backend_compile( ^^^^^^^^^^^^^^^^ File "/root/venv-311/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/root/venv-311/lib/python3.11/site-packages/jax/_src/compiler.py", line 270, in backend_compile raise e File "/root/venv-311/lib/python3.11/site-packages/jax/_src/compiler.py", line 264, in backend_compile return backend.compile(built_c, compile_options=options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure (platforms/xla/service/ba16c7433/target.cc:3042) TransferSizeUtil::HasSparseCoreLayout(*topology_, leaf_shape) I0924 21:59:40.563519 21980 allocator_stats_reporter.cc:147] Stopping AllocatorStatsReporter. Reporting one last time first. W0924 21:59:40.665513 23332 manuallysampledmetrics.cc:278] ManuallySampledMetrics couldn't flush samples within the 10s deadline: monitoring_streamz::ERROR_STREAMZ_UNAVAILABLE: Streamz wasn't configured to discover the manual sampling servers because you didn't link the streamz library in your BUILD rule. See go/streamz-force-collection#code-samples. W0924 21:59:40.665563 23332 manuallysampledmetrics.cc:278] ManuallySampledMetrics couldn't flush samples within the 10s deadline: monitoring_streamz::ERROR_STREAMZ_UNAVAILABLE: Streamz wasn't configured to discover the manual sampling servers because you didn't link the streamz library in your BUILD rule. See go/streamz-force-collection#code-samples. W0924 21:59:40.665574 23332 manuallysampledmetrics.cc:278] ManuallySampledMetrics couldn't flush samples within the 10s deadline: monitoring_streamz::ERROR_STREAMZ_UNAVAILABLE: Streamz wasn't configured to discover the manual sampling servers because you didn't link the streamz library in your BUILD rule. See go/streamz-force-collection#code-samples. W0924 21:59:40.665580 23332 manuallysampledmetrics.cc:278] ManuallySampledMetrics couldn't flush samples within the 10s deadline: monitoring_streamz::ERROR_STREAMZ_UNAVAILABLE: Streamz wasn't configured to discover the manual sampling servers because you didn't link the streamz library in your BUILD rule. See go/streamz-force-collection#code-samples. W0924 21:59:40.665587 23332 manuallysampledmetrics.cc:278] ManuallySampledMetrics couldn't flush samples within the 10s deadline: monitoring_streamz::ERROR_STREAMZ_UNAVAILABLE: Streamz wasn't configured to discover the manual sampling servers because you didn't link the streamz library in your BUILD rule. See go/streamz-force-collection#code-samples. W0924 21:59:47.322119 22811 firmware_indirect_registers.cc:80] Released last reference with existing Open. Performing implicit close-on-destruction. I0924 21:59:47.453375 22049 uptime_telemetry.cc:139] Successfully updated uptime metric W0924 21:59:47.476233 22809 firmware_indirect_registers.cc:80] Released last reference with existing Open. Performing implicit close-on-destruction. I0924 21:59:49.310701 22809 async_driver.cc:847] [/dev/vfio/2 tpu540:pe0:2] Driver closed. W0924 21:59:49.453362 22809 firmware_indirect_registers.cc:80] Released last reference with existing Open. Performing implicit close-on-destruction. W0924 21:59:49.463711 22755 firmware_indirect_registers.cc:80] Released last reference with existing Open. Performing implicit close-on-destruction. I0924 21:59:50.395870 22811 async_driver.cc:847] [/dev/vfio/3 tpu540:pe0:3] Driver closed. I0924 21:59:50.638586 22809 async_driver.cc:847] [/dev/vfio/1 tpu540:pe0:1] Driver closed. I0924 21:59:50.697901 22755 async_driver.cc:847] [/dev/vfio/0 tpu540:pe0:0] Driver closed.
Simple repro:
import os import timeit import jax import jax.sharding from jax.experimental.compute_on import compute_on import jax.numpy as jnp sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) p_sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host") @compute_on("device_host") @jax.jit def host_fn(gradient, opt_state): def true_fn(gradient, opt_state): opt_state = opt_state + jnp.sin(gradient) delta = opt_state * gradient return delta, opt_state def false_fn(gradient, opt_state): return gradient, opt_state return jax.lax.cond(jnp.sum(gradient) > 0, true_fn, false_fn, gradient, opt_state) def test_fn(gradient, opt_state): gradient = jnp.log(gradient) return host_fn(gradient, opt_state) x = jnp.arange(0, 1024*1024, dtype=jnp.float32) y = jnp.arange(0, 1024*1024, dtype=jnp.float32) y = jax.device_put(y, p_sharding) jit_fn = jax.jit(test_fn, in_shardings=(sharding, p_sharding), out_shardings=(sharding, p_sharding), donate_argnums=(0,1)) jit_fn = jit_fn.lower(x, y).compile() def fn(): global x, y x, y = jit_fn(x, y) jax.block_until_ready((x, y)) t = timeit.Timer(fn) print(t.timeit(10), t.repeat(5, 10))
Python3.11, tested on v5p
jax 0.4.34.dev20240924 jaxlib 0.4.34.dev20240924 libtpu-nightly 0.1.dev20240924+nightly
The text was updated successfully, but these errors were encountered:
yashk2810
No branches or pull requests
Description
Error message:
Simple repro:
System info (python version, jaxlib version, accelerator, etc.)
Python3.11, tested on v5p
The text was updated successfully, but these errors were encountered: