You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm running Jax in combination Ray to do some RL - we have both CPU only nodes (actors) and GPU nodes (learners). I set the default device to cpu using
but jax always allocates about 156mb of GPU memory per process. Normally this wouldn't be an issue, but since I have 32 or 64 actors, the amount of memory ( > 5GB) becomes quite significant. Turning off pre-allocation doesn't help.
Is this expected behaviour? Everything would be solved for my usecase if those 156mb aren't allocated when using cpu as a platform.
System info (python version, jaxlib version, accelerator, etc.)
Thanks @jakevdp - I know that page pretty well after helping people switch to Jax from pytorch :-). I'm more surprised that Jax always allocates 156MB of GPU memory for each process - even with pre-allocation off and the platform set to cpu. Here's a little snippet to show what I'm talking about:
import os
import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
# this thing allocates 156mb of GPU memory, even though its on CPU
print(jnp.ones((1,)).device)
That's slightly surprising to me - I would've expected no GPU memory usage when setting the the platform to CPU, or am I missing something? Is this expected behaviour?
Description
Hi,
I'm running Jax in combination Ray to do some RL - we have both CPU only nodes (actors) and GPU nodes (learners). I set the default device to cpu using
but jax always allocates about 156mb of GPU memory per process. Normally this wouldn't be an issue, but since I have 32 or 64 actors, the amount of memory ( > 5GB) becomes quite significant. Turning off pre-allocation doesn't help.
Is this expected behaviour? Everything would be solved for my usecase if those 156mb aren't allocated when using cpu as a platform.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: