Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX cannot find TPU metadata inside a container #10923

Open
pawalt opened this issue Sep 18, 2024 · 5 comments
Open

JAX cannot find TPU metadata inside a container #10923

pawalt opened this issue Sep 18, 2024 · 5 comments
Assignees
Labels
area: tpu Issues related to TPU access type: bug Something isn't working

Comments

@pawalt
Copy link

pawalt commented Sep 18, 2024

Description

When a TPU container is initialized, it's missing some environment variables that JAX needs in order to initialize. In the absence of these variables, JAX attempts to look up their values over the network. This fails as the container may not have direct access to the network.

I have also tried this with network=host to no avail.

Steps to reproduce

Run a jax image with --tpuproxy:

Runsc command:

sudo runsc --debug \
    --debug-log=/home/peyton/tputesting/logs/ \
    --strace \
    --root=/home/peyton/tputesting/runroot \
    --tpuproxy \
    run \
    --bundle=/home/peyton/tputesting \
    my-jax-container

Start the container:

peyton@t1v-n-1f714773-w-0:~/tputesting$ ./start.sh 
# python
Python 3.11.9 (main, Aug 13 2024, 02:18:20) [GCC 12.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.device_count()                                                                                                                      
Failed to get TPU metadata (tpu-env) from instance metadata for variable CHIPS_PER_HOST_BOUNDS: INTERNAL: Couldn't connect to server
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/gcp_metadata_utils.cc:99
learning/45eac/tfrc/runtime/env_var_utils.cc:50

I've built this image by exporting the following dockerfile:

FROM python:3.11

RUN pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

runsc version

runsc -version
runsc version release-20240826.0-81-g4bcbb55fcba5
spec: 1.1.0-rc.1

docker version (if using docker)

I'm not using docker

uname

uname -a Linux t1v-n-1f714773-w-0 5.19.0-1022-gcp #24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux

kubectl (if using Kubernetes)

No response

repo state (if built from source)

git describe release-20240826.0-81-g4bcbb55fc

runsc debug logs (if available)

No response

@pawalt
Copy link
Author

pawalt commented Sep 18, 2024

None of these env variables are present on the host - it populates them dynamically

@manninglucas manninglucas self-assigned this Sep 18, 2024
@manninglucas
Copy link
Contributor

manninglucas commented Sep 18, 2024

Seems like the sandbox cannot reach the GCE metadata server for some reason. I'm surprised that --network=host does not fix this. I will investigate and try to repro myself. If you can share your strace logs that will help me debug the issue as well.

@pawalt
Copy link
Author

pawalt commented Sep 18, 2024

runsc.log.20240918-170153.986977.boot.txt
Here are the strace logs - had to upload a file as it's quite a lot of logs.

@manninglucas
Copy link
Contributor

manninglucas commented Sep 18, 2024

It looks like --network=host is not true in the logs you sent.

From runsc.log.20240918-170153.986977.boot.txt:

D0918 17:01:54.094908       1 config.go:439] Config.Network (--network): sandbox

You'll also want to run runsc in the host's network namespace to allow proper host network access. Right now it looks like you're running in a separate network namespace.

From runsc.log.20240918-170153.986977.boot.txt:

  "linux": {
    "namespaces": [
      {
        "type": "pid"
      },
      {
        "type": "network"
      },
      {
        "type": "ipc"
      },
      {
        "type": "uts"
      },
      {
        "type": "mount"
      }
    ]
  }

FWIW I tested this on my own V5 GCE VM and it worked.

Here was my /dev/vfio directory:

$ stat /dev/vfio/*
  File: /dev/vfio/0
  Size: 0               Blocks: 0          IO Block: 4096   character special file
Device: 5h/5d   Inode: 334         Links: 1     Device type: ef,0
Access: (0666/crw-rw-rw-)  Uid: (    0/    root)   Gid: (    0/    root)
Access: 2024-09-18 21:11:48.884000071 +0000
Modify: 2024-09-18 21:11:48.884000071 +0000
Change: 2024-09-18 21:11:48.888000071 +0000
 Birth: 2024-09-18 21:11:48.884000071 +0000
  File: /dev/vfio/vfio
  Size: 0               Blocks: 0          IO Block: 4096   character special file
Device: 5h/5d   Inode: 141         Links: 1     Device type: a,c4
Access: (0666/crw-rw-rw-)  Uid: (    0/    root)   Gid: (    0/    root)
Access: 2024-09-18 21:11:48.832000070 +0000
Modify: 2024-09-18 21:11:48.832000070 +0000
Change: 2024-09-18 21:11:48.832000070 +0000
 Birth: 2024-09-18 21:11:45.235863372 +0000

Here was my config.json: https://gist.github.com/manninglucas/14de68aab7abaab02cf41553f900782e
My command was: sudo ./runsc --debug --debug-log=debug.txt --network=host --tpuproxy run bash
In the container, I ran:

$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$ python
>>> import jax
>>> jax.device_count()

@manninglucas manninglucas added the area: tpu Issues related to TPU access label Sep 18, 2024
@manninglucas
Copy link
Contributor

I'll just add as a note that these environment variables are piped into the configs automatically in GKE. In GCE you'll either have to add the environment variables to your spec yourself (maybe by fetching them from the metadata server before starting a sandbox) or give the sandbox at least some host network access.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area: tpu Issues related to TPU access type: bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants