diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index fe7b693f7787b..3f0f453ddce42 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -45,6 +45,7 @@ _DAGOperationGraphNode, _build_dag_node_operation_graph, _generate_actor_to_execution_schedule, + _optimize_execution_schedule, ) from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -98,7 +99,6 @@ def do_exec_tasks( actor. This runs an infinite loop to execute each _DAGNodeOperation in the order specified by the schedule. It exits only if the actor dies or an exception is thrown. - Args: tasks: the executable tasks corresponding to the actor methods. schedule: A list of _DAGNodeOperation that should be executed in order. @@ -173,6 +173,108 @@ def do_profile_tasks( raise +@DeveloperAPI +def do_stream_tasks( + self, + tasks: List["ExecutableTask"], + schedule: List[_DAGNodeOperation], +) -> None: + """A generic actor method to begin executing the operations belonging to an + actor. This runs an infinite loop to execute each _DAGNodeOperation in the + order specified by the schedule. It exits only if the actor dies or an + exception is thrown. + + Args: + tasks: the executable tasks corresponding to the actor methods. + schedule: A list of _DAGNodeOperation that should be executed in order. + """ + try: + import cupy as cp + import torch + from ray.air._internal import torch_utils + + device = torch_utils.get_devices()[0] + exec_stream = cp.cuda.ExternalStream( + torch.cuda.Stream().cuda_stream, device_id=device.index + ) + + for task in tasks: + task.prepare() + + done = False + while True: + if done: + break + for operation in schedule: + done = tasks[operation.exec_task_idx].stream_operation( + self, operation, exec_stream + ) + if done: + break + except Exception: + logging.exception("Compiled DAG task exited with exception") + raise + + +@DeveloperAPI +def do_profile_stream_tasks( + self, + tasks: List["ExecutableTask"], + schedule: List[_DAGNodeOperation], +) -> None: + """A generic actor method to begin executing the operations belonging to an + actor. This runs an infinite loop to execute each _DAGNodeOperation in the + order specified by the schedule. It exits only if the actor dies or an + exception is thrown. + + Args: + tasks: the executable tasks corresponding to the actor methods. + schedule: A list of _DAGNodeOperation that should be executed in order. + """ + try: + import cupy as cp + import torch + from ray.air._internal import torch_utils + + device = torch_utils.get_devices()[0] + exec_stream = cp.cuda.ExternalStream( + torch.cuda.Stream().cuda_stream, device_id=device.index + ) + + for task in tasks: + task.prepare() + + import os + + pid = os.getpid() + + done = False + import torch.profiler + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + profile_memory=True, + with_stack=True, + ) as profiler: + while True: + if done: + break + for operation in schedule: + done = tasks[operation.exec_task_idx].stream_operation( + self, operation, exec_stream + ) + if done: + break + profiler.export_chrome_trace(f"adag-proc-{pid}.json") + except Exception: + logging.exception("Compiled DAG task exited with exception") + raise + + @DeveloperAPI def do_cancel_executable_tasks(self, tasks: List["ExecutableTask"]) -> None: for task in tasks: @@ -355,6 +457,14 @@ def __init__( # and the result of a COMPUTE operation will be used by a WRITE operation. self._intermediate_buffer: Any = None + # Store the intermediate result for streamed execution. + # The key is the operation, and the value is the input to this operation, + # i.e., the result from the last operation. There are only two cases: + # if the operation is COMPUTE, the value is the result from the READ with + # the same exec_task_idx; if the operation is WRITE, the value is the result + # from the COMPUTE with the same exec_task_idx. + self._stream_buffer: Dict[_DAGNodeOperation, Any] = {} + def cancel(self): """ Close all the input channels and the output channel. The exact behavior @@ -425,7 +535,6 @@ def _compute(self, class_handle) -> bool: class_handle: An instance of the class to which the actor belongs. For example, the type of `class_handle` is if the actor belongs to the `class Worker` class. - Returns: True if system error occurs and exit the loop; otherwise, False. """ @@ -485,7 +594,6 @@ def exec_operation( class_handle: The handle of the class to which the actor belongs. op_type: The type of the operation. Possible types are READ, COMPUTE, and WRITE. - Returns: True if the next operation should not be executed; otherwise, False. """ @@ -496,6 +604,145 @@ def exec_operation( elif op_type == _DAGNodeOperationType.WRITE: return self._write() + def set_stream_buffer(self, op: _DAGNodeOperation, data: Any): + """ + Store the intermediate result of a READ or COMPUTE operation + in the stream buffer. + + Args: + op: The operation that generates the intermediate result. + data: The intermediate result of a READ or COMPUTE operation. + """ + self._stream_buffer[op.next_operation()] = data + + def reset_stream_buffer(self, op: _DAGNodeOperation) -> Any: + """ + Retrieve the intermediate result of a READ or COMPUTE operation, + and clear the entry from the buffer. + + Returns: + The intermediate result of a READ or COMPUTE operation. + """ + return self._stream_buffer.pop(op) + + def _stream_read(self, op: _DAGNodeOperation) -> bool: + """ + Stream read input data from upstream DAG nodes and cache the + intermediate result. + + Args: + op: The READ operation. + + Returns: + True if system error occurs and exit the loop; otherwise, False. + """ + exit = False + try: + input_data = self.input_reader.read() + self.set_stream_buffer(op, input_data) + except RayChannelError: + # Channel closed. Exit the loop. + exit = True + return exit + + def _stream_compute(self, op: _DAGNodeOperation, exec_stream, class_handle) -> bool: + """ + Retrieve the intermediate result from the READ operation and perform the + computation. Then, cache the new intermediate result. + + Args: + op: The compute operation. + exec_stream: The CUDA stream to execute the compute operation. + class_handle: An instance of the class to which the actor belongs. For + example, the type of `class_handle` is if the + actor belongs to the `class Worker` class. + + Returns: + True if system error occurs and exit the loop; otherwise, False. + """ + input_data = self.reset_stream_buffer(op) + method = getattr(class_handle, self.method_name) + try: + _process_return_vals(input_data, return_single_output=False) + except Exception as exc: + # Previous task raised an application-level exception. + # Propagate it and skip the actual task. We don't need to wrap the + # exception in a RayTaskError here because it has already been wrapped + # by the previous task. + self.set_stream_buffer(exc) + return False + + channel_results = [] + for entry in input_data: + if isinstance(entry, tuple): + channel_result, event = entry + if event: + event.synchronize() + else: + channel_result = entry + channel_results.append(channel_result) + + resolved_inputs = [] + for task_input in self.task_inputs: + resolved_inputs.append(task_input.resolve(channel_results)) + + import cupy as cp + + exec_event = cp.cuda.Event() + with exec_stream: + try: + output_val = method(*resolved_inputs, **self.resolved_kwargs) + except Exception as exc: + output_val = _wrap_exception(exc) + exec_event.record() + + self.set_stream_buffer(op, (output_val, exec_event)) + return False + + def _stream_write(self, op: _DAGNodeOperation) -> bool: + """ + Retrieve the intermediate result from the COMPUTE operation and write to its + downstream DAG nodes. The caller must ensure that the last operation executed + is COMPUTE so that the function retrieves the correct intermediate result. + + Returns: + True if system error occurs and exit the loop; otherwise, False. + """ + output_val, exec_event = self.reset_stream_buffer(op) + exit = False + exec_event.synchronize() + try: + self.output_writer.write(output_val) + except RayChannelError: + # Channel closed. Exit the loop. + exit = True + return exit + + def stream_operation( + self, class_handle, op: _DAGNodeOperation, exec_stream + ) -> bool: + """ + An ExecutableTask corresponds to a DAGNode. It consists of three + operations: READ, COMPUTE, and WRITE, which should be executed in + order to ensure that each operation can read the correct intermediate + result. + + Args: + class_handle: The handle of the class to which the actor belongs. + op_type: The type of the operation. Possible types are READ, + COMPUTE, and WRITE. + + Returns: + True if the next operation should not be executed; otherwise, False. + """ + op_type: _DAGNodeOperationType = op.type + if op_type == _DAGNodeOperationType.READ: + return self._stream_read(op) + elif op_type == _DAGNodeOperationType.COMPUTE: + return self._stream_compute(op, exec_stream, class_handle) + elif op_type == _DAGNodeOperationType.WRITE: + return self._stream_write(op) + @dataclass class _ExecutableTaskRecord: @@ -548,6 +795,7 @@ def __init__( asyncio_max_queue_size: Optional[int] = None, max_buffered_results: Optional[int] = None, max_inflight_executions: Optional[int] = None, + overlapping_factor: Optional[float] = None, ): """ Args: @@ -580,6 +828,11 @@ def __init__( are allowed to be sent to this DAG. Before submitting more requests, the caller is responsible for calling ray.get to get the result, otherwise, RayAdagCapacityExceeded is raised. + overlapping_factor: Controls the degree of overlapping computation and + communication in aDAG execution. If None, the default value is used. + If 0, no overlapping is allowed. If 1, the communication and + computation are overlapped with the minimal degree. No other values + are supported at the moment. Returns: Channel: A wrapper around ray.ObjectRef. @@ -605,6 +858,9 @@ def __init__( self._buffer_size_bytes: Optional[int] = buffer_size_bytes if self._buffer_size_bytes is None: self._buffer_size_bytes = ctx.buffer_size_bytes + self._overlapping_factor: Optional[float] = overlapping_factor + if self._overlapping_factor is None: + self._overlapping_factor = ctx.overlapping_factor self._default_type_hint: ChannelOutputType = SharedMemoryType( self._buffer_size_bytes, @@ -1335,10 +1591,18 @@ def _get_or_compile( self.actor_to_executable_tasks[actor_handle] = executable_tasks # Build an execution schedule for each actor - from ray.dag.constants import RAY_ADAG_ENABLE_PROFILING + from ray.dag.constants import ( + RAY_ADAG_ENABLE_PROFILING, + RAY_ADAG_ENABLE_TORCH_PROFILING, + ) if RAY_ADAG_ENABLE_PROFILING: exec_task_func = do_profile_tasks + elif self._overlapping_factor: + if RAY_ADAG_ENABLE_TORCH_PROFILING: + exec_task_func = do_profile_stream_tasks + else: + exec_task_func = do_stream_tasks else: exec_task_func = do_exec_tasks @@ -1432,23 +1696,30 @@ def _generate_dag_operation_graph_node( # and WRITE. Each _DAGOperationGraphNode has a _DAGNodeOperation. task_index = exec_task.task_idx dag_node = self.idx_to_task[task_index].dag_node + method_name = exec_task.method_name actor_handle = dag_node._get_actor_handle() requires_nccl = dag_node.type_hint.requires_nccl() read_node = _DAGOperationGraphNode( - _DAGNodeOperation(exec_task_idx, _DAGNodeOperationType.READ), + _DAGNodeOperation( + exec_task_idx, _DAGNodeOperationType.READ, method_name + ), task_index, actor_handle, requires_nccl, ) compute_node = _DAGOperationGraphNode( - _DAGNodeOperation(exec_task_idx, _DAGNodeOperationType.COMPUTE), + _DAGNodeOperation( + exec_task_idx, _DAGNodeOperationType.COMPUTE, method_name + ), task_index, actor_handle, requires_nccl, ) write_node = _DAGOperationGraphNode( - _DAGNodeOperation(exec_task_idx, _DAGNodeOperationType.WRITE), + _DAGNodeOperation( + exec_task_idx, _DAGNodeOperationType.WRITE, method_name + ), task_index, actor_handle, requires_nccl, @@ -1494,7 +1765,12 @@ def _build_execution_schedule( ) # Step 2: Generate an execution schedule for each actor using topological sort actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) - return actor_to_execution_schedule + + # Step 3: Optimize the execution schedule based on overlapping factor + actor_to_optimized_schedule = _optimize_execution_schedule( + actor_to_execution_schedule, self._overlapping_factor + ) + return actor_to_optimized_schedule def _detect_deadlock(self) -> bool: """ @@ -2073,6 +2349,7 @@ def build_compiled_dag_from_ray_dag( asyncio_max_queue_size: Optional[int] = None, max_buffered_results: Optional[int] = None, max_inflight_executions: Optional[int] = None, + overlapping_factor: Optional[int] = None, ) -> "CompiledDAG": compiled_dag = CompiledDAG( execution_timeout, @@ -2081,6 +2358,7 @@ def build_compiled_dag_from_ray_dag( asyncio_max_queue_size, max_buffered_results, max_inflight_executions, + overlapping_factor, ) def _build_compiled_dag(node): diff --git a/python/ray/dag/constants.py b/python/ray/dag/constants.py index f1077c7b51040..268c3a1253b88 100644 --- a/python/ray/dag/constants.py +++ b/python/ray/dag/constants.py @@ -16,3 +16,8 @@ # Feature flag to turn on profiling. RAY_ADAG_ENABLE_PROFILING = os.environ.get("RAY_ADAG_ENABLE_PROFILING", "0") == "1" + +# Feature flag to turn on torch profiling. +RAY_ADAG_ENABLE_TORCH_PROFILING = ( + os.environ.get("RAY_ADAG_ENABLE_TORCH_PROFILING", "0") == "1" +) diff --git a/python/ray/dag/context.py b/python/ray/dag/context.py index 03eb8915d13b3..4bd0261a7c1eb 100644 --- a/python/ray/dag/context.py +++ b/python/ray/dag/context.py @@ -25,6 +25,8 @@ os.environ.get("RAY_DAG_max_inflight_executions", 10) ) +DEFAULT_OVERLAPPING_FACTOR = int(os.environ.get("RAY_DAG_overlapping_factor", 0)) + @DeveloperAPI @dataclass @@ -58,6 +60,10 @@ class DAGContext: executions is beyond the DAG capacity, the new execution would be blocked in the first place; therefore, this limit is only enforced when it is smaller than the DAG capacity. + max_inflight_executions: The maximum number of in-flight executions + that can be submitted before consuming the output. + overlapping_factor: Determines the degree to which the DAG execution + can overlap communication and computation. """ execution_timeout: int = DEFAULT_EXECUTION_TIMEOUT_S @@ -66,6 +72,7 @@ class DAGContext: asyncio_max_queue_size: int = DEFAULT_ASYNCIO_MAX_QUEUE_SIZE max_buffered_results: int = DEFAULT_MAX_BUFFERED_RESULTS max_inflight_executions: int = DEFAULT_MAX_INFLIGHT_EXECUTIONS + overlapping_factor: int = DEFAULT_OVERLAPPING_FACTOR @staticmethod def get_current() -> "DAGContext": diff --git a/python/ray/dag/dag_node.py b/python/ray/dag/dag_node.py index 031709382edaa..9ec93266ed784 100644 --- a/python/ray/dag/dag_node.py +++ b/python/ray/dag/dag_node.py @@ -167,6 +167,7 @@ def experimental_compile( _asyncio_max_queue_size: Optional[int] = None, _max_buffered_results: Optional[int] = None, _max_inflight_executions: Optional[int] = None, + _overlapping_factor: Optional[int] = None, ) -> "ray.dag.CompiledDAG": """Compile an accelerated execution path for this DAG. @@ -191,6 +192,11 @@ def experimental_compile( are allowed to be sent to this DAG. Before submitting more requests, the caller is responsible for calling ray.get to clear finished in-flight requests. + _overlapping_factor: Controls the degree of overlapping computation and + communication in aDAG execution. If None, the default value is used. + If 0, no overlapping is allowed. If 1, the communication and + computation are overlapped with the minimal degree. No other values + are supported at the moment. Returns: A compiled DAG. @@ -224,6 +230,7 @@ def experimental_compile( _asyncio_max_queue_size, _max_buffered_results, _max_inflight_executions, + _overlapping_factor, ) def execute( diff --git a/python/ray/dag/dag_node_operation.py b/python/ray/dag/dag_node_operation.py index dcd55ae2c702c..bf9019b1bd45c 100644 --- a/python/ray/dag/dag_node_operation.py +++ b/python/ray/dag/dag_node_operation.py @@ -1,6 +1,6 @@ from functools import total_ordering from enum import Enum -from typing import Set, Tuple, List, Dict +from typing import Optional, Set, Tuple, List, Dict import ray import heapq from collections import defaultdict @@ -18,12 +18,22 @@ class _DAGNodeOperationType(Enum): COMPUTE = "COMPUTE" WRITE = "WRITE" + def __str__(self): + if self == _DAGNodeOperationType.READ: + return "R" + elif self == _DAGNodeOperationType.COMPUTE: + return "C" + elif self == _DAGNodeOperationType.WRITE: + return "W" + assert False, f"Unknown operation type: {self}" + class _DAGNodeOperation: def __init__( self, exec_task_idx: int, operation_type: _DAGNodeOperationType, + method_name: Optional[str] = None, ): """ Args: @@ -32,12 +42,42 @@ def __init__( as bind_index because there may be more tasks bound to an actor than tasks that appear in the current compiled DAG. operation_type: The type of operation to perform. + method_name: The name of the method that this operation originates + from. This is only for debugging purposes. """ self.exec_task_idx = exec_task_idx self.type = operation_type + self.method_name = method_name + + def next_operation(self): + if self.type == _DAGNodeOperationType.READ: + return _DAGNodeOperation( + self.exec_task_idx, _DAGNodeOperationType.COMPUTE, self.method_name + ) + elif self.type == _DAGNodeOperationType.COMPUTE: + return _DAGNodeOperation( + self.exec_task_idx, _DAGNodeOperationType.WRITE, self.method_name + ) + else: + raise ValueError( + "Cannot only get next operation for READ or COMPUTE type, " + f"{self.type} is provided." + ) def __repr__(self): - return f"(Task idx: {self.exec_task_idx}, Type: {self.type})" + return f"([{self.exec_task_idx}] {self.method_name} {self.type})" + # return f"(Task idx: {self.exec_task_idx}, Type: {self.type})" + + def __str__(self): + return f"([{self.exec_task_idx}] {self.method_name} {self.type})" + + def __hash__(self): + return hash((self.exec_task_idx, self.type)) + + def __eq__(self, other): + # An operation is uniquely identified by its `exec_task_idx` and type. + # `func_name` is only for debugging purposes. + return self.exec_task_idx == other.exec_task_idx and self.type == other.type @total_ordering @@ -367,4 +407,56 @@ def _generate_actor_to_execution_schedule( ) for _, candidates in actor_to_candidates.items(): assert len(candidates) == 0 + for actor_handle, execution_schedule in actor_to_execution_schedule.items(): + print(f"Actor {actor_handle._ray_actor_id} schedule: {execution_schedule}") return actor_to_execution_schedule + + +def _optimize_execution_schedule( + actor_to_execution_schedule: Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]], + out_of_order_limit: int = 1, +): + """ + Optimize the execution schedule by overlapping computation and communication. + + Args: + actor_to_execution_schedule: A dictionary that maps an actor handle to + the execution schedule which is a list of operations to be executed. + out_of_order_limit: The maximum number of out-of-order `receive` operations + allowed. + """ + # TODO: analyze the DAG and turn off overlap optimization when it is + # not supported (yet). For example, currently if a channel requires + # both NCCL and shared memory transport, overlap optimization cannot + # be applied. + if out_of_order_limit == 0: + return actor_to_execution_schedule + + actor_to_optimized_schedule: Dict[ + "ray.actor.ActorHandle", List[_DAGNodeOperation] + ] = defaultdict(list) + for actor, execution_schedule in actor_to_execution_schedule.items(): + read_queue = [] + other_queue = [] + optimized_schedule = [] + for operation in execution_schedule: + if operation.type == _DAGNodeOperationType.READ: + read_queue.append(operation) + else: + other_queue.append(operation) + out_of_order_quota = out_of_order_limit + 1 + while other_queue: + other_op = other_queue[0] + if read_queue: + if out_of_order_quota > 0: + optimized_schedule.append(read_queue.pop(0)) + out_of_order_quota -= 1 + else: + optimized_schedule.append(other_queue.pop(0)) + if other_op.type == _DAGNodeOperationType.WRITE: + out_of_order_quota += 1 + else: + optimized_schedule.append(other_queue.pop(0)) + actor_to_optimized_schedule[actor] = optimized_schedule + print(f"Actor {actor._ray_actor_id} optimized schedule:", optimized_schedule) + return actor_to_optimized_schedule diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index edb089440d8da..c4314e4bd6afe 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -4,6 +4,7 @@ import re import sys from typing import List, Optional, Tuple +from ray.dag.output_node import MultiOutputNode from ray.experimental.channel.gpu_communicator import ( GPUCommunicator, TorchTensorAllocator, @@ -68,6 +69,7 @@ def send_with_tuple_args(self, args): return torch.ones(shape, dtype=dtype, device=self.device) * value def recv(self, tensor): + # print(f"{tensor=}") # Check that tensor got loaded to the correct device. assert tensor.device == self.device return (tensor[0].item(), tensor.shape, tensor.dtype) @@ -269,6 +271,52 @@ def test_torch_tensor_nccl(ray_start_regular): # ray.get(receiver.ping.remote()) +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_torch_tensor_nccl_overlap(ray_start_regular): + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + assert ( + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 2 + ), "This test requires at least 3 GPUs" + + actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) + + sender1 = actor_cls.remote() + sender2 = actor_cls.remote() + receiver = actor_cls.remote() + print(f"{receiver=}") + + shape = (100000,) + dtype = torch.float16 + + with InputNode() as inp: + branch1 = sender1.send.bind(shape, dtype, inp) + + branch1 = branch1.with_type_hint( + TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) + ) + branch1 = receiver.recv.bind(branch1) + + branch2 = sender2.send.bind(shape, dtype, inp) + branch2 = branch2.with_type_hint( + TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) + ) + branch2 = receiver.recv.bind(branch2) + dag = MultiOutputNode([branch1, branch2]) + + # Test normal execution. + compiled_dag = dag.experimental_compile(_overlapping_factor=1) + + for i in range(3): + ref = compiled_dag.execute(i) + result = ray.get(ref) + # print(f"{result=}") + assert result == [(i, shape, dtype), (i, shape, dtype)] + + compiled_dag.teardown() + + @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) def test_torch_tensor_nccl_dynamic(ray_start_regular): if not USE_GPU: diff --git a/python/ray/experimental/channel/gpu_communicator.py b/python/ray/experimental/channel/gpu_communicator.py index e6bc2fccdb2db..321fca118e507 100644 --- a/python/ray/experimental/channel/gpu_communicator.py +++ b/python/ray/experimental/channel/gpu_communicator.py @@ -13,6 +13,17 @@ TorchTensorAllocator = Callable[[Tuple[int], "torch.dtype"], "torch.Tensor"] +@DeveloperAPI +class Event(ABC): + @abstractmethod + def record(): + raise NotImplementedError + + @abstractmethod + def synchronize(): + raise NotImplementedError + + @DeveloperAPI class GPUCommunicator(ABC): """ @@ -88,7 +99,7 @@ def recv( dtype: "torch.dtype", peer_rank: int, allocator: Optional[TorchTensorAllocator] = None, - ) -> "torch.Tensor": + ) -> Tuple["torch.Tensor", Optional[Event]]: """ Receive a torch.Tensor from a peer and synchronize. diff --git a/python/ray/experimental/channel/nccl_group.py b/python/ray/experimental/channel/nccl_group.py index dcdfef10f163d..c58f90fba0003 100644 --- a/python/ray/experimental/channel/nccl_group.py +++ b/python/ray/experimental/channel/nccl_group.py @@ -5,6 +5,7 @@ import ray from ray.exceptions import RayChannelError from ray.experimental.channel.gpu_communicator import ( + Event, GPUCommunicator, TorchTensorAllocator, ) @@ -20,6 +21,20 @@ logger = logging.getLogger(__name__) +class _NcclEvent(Event): + def __init__(self): + import cupy as cp + + self._event = cp.cuda.Event() + + def record(self, stream): + self._event.record(stream) + + def synchronize(self): + # TODO + pass + + class _NcclGroup(GPUCommunicator): """ Represents an actor's NCCL communicator. This is the default NCCL communicator @@ -94,6 +109,7 @@ def __init__( assert rank is not None, "NCCL actor has no rank assigned" import cupy as cp + import torch from ray.air._internal import torch_utils @@ -103,6 +119,19 @@ def __init__( cuda_stream, device_id=device.index ) + send_stream = torch.cuda.Stream() + recv_stream = torch.cuda.Stream() + self._send_stream = cp.cuda.ExternalStream( + send_stream.cuda_stream, device_id=device.index + ) + self._recv_stream = cp.cuda.ExternalStream( + recv_stream.cuda_stream, device_id=device.index + ) + print( + f"Inited streams, send={send_stream.cuda_stream}, " + f"recv={recv_stream.cuda_stream}" + ) + self._closed = False def initialize(self, rank: int) -> None: @@ -163,7 +192,7 @@ def send(self, value: "torch.Tensor", peer_rank: int) -> None: value.numel(), self.nccl_util.get_nccl_tensor_dtype(value), peer_rank, - self._cuda_stream.ptr, + self._send_stream.ptr, ) def recv( @@ -172,7 +201,7 @@ def recv( dtype: "torch.dtype", peer_rank: int, allocator=Optional[TorchTensorAllocator], - ) -> "torch.Tensor": + ) -> Tuple["torch.Tensor", "cp.cuda.Event"]: """ Receive a torch.Tensor from a peer and synchronize the current stream. @@ -188,22 +217,26 @@ def recv( raise RayChannelError("NCCL group has been destroyed.") assert allocator is not None, "NCCL group requires a tensor allocator" buf = allocator(shape, dtype) + import cupy as cp + + event = cp.cuda.Event() self._comm.recv( self.nccl_util.get_tensor_ptr(buf), buf.numel(), self.nccl_util.get_nccl_tensor_dtype(buf), peer_rank, - self._cuda_stream.ptr, + self._recv_stream.ptr, ) + event.record(self._recv_stream) # Buffer values are undefined if NCCL ops are aborted. Therefore, we # need to synchronize here and check that the channel is still open to # ensure that the receive buffer is valid. # TODO(swang): Avoid CUDA synchronization. - self._cuda_stream.synchronize() + # self._cuda_stream.synchronize() if self._closed: raise RayChannelError("NCCL group has been destroyed.") - return buf + return buf, event def destroy(self) -> None: """ diff --git a/python/ray/experimental/channel/torch_tensor_nccl_channel.py b/python/ray/experimental/channel/torch_tensor_nccl_channel.py index c97f9913838e8..5eacc0d7e6090 100644 --- a/python/ray/experimental/channel/torch_tensor_nccl_channel.py +++ b/python/ray/experimental/channel/torch_tensor_nccl_channel.py @@ -4,6 +4,8 @@ from types import ModuleType from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +import cupy as cp + import ray import ray.util.serialization from ray.experimental.channel import ChannelContext @@ -383,7 +385,10 @@ def write( def read( self, timeout: Optional[float] = None - ) -> Union["torch.Tensor", List["torch.Tensor"]]: + ) -> Union[ + Tuple["torch.Tensor", "cp.cuda.Event"], + List[Tuple["torch.Tensor", "cp.cuda.Event"]], + ]: if self._meta_channel is not None: meta = self._meta_channel.read() else: @@ -397,7 +402,7 @@ def read( self._torch_tensor_allocator, ) - bufs: List["torch.Tensor"] = [] + bufs: List[Tuple["torch.Tensor", "cp.cuda.Event"]] = [] for typ in meta: buf = self._nccl_group.recv( typ._shape, @@ -406,8 +411,6 @@ def read( self._torch_tensor_allocator, ) bufs.append(buf) - # TODO: Sync CUDA stream after receiving all tensors, instead of after - # each tensor. return bufs def close(self) -> None: @@ -444,6 +447,7 @@ def _do_init_nccl_group( ), "Actors participating in NCCL group must have at least one GPU assigned" ctx = ChannelContext.get_current() + print(f"_do_init_nccl_group() with {custom_nccl_group=}") if custom_nccl_group is not None: custom_nccl_group.initialize(rank) ctx.nccl_groups[group_id] = custom_nccl_group