Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aDAG] Overlap computation and communication #47586

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
393 changes: 360 additions & 33 deletions python/ray/dag/compiled_dag_node.py

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions python/ray/dag/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@

# Feature flag to turn on profiling.
RAY_ADAG_ENABLE_PROFILING = os.environ.get("RAY_ADAG_ENABLE_PROFILING", "0") == "1"

# Feature flag to turn on torch profiling.
RAY_ADAG_ENABLE_TORCH_PROFILING = (
os.environ.get("RAY_ADAG_ENABLE_TORCH_PROFILING", "0") == "1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will allow to dynamic profiling (like profiling N iterations, and you can enable/disable at runtime). I think this one is okay for now. Can you create a corresponding issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #47745

)
11 changes: 11 additions & 0 deletions python/ray/dag/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
os.environ.get("RAY_DAG_max_inflight_executions", 10)
)

DEFAULT_OVERLAP_GPU_COMMUNICATION = bool(
os.environ.get("RAY_DAG_overlap_gpu_communication", 0)
)


@DeveloperAPI
@dataclass
Expand Down Expand Up @@ -58,6 +62,12 @@ class DAGContext:
executions is beyond the DAG capacity, the new execution would
be blocked in the first place; therefore, this limit is only
enforced when it is smaller than the DAG capacity.
max_inflight_executions: The maximum number of in-flight executions
that can be submitted before consuming the output.
overlap_gpu_communication: Whether to overlap GPU communication with
computation during DAG execution. If True, the communication
and computation can be overlapped, which can improve the
performance of the DAG execution.
"""

execution_timeout: int = DEFAULT_EXECUTION_TIMEOUT_S
Expand All @@ -66,6 +76,7 @@ class DAGContext:
asyncio_max_queue_size: int = DEFAULT_ASYNCIO_MAX_QUEUE_SIZE
max_buffered_results: int = DEFAULT_MAX_BUFFERED_RESULTS
max_inflight_executions: int = DEFAULT_MAX_INFLIGHT_EXECUTIONS
overlap_gpu_communication: bool = DEFAULT_OVERLAP_GPU_COMMUNICATION

@staticmethod
def get_current() -> "DAGContext":
Expand Down
7 changes: 7 additions & 0 deletions python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def experimental_compile(
_asyncio_max_queue_size: Optional[int] = None,
_max_buffered_results: Optional[int] = None,
_max_inflight_executions: Optional[int] = None,
_overlap_gpu_communication: Optional[bool] = True,
) -> "ray.dag.CompiledDAG":
"""Compile an accelerated execution path for this DAG.

Expand All @@ -191,6 +192,11 @@ def experimental_compile(
are allowed to be sent to this DAG. Before submitting more requests,
the caller is responsible for calling ray.get to clear finished
in-flight requests.
overlap_gpu_communication: Whether to overlap GPU communication with
computation during DAG execution. If True, the communication
and computation can be overlapped, which can improve the
performance of the DAG execution. If None, the default value
will be used.

Returns:
A compiled DAG.
Expand Down Expand Up @@ -224,6 +230,7 @@ def experimental_compile(
_asyncio_max_queue_size,
_max_buffered_results,
_max_inflight_executions,
_overlap_gpu_communication,
)

def execute(
Expand Down
235 changes: 225 additions & 10 deletions python/ray/dag/dag_node_operation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import total_ordering
from enum import Enum
from typing import Set, Tuple, List, Dict
from typing import Optional, Tuple, List, Dict
import graphviz
import ray
import heapq
from collections import defaultdict
Expand All @@ -18,12 +19,22 @@ class _DAGNodeOperationType(Enum):
COMPUTE = "COMPUTE"
WRITE = "WRITE"

def __str__(self):
if self == _DAGNodeOperationType.READ:
return "R"
elif self == _DAGNodeOperationType.COMPUTE:
return "C"
elif self == _DAGNodeOperationType.WRITE:
return "W"
assert False, f"Unknown operation type: {self}"


class _DAGNodeOperation:
def __init__(
self,
exec_task_idx: int,
operation_type: _DAGNodeOperationType,
method_name: Optional[str] = None,
):
"""
Args:
Expand All @@ -32,12 +43,42 @@ def __init__(
as bind_index because there may be more tasks bound to an actor
than tasks that appear in the current compiled DAG.
operation_type: The type of operation to perform.
method_name: The name of the method that this operation originates
from. This is only for debugging purposes.
"""
self.exec_task_idx = exec_task_idx
self.type = operation_type
self.method_name = method_name

def next_operation(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add docstring to clarify the definition?

I am asking it because "next operation" can have 2 meanings. 1. the literal next op in the scheduling. the next operation for the same bind index

if self.type == _DAGNodeOperationType.READ:
return _DAGNodeOperation(
self.exec_task_idx, _DAGNodeOperationType.COMPUTE, self.method_name
)
elif self.type == _DAGNodeOperationType.COMPUTE:
return _DAGNodeOperation(
self.exec_task_idx, _DAGNodeOperationType.WRITE, self.method_name
)
else:
raise ValueError(
"Cannot only get next operation for READ or COMPUTE type, "
f"{self.type} is provided."
)

def __repr__(self):
return f"(Task idx: {self.exec_task_idx}, Type: {self.type})"
return f"([{self.exec_task_idx}] {self.method_name} {self.type})"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just return __str__()?

# return f"(Task idx: {self.exec_task_idx}, Type: {self.type})"

def __str__(self):
return f"([{self.exec_task_idx}] {self.method_name} {self.type})"

def __hash__(self):
return hash((self.exec_task_idx, self.type))

def __eq__(self, other):
# An operation is uniquely identified by its `exec_task_idx` and type.
# `func_name` is only for debugging purposes.
return self.exec_task_idx == other.exec_task_idx and self.type == other.type


@total_ordering
Expand Down Expand Up @@ -70,8 +111,8 @@ def __init__(
# an integer `task_idx`, which can be used to index into `idx_to_task`
# to get the corresponding task, and a `_DAGNodeOperationType`, which can
# be READ, COMPUTE, or WRITE.
self.in_edges: Set[Tuple[int, _DAGNodeOperationType]] = set()
self.out_edges: Set[Tuple[int, _DAGNodeOperationType]] = set()
self.in_edges: Dict[Tuple[int, _DAGNodeOperationType]] = {}
self.out_edges: Dict[Tuple[int, _DAGNodeOperationType]] = {}

@property
def in_degree(self) -> int:
Expand Down Expand Up @@ -122,14 +163,32 @@ def __hash__(self):
"""
return hash((self.operation, self.task_idx))

def __str__(self):
class_name = (
self.actor_handle._ray_actor_creation_function_descriptor.class_name
)
actor_id = self.actor_handle._actor_id.hex()
return (
class_name
+ "_"
+ actor_id[:4]
+ f" [{self.operation.exec_task_idx}] "
+ f"{self.operation.method_name} {self.operation.type}"
)

def _get_actor_id(self):
return self.actor_handle._ray_actor_id.hex()

def _add_edge(from_node: _DAGOperationGraphNode, to_node: _DAGOperationGraphNode):

def _add_edge(
from_node: _DAGOperationGraphNode, to_node: _DAGOperationGraphNode, label=""
):
"""
Add an edge from `from_node` to `to_node`. An edge is a tuple of
the operation's `task_idx` and type.
"""
from_node.out_edges.add((to_node.task_idx, to_node.operation.type))
to_node.in_edges.add((from_node.task_idx, from_node.operation.type))
from_node.out_edges[(to_node.task_idx, to_node.operation.type)] = label
to_node.in_edges[(from_node.task_idx, from_node.operation.type)] = label


def _select_next_nodes(
Expand Down Expand Up @@ -204,6 +263,44 @@ def _select_next_nodes(
return next_nodes


def _visualize_graph(
graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]]
):
dot = graphviz.Digraph(comment="DAG")

actor_to_nodes = defaultdict(list)

# Add nodes and edges to the graph
for task_idx, dict in graph.items():
for node in dict.values():
node_label = str(node)
dot.node(node_label, node_label)

actor_to_nodes[node._get_actor_id()].append(node)

# # Add in_edges
# for in_edge, label in node.in_edges.items():
# in_task_idx, in_op_type = in_edge
# in_node = graph[in_task_idx][in_op_type]
# dot.edge(str(in_node), str(node), label="")

# Add out_edges
for out_edge, label in node.out_edges.items():
out_task_idx, out_op_type = out_edge
out_node = graph[out_task_idx][out_op_type]
color = "blue" if label == "nccl" else "black"
dot.edge(node_label, str(out_node), label=label, color=color)

for actor_id, nodes in actor_to_nodes.items():
with dot.subgraph(name=f"cluster_{actor_id}") as subgraph:
subgraph.attr(rank=nodes[0]._get_actor_id())
for node in nodes:
subgraph.node(str(node), str(node))

# Render the graph to a file or display it
dot.render("dag_graph", format="png", view=True)


def _build_dag_node_operation_graph(
idx_to_task: Dict[int, "ray.dag.compiled_dag_node.CompiledTask"],
actor_to_operation_nodes: Dict[
Expand Down Expand Up @@ -258,7 +355,7 @@ def _build_dag_node_operation_graph(
# Add an edge from COMPUTE with `bind_index` i to COMPUTE with
# `bind_index` i+1 if they belong to the same actor.
if prev_compute_node is not None:
_add_edge(prev_compute_node, compute_node)
_add_edge(prev_compute_node, compute_node, "next")
prev_compute_node = compute_node
assert task_idx not in graph
graph[task_idx] = {
Expand Down Expand Up @@ -297,10 +394,59 @@ def _build_dag_node_operation_graph(
_add_edge(
graph[task_idx][_DAGNodeOperationType.WRITE],
graph[downstream_task_idx][_DAGNodeOperationType.READ],
"nccl"
if graph[task_idx][_DAGNodeOperationType.WRITE].requires_nccl
else "shm",
)
# _visualize_graph(graph)
return graph


def _node_repr(node: _DAGOperationGraphNode, idx: int, optimized_index):
return str(node) + f" {idx},{optimized_index}"


def _visualize_graph_ordered(
actor_to_execution_nodes: Dict[
"ray.actor.ActorHandle", List[_DAGOperationGraphNode]
],
actor_to_optimized_nodes: Dict[
"ray.actor.ActorHandle", List[_DAGOperationGraphNode]
],
graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]],
):
dot = graphviz.Digraph(comment="DAG")
node_to_node_repr = {}

for actor, execution_nodes in actor_to_execution_nodes.items():
optimized_nodes = actor_to_optimized_nodes[actor]
node_to_optimized_index = {node: i for i, node in enumerate(optimized_nodes)}

with dot.subgraph(
name=f"cluster_{execution_nodes[0]._get_actor_id()}"
) as subgraph:
subgraph.attr(rank=execution_nodes[0]._get_actor_id())
for i, node in enumerate(execution_nodes):
optimized_index = node_to_optimized_index.get(node)
node_repr = _node_repr(node, i, optimized_index)
color = "red" if optimized_index != i else "black"
subgraph.node(node_repr, node_repr, color=color)
node_to_node_repr[node] = node_repr

for actor, execution_nodes in actor_to_execution_nodes.items():
for i, node in enumerate(execution_nodes):
node_repr = node_to_node_repr[node]
for out_edge, label in node.out_edges.items():
out_task_idx, out_op_type = out_edge
out_node = graph[out_task_idx][out_op_type]
out_node_repr = node_to_node_repr[out_node]
color = "blue" if label == "nccl" else "black"
dot.edge(node_repr, out_node_repr, label=label, color=color)

# Render the graph to a file or display it
dot.render("dag_schedule", format="png", view=True)


def _generate_actor_to_execution_schedule(
graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]]
) -> Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]]:
Expand All @@ -326,6 +472,9 @@ def _generate_actor_to_execution_schedule(
actor_to_execution_schedule: Dict[
"ray.actor.ActorHandle", List[_DAGNodeOperation]
] = defaultdict(list)
actor_to_execution_nodes: Dict[
"ray.actor.ActorHandle", List[_DAGOperationGraphNode]
] = defaultdict(list)

# A dictionary mapping an actor id to a list of candidate nodes. The list
# is maintained as a priority queue, so the head of the queue, i.e.,
Expand Down Expand Up @@ -356,15 +505,81 @@ def _generate_actor_to_execution_schedule(
if node in visited_nodes:
continue
actor_to_execution_schedule[node.actor_handle].append(node.operation)
actor_to_execution_nodes[node.actor_handle].append(node)
visited_nodes.add(node)
for out_node_task_idx, out_node_type in node.out_edges:
out_node = graph[out_node_task_idx][out_node_type]
out_node.in_edges.remove((node.task_idx, node.operation.type))
out_node.in_edges.pop((node.task_idx, node.operation.type))
if out_node.in_degree == 0:
heapq.heappush(
actor_to_candidates[out_node.actor_handle._actor_id],
out_node,
)
for _, candidates in actor_to_candidates.items():
assert len(candidates) == 0
return actor_to_execution_schedule
for actor_handle, execution_schedule in actor_to_execution_schedule.items():
print(f"Actor {actor_handle._ray_actor_id} schedule: {execution_schedule}")
# _visualize_graph_ordered(actor_to_nodes, graph)
return actor_to_execution_schedule, actor_to_execution_nodes


def _optimize_execution_schedule(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add unit tests like @kevin85421 did before? (not e2e, but unit level testing)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, will do.

actor_to_execution_schedule: Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]],
actor_to_execution_nodes: Dict[
"ray.actor.ActorHandle", List[_DAGOperationGraphNode]
],
out_of_order_limit: int = 1,
):
"""
Optimize the execution schedule by overlapping computation and communication.

Args:
actor_to_execution_schedule: A dictionary that maps an actor handle to
the execution schedule which is a list of operations to be executed.
out_of_order_limit: The maximum number of out-of-order `receive` operations
allowed.
"""
if out_of_order_limit == 0:
return actor_to_execution_schedule, actor_to_execution_nodes

import copy

actor_to_optimized_schedule: Dict[
"ray.actor.ActorHandle", List[_DAGNodeOperation]
] = copy.deepcopy(actor_to_execution_schedule)
actor_to_optimized_nodes: Dict[
"ray.actor.ActorHandle", List[_DAGOperationGraphNode]
] = copy.deepcopy(actor_to_execution_nodes)
for actor, optimized_nodes in actor_to_optimized_nodes.items():
optimized_schedule = actor_to_optimized_schedule[actor]
for i in range(len(optimized_nodes)):
if (
optimized_nodes[i].operation.type == _DAGNodeOperationType.READ
and optimized_nodes[i].requires_nccl
):
for j in range(i - 1, -1, -1):
if (
optimized_nodes[j].operation.type
== _DAGNodeOperationType.COMPUTE
):
nccl_read_node = optimized_nodes[i]
nccl_read_op = optimized_schedule[i]
sublist = optimized_nodes[j:i]
sublist_op = optimized_schedule[j:i]
optimized_nodes[j + 1 : i + 1] = sublist
optimized_schedule[j + 1 : i + 1] = sublist_op
optimized_nodes[j] = nccl_read_node
optimized_schedule[j] = nccl_read_op
break
if (
optimized_nodes[j].operation.type == _DAGNodeOperationType.READ
) and (optimized_nodes[j].requires_nccl):
# Keep relative order of nccl reads
break
if (
optimized_nodes[j].operation.type == _DAGNodeOperationType.WRITE
) and (optimized_nodes[j].requires_nccl):
# Keep relative order of nccl reads and writes
break
print(f"Actor {actor._ray_actor_id} optimized schedule:", optimized_schedule)
return actor_to_optimized_schedule, actor_to_optimized_nodes
Loading