Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
Signed-off-by: Rui Qiao <[email protected]>
  • Loading branch information
ruisearch42 committed Sep 18, 2024
1 parent 17f4f38 commit 4e752b2
Show file tree
Hide file tree
Showing 9 changed files with 505 additions and 20 deletions.
294 changes: 286 additions & 8 deletions python/ray/dag/compiled_dag_node.py

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions python/ray/dag/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@

# Feature flag to turn on profiling.
RAY_ADAG_ENABLE_PROFILING = os.environ.get("RAY_ADAG_ENABLE_PROFILING", "0") == "1"

# Feature flag to turn on torch profiling.
RAY_ADAG_ENABLE_TORCH_PROFILING = (
os.environ.get("RAY_ADAG_ENABLE_TORCH_PROFILING", "0") == "1"
)
7 changes: 7 additions & 0 deletions python/ray/dag/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
os.environ.get("RAY_DAG_max_inflight_executions", 10)
)

DEFAULT_OVERLAPPING_FACTOR = int(os.environ.get("RAY_DAG_overlapping_factor", 0))


@DeveloperAPI
@dataclass
Expand Down Expand Up @@ -58,6 +60,10 @@ class DAGContext:
executions is beyond the DAG capacity, the new execution would
be blocked in the first place; therefore, this limit is only
enforced when it is smaller than the DAG capacity.
max_inflight_executions: The maximum number of in-flight executions
that can be submitted before consuming the output.
overlapping_factor: Determines the degree to which the DAG execution
can overlap communication and computation.
"""

execution_timeout: int = DEFAULT_EXECUTION_TIMEOUT_S
Expand All @@ -66,6 +72,7 @@ class DAGContext:
asyncio_max_queue_size: int = DEFAULT_ASYNCIO_MAX_QUEUE_SIZE
max_buffered_results: int = DEFAULT_MAX_BUFFERED_RESULTS
max_inflight_executions: int = DEFAULT_MAX_INFLIGHT_EXECUTIONS
overlapping_factor: int = DEFAULT_OVERLAPPING_FACTOR

@staticmethod
def get_current() -> "DAGContext":
Expand Down
7 changes: 7 additions & 0 deletions python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def experimental_compile(
_asyncio_max_queue_size: Optional[int] = None,
_max_buffered_results: Optional[int] = None,
_max_inflight_executions: Optional[int] = None,
_overlapping_factor: Optional[int] = None,
) -> "ray.dag.CompiledDAG":
"""Compile an accelerated execution path for this DAG.
Expand All @@ -191,6 +192,11 @@ def experimental_compile(
are allowed to be sent to this DAG. Before submitting more requests,
the caller is responsible for calling ray.get to clear finished
in-flight requests.
_overlapping_factor: Controls the degree of overlapping computation and
communication in aDAG execution. If None, the default value is used.
If 0, no overlapping is allowed. If 1, the communication and
computation are overlapped with the minimal degree. No other values
are supported at the moment.
Returns:
A compiled DAG.
Expand Down Expand Up @@ -224,6 +230,7 @@ def experimental_compile(
_asyncio_max_queue_size,
_max_buffered_results,
_max_inflight_executions,
_overlapping_factor,
)

def execute(
Expand Down
96 changes: 94 additions & 2 deletions python/ray/dag/dag_node_operation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import total_ordering
from enum import Enum
from typing import Set, Tuple, List, Dict
from typing import Optional, Set, Tuple, List, Dict
import ray
import heapq
from collections import defaultdict
Expand All @@ -18,12 +18,22 @@ class _DAGNodeOperationType(Enum):
COMPUTE = "COMPUTE"
WRITE = "WRITE"

def __str__(self):
if self == _DAGNodeOperationType.READ:
return "R"
elif self == _DAGNodeOperationType.COMPUTE:
return "C"
elif self == _DAGNodeOperationType.WRITE:
return "W"
assert False, f"Unknown operation type: {self}"


class _DAGNodeOperation:
def __init__(
self,
exec_task_idx: int,
operation_type: _DAGNodeOperationType,
method_name: Optional[str] = None,
):
"""
Args:
Expand All @@ -32,12 +42,42 @@ def __init__(
as bind_index because there may be more tasks bound to an actor
than tasks that appear in the current compiled DAG.
operation_type: The type of operation to perform.
method_name: The name of the method that this operation originates
from. This is only for debugging purposes.
"""
self.exec_task_idx = exec_task_idx
self.type = operation_type
self.method_name = method_name

def next_operation(self):
if self.type == _DAGNodeOperationType.READ:
return _DAGNodeOperation(
self.exec_task_idx, _DAGNodeOperationType.COMPUTE, self.method_name
)
elif self.type == _DAGNodeOperationType.COMPUTE:
return _DAGNodeOperation(
self.exec_task_idx, _DAGNodeOperationType.WRITE, self.method_name
)
else:
raise ValueError(
"Cannot only get next operation for READ or COMPUTE type, "
f"{self.type} is provided."
)

def __repr__(self):
return f"(Task idx: {self.exec_task_idx}, Type: {self.type})"
return f"([{self.exec_task_idx}] {self.method_name} {self.type})"
# return f"(Task idx: {self.exec_task_idx}, Type: {self.type})"

def __str__(self):
return f"([{self.exec_task_idx}] {self.method_name} {self.type})"

def __hash__(self):
return hash((self.exec_task_idx, self.type))

def __eq__(self, other):
# An operation is uniquely identified by its `exec_task_idx` and type.
# `func_name` is only for debugging purposes.
return self.exec_task_idx == other.exec_task_idx and self.type == other.type


@total_ordering
Expand Down Expand Up @@ -367,4 +407,56 @@ def _generate_actor_to_execution_schedule(
)
for _, candidates in actor_to_candidates.items():
assert len(candidates) == 0
for actor_handle, execution_schedule in actor_to_execution_schedule.items():
print(f"Actor {actor_handle._ray_actor_id} schedule: {execution_schedule}")
return actor_to_execution_schedule


def _optimize_execution_schedule(
actor_to_execution_schedule: Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]],
out_of_order_limit: int = 1,
):
"""
Optimize the execution schedule by overlapping computation and communication.
Args:
actor_to_execution_schedule: A dictionary that maps an actor handle to
the execution schedule which is a list of operations to be executed.
out_of_order_limit: The maximum number of out-of-order `receive` operations
allowed.
"""
# TODO: analyze the DAG and turn off overlap optimization when it is
# not supported (yet). For example, currently if a channel requires
# both NCCL and shared memory transport, overlap optimization cannot
# be applied.
if out_of_order_limit == 0:
return actor_to_execution_schedule

actor_to_optimized_schedule: Dict[
"ray.actor.ActorHandle", List[_DAGNodeOperation]
] = defaultdict(list)
for actor, execution_schedule in actor_to_execution_schedule.items():
read_queue = []
other_queue = []
optimized_schedule = []
for operation in execution_schedule:
if operation.type == _DAGNodeOperationType.READ:
read_queue.append(operation)
else:
other_queue.append(operation)
out_of_order_quota = out_of_order_limit + 1
while other_queue:
other_op = other_queue[0]
if read_queue:
if out_of_order_quota > 0:
optimized_schedule.append(read_queue.pop(0))
out_of_order_quota -= 1
else:
optimized_schedule.append(other_queue.pop(0))
if other_op.type == _DAGNodeOperationType.WRITE:
out_of_order_quota += 1
else:
optimized_schedule.append(other_queue.pop(0))
actor_to_optimized_schedule[actor] = optimized_schedule
print(f"Actor {actor._ray_actor_id} optimized schedule:", optimized_schedule)
return actor_to_optimized_schedule
48 changes: 48 additions & 0 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import sys
from typing import List, Optional, Tuple
from ray.dag.output_node import MultiOutputNode
from ray.experimental.channel.gpu_communicator import (
GPUCommunicator,
TorchTensorAllocator,
Expand Down Expand Up @@ -68,6 +69,7 @@ def send_with_tuple_args(self, args):
return torch.ones(shape, dtype=dtype, device=self.device) * value

def recv(self, tensor):
# print(f"{tensor=}")
# Check that tensor got loaded to the correct device.
assert tensor.device == self.device
return (tensor[0].item(), tensor.shape, tensor.dtype)
Expand Down Expand Up @@ -269,6 +271,52 @@ def test_torch_tensor_nccl(ray_start_regular):
# ray.get(receiver.ping.remote())


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_nccl_overlap(ray_start_regular):
if not USE_GPU:
pytest.skip("NCCL tests require GPUs")

assert (
sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 2
), "This test requires at least 3 GPUs"

actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1)

sender1 = actor_cls.remote()
sender2 = actor_cls.remote()
receiver = actor_cls.remote()
print(f"{receiver=}")

shape = (100000,)
dtype = torch.float16

with InputNode() as inp:
branch1 = sender1.send.bind(shape, dtype, inp)

branch1 = branch1.with_type_hint(
TorchTensorType(shape, dtype, transport="nccl", _direct_return=True)
)
branch1 = receiver.recv.bind(branch1)

branch2 = sender2.send.bind(shape, dtype, inp)
branch2 = branch2.with_type_hint(
TorchTensorType(shape, dtype, transport="nccl", _direct_return=True)
)
branch2 = receiver.recv.bind(branch2)
dag = MultiOutputNode([branch1, branch2])

# Test normal execution.
compiled_dag = dag.experimental_compile(_overlapping_factor=1)

for i in range(3):
ref = compiled_dag.execute(i)
result = ray.get(ref)
# print(f"{result=}")
assert result == [(i, shape, dtype), (i, shape, dtype)]

compiled_dag.teardown()


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_nccl_dynamic(ray_start_regular):
if not USE_GPU:
Expand Down
13 changes: 12 additions & 1 deletion python/ray/experimental/channel/gpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
TorchTensorAllocator = Callable[[Tuple[int], "torch.dtype"], "torch.Tensor"]


@DeveloperAPI
class Event(ABC):
@abstractmethod
def record():
raise NotImplementedError

@abstractmethod
def synchronize():
raise NotImplementedError


@DeveloperAPI
class GPUCommunicator(ABC):
"""
Expand Down Expand Up @@ -88,7 +99,7 @@ def recv(
dtype: "torch.dtype",
peer_rank: int,
allocator: Optional[TorchTensorAllocator] = None,
) -> "torch.Tensor":
) -> Tuple["torch.Tensor", Optional[Event]]:
"""
Receive a torch.Tensor from a peer and synchronize.
Expand Down
43 changes: 38 additions & 5 deletions python/ray/experimental/channel/nccl_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ray
from ray.exceptions import RayChannelError
from ray.experimental.channel.gpu_communicator import (
Event,
GPUCommunicator,
TorchTensorAllocator,
)
Expand All @@ -20,6 +21,20 @@
logger = logging.getLogger(__name__)


class _NcclEvent(Event):
def __init__(self):
import cupy as cp

self._event = cp.cuda.Event()

def record(self, stream):
self._event.record(stream)

def synchronize(self):
# TODO
pass


class _NcclGroup(GPUCommunicator):
"""
Represents an actor's NCCL communicator. This is the default NCCL communicator
Expand Down Expand Up @@ -94,6 +109,7 @@ def __init__(
assert rank is not None, "NCCL actor has no rank assigned"

import cupy as cp
import torch

from ray.air._internal import torch_utils

Expand All @@ -103,6 +119,19 @@ def __init__(
cuda_stream, device_id=device.index
)

send_stream = torch.cuda.Stream()
recv_stream = torch.cuda.Stream()
self._send_stream = cp.cuda.ExternalStream(
send_stream.cuda_stream, device_id=device.index
)
self._recv_stream = cp.cuda.ExternalStream(
recv_stream.cuda_stream, device_id=device.index
)
print(
f"Inited streams, send={send_stream.cuda_stream}, "
f"recv={recv_stream.cuda_stream}"
)

self._closed = False

def initialize(self, rank: int) -> None:
Expand Down Expand Up @@ -163,7 +192,7 @@ def send(self, value: "torch.Tensor", peer_rank: int) -> None:
value.numel(),
self.nccl_util.get_nccl_tensor_dtype(value),
peer_rank,
self._cuda_stream.ptr,
self._send_stream.ptr,
)

def recv(
Expand All @@ -172,7 +201,7 @@ def recv(
dtype: "torch.dtype",
peer_rank: int,
allocator=Optional[TorchTensorAllocator],
) -> "torch.Tensor":
) -> Tuple["torch.Tensor", "cp.cuda.Event"]:
"""
Receive a torch.Tensor from a peer and synchronize the current stream.
Expand All @@ -188,22 +217,26 @@ def recv(
raise RayChannelError("NCCL group has been destroyed.")
assert allocator is not None, "NCCL group requires a tensor allocator"
buf = allocator(shape, dtype)
import cupy as cp

event = cp.cuda.Event()
self._comm.recv(
self.nccl_util.get_tensor_ptr(buf),
buf.numel(),
self.nccl_util.get_nccl_tensor_dtype(buf),
peer_rank,
self._cuda_stream.ptr,
self._recv_stream.ptr,
)
event.record(self._recv_stream)

# Buffer values are undefined if NCCL ops are aborted. Therefore, we
# need to synchronize here and check that the channel is still open to
# ensure that the receive buffer is valid.
# TODO(swang): Avoid CUDA synchronization.
self._cuda_stream.synchronize()
# self._cuda_stream.synchronize()
if self._closed:
raise RayChannelError("NCCL group has been destroyed.")
return buf
return buf, event

def destroy(self) -> None:
"""
Expand Down
Loading

0 comments on commit 4e752b2

Please sign in to comment.