Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
Signed-off-by: Rui Qiao <[email protected]>
  • Loading branch information
ruisearch42 committed Sep 17, 2024
1 parent f40313b commit ccb561c
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 43 deletions.
101 changes: 70 additions & 31 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from collections import defaultdict, deque
import copy
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Tuple, Union, Optional, Set
import logging
Expand Down Expand Up @@ -107,6 +108,16 @@ def do_exec_tasks(
schedule: A list of _DAGNodeOperation that should be executed in order.
"""
try:
import cupy as cp
import torch
from ray.air._internal import torch_utils

device = torch_utils.get_devices()[0]
exec_stream = torch.cuda.Stream()
self.exec_stream = cp.cuda.ExternalStream(
exec_stream.cuda_stream, device_id=device.index
)

for task in tasks:
task.prepare()

Expand All @@ -116,9 +127,7 @@ def do_exec_tasks(
break
for operation in schedule:
print("SANG-TODO operation: ", operation)
done = tasks[operation.exec_task_idx].exec_operation(
self, operation.type
)
done = tasks[operation.exec_task_idx].exec_operation(self, operation)
if done:
break
except Exception:
Expand Down Expand Up @@ -404,7 +413,7 @@ def __init__(
# Store the intermediate result of a READ or COMPUTE operation.
# The result of a READ operation will be used by a COMPUTE operation,
# and the result of a COMPUTE operation will be used by a WRITE operation.
self._intermediate_buffer: Any = None
self._intermediate_buffer: Dict[_DAGNodeOperation, Any] = {}

def cancel(self):
"""
Expand All @@ -426,46 +435,53 @@ def prepare(self):
self.input_reader.start()
self.output_writer.start()

def set_intermediate_buffer(self, data: Any):
def set_intermediate_buffer(self, op: _DAGNodeOperation, data: Any):
"""
Store the intermediate result of a READ or COMPUTE operation.
Args:
data: The intermediate result of a READ or COMPUTE operation.
"""
assert self._intermediate_buffer is None
self._intermediate_buffer = data
if op.type == _DAGNodeOperationType.READ:
key = copy.deepcopy(op)
key.type = _DAGNodeOperationType.COMPUTE
elif op.type == _DAGNodeOperationType.COMPUTE:
key = copy.deepcopy(op)
key.type = _DAGNodeOperationType.WRITE
else:
assert False, f"Invalid operation type: {op.type}"
self._intermediate_buffer[key] = data
# logger.info(f"Setting {key} to {data}", stack_info=True)

def reset_intermediate_buffer(self) -> Any:
def reset_intermediate_buffer(self, op: _DAGNodeOperation) -> Any:
"""
Retrieve the intermediate result of a READ or COMPUTE operation,
and reset the intermediate buffer to None.
Returns:
The intermediate result of a READ or COMPUTE operation.
"""
data = self._intermediate_buffer
self._intermediate_buffer = None
return data
logger.info(f"{self._intermediate_buffer=}")
return self._intermediate_buffer.pop(op)

def _read(self) -> bool:
def _read(self, op: _DAGNodeOperation) -> bool:
"""
Read input data from upstream DAG nodes and cache the intermediate result.
Returns:
True if system error occurs and exit the loop; otherwise, False.
"""
assert self._intermediate_buffer is None
# assert self._intermediate_buffer is None
exit = False
try:
input_data = self.input_reader.read()
self.set_intermediate_buffer(input_data)
self.set_intermediate_buffer(op, input_data)
except RayChannelError:
# Channel closed. Exit the loop.
exit = True
return exit

def _compute(self, class_handle) -> bool:
def _compute(self, op: _DAGNodeOperation, class_handle) -> bool:
"""
Retrieve the intermediate result from the READ operation and perform the
computation. Then, cache the new intermediate result. The caller must ensure
Expand All @@ -480,7 +496,7 @@ def _compute(self, class_handle) -> bool:
Returns:
True if system error occurs and exit the loop; otherwise, False.
"""
input_data = self.reset_intermediate_buffer()
input_data = self.reset_intermediate_buffer(op)
method = getattr(class_handle, self.method_name)
try:
_process_return_vals(input_data, return_single_output=False)
Expand All @@ -492,18 +508,36 @@ def _compute(self, class_handle) -> bool:
self.set_intermediate_buffer(exc)
return False

channel_results = []
for entry in input_data:
if isinstance(entry, tuple):
channel_result, event = entry
if event:
event.synchronize()
else:
channel_result = entry
channel_results.append(channel_result)

resolved_inputs = []
for task_input in self.task_inputs:
resolved_inputs.append(task_input.resolve(input_data))
resolved_inputs.append(task_input.resolve(channel_results))

import cupy as cp

exec_event = cp.cuda.Event()
# TODO: run on exec_stream
# with self.exec_stream:
logger.info(f"{resolved_inputs=}")
try:
output_val = method(*resolved_inputs, **self.resolved_kwargs)
except Exception as exc:
output_val = _wrap_exception(exc)
self.set_intermediate_buffer(output_val)
exec_event.record()

self.set_intermediate_buffer(op, (output_val, exec_event))
return False

def _write(self) -> bool:
def _write(self, op: _DAGNodeOperation) -> bool:
"""
Retrieve the intermediate result from the COMPUTE operation and write to its
downstream DAG nodes. The caller must ensure that the last operation executed
Expand All @@ -512,20 +546,17 @@ def _write(self) -> bool:
Returns:
True if system error occurs and exit the loop; otherwise, False.
"""
output_val = self.reset_intermediate_buffer()
output_val, exec_event = self.reset_intermediate_buffer(op)
exit = False
exec_event.synchronize()
try:
self.output_writer.write(output_val)
except RayChannelError:
# Channel closed. Exit the loop.
exit = True
return exit

def exec_operation(
self,
class_handle,
op_type: _DAGNodeOperationType,
) -> bool:
def exec_operation(self, class_handle, op: _DAGNodeOperation) -> bool:
"""
An ExecutableTask corresponds to a DAGNode. It consists of three
operations: READ, COMPUTE, and WRITE, which should be executed in
Expand All @@ -540,12 +571,13 @@ def exec_operation(
Returns:
True if the next operation should not be executed; otherwise, False.
"""
op_type: _DAGNodeOperationType = op.type
if op_type == _DAGNodeOperationType.READ:
return self._read()
return self._read(op)
elif op_type == _DAGNodeOperationType.COMPUTE:
return self._compute(class_handle)
return self._compute(op, class_handle)
elif op_type == _DAGNodeOperationType.WRITE:
return self._write()
return self._write(op)


@dataclass
Expand Down Expand Up @@ -1527,23 +1559,30 @@ def _generate_dag_operation_graph_node(
# and WRITE. Each _DAGOperationGraphNode has a _DAGNodeOperation.
task_index = exec_task.task_idx
dag_node = self.idx_to_task[task_index].dag_node
method_name = exec_task.method_name
actor_handle = dag_node._get_actor_handle()
requires_nccl = dag_node.type_hint.requires_nccl()

read_node = _DAGOperationGraphNode(
_DAGNodeOperation(exec_task_idx, _DAGNodeOperationType.READ),
_DAGNodeOperation(
exec_task_idx, _DAGNodeOperationType.READ, method_name
),
task_index,
actor_handle,
requires_nccl,
)
compute_node = _DAGOperationGraphNode(
_DAGNodeOperation(exec_task_idx, _DAGNodeOperationType.COMPUTE),
_DAGNodeOperation(
exec_task_idx, _DAGNodeOperationType.COMPUTE, method_name
),
task_index,
actor_handle,
requires_nccl,
)
write_node = _DAGOperationGraphNode(
_DAGNodeOperation(exec_task_idx, _DAGNodeOperationType.WRITE),
_DAGNodeOperation(
exec_task_idx, _DAGNodeOperationType.WRITE, method_name
),
task_index,
actor_handle,
requires_nccl,
Expand Down
78 changes: 75 additions & 3 deletions python/ray/dag/dag_node_operation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import total_ordering
from enum import Enum
from typing import Set, Tuple, List, Dict
from typing import Optional, Set, Tuple, List, Dict
import ray
import heapq
from collections import defaultdict
Expand All @@ -18,12 +18,22 @@ class _DAGNodeOperationType(Enum):
COMPUTE = "COMPUTE"
WRITE = "WRITE"

def __str__(self):
if self == _DAGNodeOperationType.READ:
return "R"
elif self == _DAGNodeOperationType.COMPUTE:
return "C"
elif self == _DAGNodeOperationType.WRITE:
return "W"
assert False, f"Unknown operation type: {self}"


class _DAGNodeOperation:
def __init__(
self,
exec_task_idx: int,
operation_type: _DAGNodeOperationType,
func_name: Optional[str] = None,
):
"""
Args:
Expand All @@ -35,9 +45,24 @@ def __init__(
"""
self.exec_task_idx = exec_task_idx
self.type = operation_type
self.func_name = func_name

def __repr__(self):
return f"(Task idx: {self.exec_task_idx}, Type: {self.type})"
return f"([{self.exec_task_idx}] {self.func_name} {self.type})"
# return f"(Task idx: {self.exec_task_idx}, Type: {self.type})"

def __str__(self):
return f"([{self.exec_task_idx}] {self.func_name} {self.type})"

def __hash__(self):
return hash(str(self))

def __eq__(self, other):
return (
self.exec_task_idx == other.exec_task_idx
and self.type == other.type
and self.func_name == other.func_name
)


@total_ordering
Expand Down Expand Up @@ -367,4 +392,51 @@ def _generate_actor_to_execution_schedule(
)
for _, candidates in actor_to_candidates.items():
assert len(candidates) == 0
return actor_to_execution_schedule
for actor_handle, execution_schedule in actor_to_execution_schedule.items():
print(f"Actor {actor_handle._ray_actor_id} schedule: {execution_schedule}")
actor_to_optimized_schedule: Dict[
"ray.actor.ActorHandle", List[_DAGNodeOperation]
] = _optimize_execution_schedule(graph, actor_to_execution_schedule)
return actor_to_optimized_schedule


def _optimize_execution_schedule(
graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]],
actor_to_execution_schedule: Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]],
out_of_order_limit: int = 1,
):
"""
Optimize the execution schedule by overlapping computation and communication.
Args:
actor_to_execution_schedule: A dictionary that maps an actor handle to
the execution schedule which is a list of operations to be executed.
"""
actor_to_optimized_schedule: Dict[
"ray.actor.ActorHandle", List[_DAGNodeOperation]
] = defaultdict(list)
for actor, execution_schedule in actor_to_execution_schedule.items():
read_queue = []
other_queue = []
optimized_schedule = []
for i, operation in enumerate(execution_schedule):
if operation.type == _DAGNodeOperationType.READ:
read_queue.append(operation)
else:
other_queue.append(operation)
out_of_order_quota = out_of_order_limit + 1
while other_queue:
other_op = other_queue[0]
if read_queue:
if out_of_order_quota > 0:
optimized_schedule.append(read_queue.pop(0))
out_of_order_quota -= 1
else:
optimized_schedule.append(other_queue.pop(0))
if other_op.type == _DAGNodeOperationType.WRITE:
out_of_order_quota += 1
else:
optimized_schedule.append(other_queue.pop(0))
actor_to_optimized_schedule[actor] = optimized_schedule
print(f"Actor {actor._ray_actor_id} optimized schedule:", optimized_schedule)
return actor_to_optimized_schedule
1 change: 1 addition & 0 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def send_with_tuple_args(self, args):
return torch.ones(shape, dtype=dtype, device=self.device) * value

def recv(self, tensor):
print(f"{tensor=}")
# Check that tensor got loaded to the correct device.
assert tensor.device == self.device
return (tensor[0].item(), tensor.shape, tensor.dtype)
Expand Down
Loading

0 comments on commit ccb561c

Please sign in to comment.