Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax[cuda] always allocates memory #23882

Open
GJBoth opened this issue Sep 24, 2024 · 2 comments
Open

Jax[cuda] always allocates memory #23882

GJBoth opened this issue Sep 24, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@GJBoth
Copy link
Contributor

GJBoth commented Sep 24, 2024

Description

Hi,

I'm running Jax in combination Ray to do some RL - we have both CPU only nodes (actors) and GPU nodes (learners). I set the default device to cpu using

import jax
jax.config.update("jax_platform_name", "cpu")

but jax always allocates about 156mb of GPU memory per process. Normally this wouldn't be an issue, but since I have 32 or 64 actors, the amount of memory ( > 5GB) becomes quite significant. Turning off pre-allocation doesn't help.

Is this expected behaviour? Everything would be solved for my usecase if those 156mb aren't allocated when using cpu as a platform.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.33
jaxlib: 0.4.33
numpy:  1.24.3
python: 3.10.14 (main, Aug 14 2024, 05:11:29) [Clang 18.1.8 ]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='f14u25.int.janelia.org', release='5.14.0-362.24.1.el9_3.0.1.x86_64', version='#1 SMP PREEMPT_DYNAMIC Thu Apr 4 22:31:43 UTC 2024', machine='x86_64')


$ nvidia-smi
Tue Sep 24 15:13:53 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 2080 Ti     On  |   00000000:1A:00.0 Off |                  N/A |
| 30%   34C    P2             54W /  250W |     158MiB /  11264MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 2080 Ti     On  |   00000000:1B:00.0 Off |                  N/A |
| 29%   33C    P2             54W /  250W |     158MiB /  11264MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA GeForce RTX 2080 Ti     On  |   00000000:60:00.0 Off |                  N/A |
| 30%   34C    P2             45W /  250W |     158MiB /  11264MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA GeForce RTX 2080 Ti     On  |   00000000:61:00.0 Off |                  N/A |
| 30%   35C    P2             54W /  250W |     158MiB /  11264MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA GeForce RTX 2080 Ti     On  |   00000000:B1:00.0 Off |                  N/A |
| 29%   32C    P2             50W /  250W |     158MiB /  11264MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA GeForce RTX 2080 Ti     On  |   00000000:B2:00.0 Off |                  N/A |
| 29%   32C    P2             53W /  250W |     158MiB /  11264MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA GeForce RTX 2080 Ti     On  |   00000000:DA:00.0 Off |                  N/A |
| 29%   29C    P2             50W /  250W |     158MiB /  11264MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA GeForce RTX 2080 Ti     On  |   00000000:DB:00.0 Off |                  N/A |
| 29%   31C    P2             39W /  250W |     158MiB /  11264MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A   1677504      C   python                                        156MiB |
|    1   N/A  N/A   1677504      C   python                                        156MiB |
|    2   N/A  N/A   1677504      C   python                                        156MiB |
|    3   N/A  N/A   1677504      C   python                                        156MiB |
|    4   N/A  N/A   1677504      C   python                                        156MiB |
|    5   N/A  N/A   1677504      C   python                                        156MiB |
|    6   N/A  N/A   1677504      C   python                                        156MiB |
|    7   N/A  N/A   1677504      C   python                                        156MiB |
+-----------------------------------------------------------------------------------------+```
@GJBoth GJBoth added the bug Something isn't working label Sep 24, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 24, 2024

In case you haven't seen it, check out https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html#gpu-memory-allocation for some discussion of JAX's GPU memory allocation behavior.

@GJBoth
Copy link
Contributor Author

GJBoth commented Sep 25, 2024

Thanks @jakevdp - I know that page pretty well after helping people switch to Jax from pytorch :-). I'm more surprised that Jax always allocates 156MB of GPU memory for each process - even with pre-allocation off and the platform set to cpu. Here's a little snippet to show what I'm talking about:

import os

import jax
import jax.numpy as jnp

jax.config.update("jax_platform_name", "cpu")
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# this thing allocates 156mb of GPU memory, even though its on CPU
print(jnp.ones((1,)).device)

That's slightly surprising to me - I would've expected no GPU memory usage when setting the the platform to CPU, or am I missing something? Is this expected behaviour?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants