Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-metal: GPU crash(?) with large input #23902

Open
matthewlai opened this issue Sep 25, 2024 · 1 comment
Open

jax-metal: GPU crash(?) with large input #23902

matthewlai opened this issue Sep 25, 2024 · 1 comment
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@matthewlai
Copy link

Description

This seems to cause the whole machine (or at least WindowServer) to lock up, and the machine restarted by userspace watchdog. Also takes very long (~1 minute) to compile.

import os
import platform

if platform.system() == 'Darwin':
  os.environ['ENABLE_PJRT_COMPATIBILITY'] = '1'

import jax
from jax import numpy as jnp
import numpy as np

@jax.jit
def apply_lut(img: jnp.ndarray, lut: jnp.ndarray) -> jnp.ndarray:
	indices = jnp.floor(img * 63).astype(jnp.uint8)
	return lut[indices[:, :, 0], indices[:, :, 1], indices[:, :, 2]]

key = jax.random.key(42)
img = jax.random.uniform(key, shape=(4096, 3072, 3))
lut = jax.random.uniform(key, shape=(63, 63, 63, 3))

apply_lut(img, lut)

This is the minimally reproducible version of a function that applies a look-up table to every pixel in an image.

Suggestions for better ways to do this on Metal would also be appreciated, but this works fine and is very fast on NVIDIA.

Thanks!

System info (python version, jaxlib version, accelerator, etc.)

Intel MacBook Pro with AMD Radeon Pro 5300M.

Python 3.12.6, jax 0.4.31, jax-metal 0.1.0.

>>> import jax; jax.print_environment_info()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1727260830.000558   30400 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: AMD Radeon Pro 5300M

systemMemory: 16.00 GB
maxCacheSize: 1.99 GB

I0000 00:00:1727260830.028212   30400 service.cc:145] XLA service 0x6000011c4200 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1727260830.028236   30400 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1727260830.030023   30400 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1727260830.030045   30400 mps_client.cc:384] XLA backend will use up to 4277645312 bytes on device 0 for SimpleAllocator.
jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.12.6 (main, Sep  6 2024, 19:03:47) [Clang 15.0.0 (clang-1500.3.9.4)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='matthewlai-macbookpro3.roam.corp.google.com', release='23.6.0', version='Darwin Kernel Version 23.6.0: Wed Jul 31 20:48:44 PDT 2024; root:xnu-10063.141.1.700.5~1/RELEASE_X86_64', machine='x86_64')
@matthewlai matthewlai added the bug Something isn't working label Sep 25, 2024
@matthewlai
Copy link
Author

This also crashes an M1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants