-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[aDAG] Overlap computation and communication #47586
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from functools import total_ordering | ||
from enum import Enum | ||
from typing import Set, Tuple, List, Dict | ||
from typing import Optional, Tuple, List, Dict | ||
import graphviz | ||
import ray | ||
import heapq | ||
from collections import defaultdict | ||
|
@@ -18,12 +19,22 @@ class _DAGNodeOperationType(Enum): | |
COMPUTE = "COMPUTE" | ||
WRITE = "WRITE" | ||
|
||
def __str__(self): | ||
if self == _DAGNodeOperationType.READ: | ||
return "R" | ||
elif self == _DAGNodeOperationType.COMPUTE: | ||
return "C" | ||
elif self == _DAGNodeOperationType.WRITE: | ||
return "W" | ||
assert False, f"Unknown operation type: {self}" | ||
|
||
|
||
class _DAGNodeOperation: | ||
def __init__( | ||
self, | ||
exec_task_idx: int, | ||
operation_type: _DAGNodeOperationType, | ||
method_name: Optional[str] = None, | ||
): | ||
""" | ||
Args: | ||
|
@@ -32,12 +43,42 @@ def __init__( | |
as bind_index because there may be more tasks bound to an actor | ||
than tasks that appear in the current compiled DAG. | ||
operation_type: The type of operation to perform. | ||
method_name: The name of the method that this operation originates | ||
from. This is only for debugging purposes. | ||
""" | ||
self.exec_task_idx = exec_task_idx | ||
self.type = operation_type | ||
self.method_name = method_name | ||
|
||
def next_operation(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add docstring to clarify the definition? I am asking it because "next operation" can have 2 meanings. 1. the literal next op in the scheduling. the next operation for the same bind index |
||
if self.type == _DAGNodeOperationType.READ: | ||
return _DAGNodeOperation( | ||
self.exec_task_idx, _DAGNodeOperationType.COMPUTE, self.method_name | ||
) | ||
elif self.type == _DAGNodeOperationType.COMPUTE: | ||
return _DAGNodeOperation( | ||
self.exec_task_idx, _DAGNodeOperationType.WRITE, self.method_name | ||
) | ||
else: | ||
raise ValueError( | ||
"Cannot only get next operation for READ or COMPUTE type, " | ||
f"{self.type} is provided." | ||
) | ||
|
||
def __repr__(self): | ||
return f"(Task idx: {self.exec_task_idx}, Type: {self.type})" | ||
return f"([{self.exec_task_idx}] {self.method_name} {self.type})" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just return |
||
# return f"(Task idx: {self.exec_task_idx}, Type: {self.type})" | ||
|
||
def __str__(self): | ||
return f"([{self.exec_task_idx}] {self.method_name} {self.type})" | ||
|
||
def __hash__(self): | ||
return hash((self.exec_task_idx, self.type)) | ||
|
||
def __eq__(self, other): | ||
# An operation is uniquely identified by its `exec_task_idx` and type. | ||
# `func_name` is only for debugging purposes. | ||
return self.exec_task_idx == other.exec_task_idx and self.type == other.type | ||
|
||
|
||
@total_ordering | ||
|
@@ -70,8 +111,8 @@ def __init__( | |
# an integer `task_idx`, which can be used to index into `idx_to_task` | ||
# to get the corresponding task, and a `_DAGNodeOperationType`, which can | ||
# be READ, COMPUTE, or WRITE. | ||
self.in_edges: Set[Tuple[int, _DAGNodeOperationType]] = set() | ||
self.out_edges: Set[Tuple[int, _DAGNodeOperationType]] = set() | ||
self.in_edges: Dict[Tuple[int, _DAGNodeOperationType]] = {} | ||
self.out_edges: Dict[Tuple[int, _DAGNodeOperationType]] = {} | ||
|
||
@property | ||
def in_degree(self) -> int: | ||
|
@@ -122,14 +163,32 @@ def __hash__(self): | |
""" | ||
return hash((self.operation, self.task_idx)) | ||
|
||
def __str__(self): | ||
class_name = ( | ||
self.actor_handle._ray_actor_creation_function_descriptor.class_name | ||
) | ||
actor_id = self.actor_handle._actor_id.hex() | ||
return ( | ||
class_name | ||
+ "_" | ||
+ actor_id[:4] | ||
+ f" [{self.operation.exec_task_idx}] " | ||
+ f"{self.operation.method_name} {self.operation.type}" | ||
) | ||
|
||
def _get_actor_id(self): | ||
return self.actor_handle._ray_actor_id.hex() | ||
|
||
def _add_edge(from_node: _DAGOperationGraphNode, to_node: _DAGOperationGraphNode): | ||
|
||
def _add_edge( | ||
from_node: _DAGOperationGraphNode, to_node: _DAGOperationGraphNode, label="" | ||
): | ||
""" | ||
Add an edge from `from_node` to `to_node`. An edge is a tuple of | ||
the operation's `task_idx` and type. | ||
""" | ||
from_node.out_edges.add((to_node.task_idx, to_node.operation.type)) | ||
to_node.in_edges.add((from_node.task_idx, from_node.operation.type)) | ||
from_node.out_edges[(to_node.task_idx, to_node.operation.type)] = label | ||
to_node.in_edges[(from_node.task_idx, from_node.operation.type)] = label | ||
|
||
|
||
def _select_next_nodes( | ||
|
@@ -204,6 +263,44 @@ def _select_next_nodes( | |
return next_nodes | ||
|
||
|
||
def _visualize_graph( | ||
graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]] | ||
): | ||
dot = graphviz.Digraph(comment="DAG") | ||
|
||
actor_to_nodes = defaultdict(list) | ||
|
||
# Add nodes and edges to the graph | ||
for task_idx, dict in graph.items(): | ||
for node in dict.values(): | ||
node_label = str(node) | ||
dot.node(node_label, node_label) | ||
|
||
actor_to_nodes[node._get_actor_id()].append(node) | ||
|
||
# # Add in_edges | ||
# for in_edge, label in node.in_edges.items(): | ||
# in_task_idx, in_op_type = in_edge | ||
# in_node = graph[in_task_idx][in_op_type] | ||
# dot.edge(str(in_node), str(node), label="") | ||
|
||
# Add out_edges | ||
for out_edge, label in node.out_edges.items(): | ||
out_task_idx, out_op_type = out_edge | ||
out_node = graph[out_task_idx][out_op_type] | ||
color = "blue" if label == "nccl" else "black" | ||
dot.edge(node_label, str(out_node), label=label, color=color) | ||
|
||
for actor_id, nodes in actor_to_nodes.items(): | ||
with dot.subgraph(name=f"cluster_{actor_id}") as subgraph: | ||
subgraph.attr(rank=nodes[0]._get_actor_id()) | ||
for node in nodes: | ||
subgraph.node(str(node), str(node)) | ||
|
||
# Render the graph to a file or display it | ||
dot.render("dag_graph", format="png", view=True) | ||
|
||
|
||
def _build_dag_node_operation_graph( | ||
idx_to_task: Dict[int, "ray.dag.compiled_dag_node.CompiledTask"], | ||
actor_to_operation_nodes: Dict[ | ||
|
@@ -258,7 +355,7 @@ def _build_dag_node_operation_graph( | |
# Add an edge from COMPUTE with `bind_index` i to COMPUTE with | ||
# `bind_index` i+1 if they belong to the same actor. | ||
if prev_compute_node is not None: | ||
_add_edge(prev_compute_node, compute_node) | ||
_add_edge(prev_compute_node, compute_node, "next") | ||
prev_compute_node = compute_node | ||
assert task_idx not in graph | ||
graph[task_idx] = { | ||
|
@@ -297,10 +394,59 @@ def _build_dag_node_operation_graph( | |
_add_edge( | ||
graph[task_idx][_DAGNodeOperationType.WRITE], | ||
graph[downstream_task_idx][_DAGNodeOperationType.READ], | ||
"nccl" | ||
if graph[task_idx][_DAGNodeOperationType.WRITE].requires_nccl | ||
else "shm", | ||
) | ||
# _visualize_graph(graph) | ||
return graph | ||
|
||
|
||
def _node_repr(node: _DAGOperationGraphNode, idx: int, optimized_index): | ||
return str(node) + f" {idx},{optimized_index}" | ||
|
||
|
||
def _visualize_graph_ordered( | ||
actor_to_execution_nodes: Dict[ | ||
"ray.actor.ActorHandle", List[_DAGOperationGraphNode] | ||
], | ||
actor_to_optimized_nodes: Dict[ | ||
"ray.actor.ActorHandle", List[_DAGOperationGraphNode] | ||
], | ||
graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], | ||
): | ||
dot = graphviz.Digraph(comment="DAG") | ||
node_to_node_repr = {} | ||
|
||
for actor, execution_nodes in actor_to_execution_nodes.items(): | ||
optimized_nodes = actor_to_optimized_nodes[actor] | ||
node_to_optimized_index = {node: i for i, node in enumerate(optimized_nodes)} | ||
|
||
with dot.subgraph( | ||
name=f"cluster_{execution_nodes[0]._get_actor_id()}" | ||
) as subgraph: | ||
subgraph.attr(rank=execution_nodes[0]._get_actor_id()) | ||
for i, node in enumerate(execution_nodes): | ||
optimized_index = node_to_optimized_index.get(node) | ||
node_repr = _node_repr(node, i, optimized_index) | ||
color = "red" if optimized_index != i else "black" | ||
subgraph.node(node_repr, node_repr, color=color) | ||
node_to_node_repr[node] = node_repr | ||
|
||
for actor, execution_nodes in actor_to_execution_nodes.items(): | ||
for i, node in enumerate(execution_nodes): | ||
node_repr = node_to_node_repr[node] | ||
for out_edge, label in node.out_edges.items(): | ||
out_task_idx, out_op_type = out_edge | ||
out_node = graph[out_task_idx][out_op_type] | ||
out_node_repr = node_to_node_repr[out_node] | ||
color = "blue" if label == "nccl" else "black" | ||
dot.edge(node_repr, out_node_repr, label=label, color=color) | ||
|
||
# Render the graph to a file or display it | ||
dot.render("dag_schedule", format="png", view=True) | ||
|
||
|
||
def _generate_actor_to_execution_schedule( | ||
graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]] | ||
) -> Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]]: | ||
|
@@ -326,6 +472,9 @@ def _generate_actor_to_execution_schedule( | |
actor_to_execution_schedule: Dict[ | ||
"ray.actor.ActorHandle", List[_DAGNodeOperation] | ||
] = defaultdict(list) | ||
actor_to_execution_nodes: Dict[ | ||
"ray.actor.ActorHandle", List[_DAGOperationGraphNode] | ||
] = defaultdict(list) | ||
|
||
# A dictionary mapping an actor id to a list of candidate nodes. The list | ||
# is maintained as a priority queue, so the head of the queue, i.e., | ||
|
@@ -356,15 +505,81 @@ def _generate_actor_to_execution_schedule( | |
if node in visited_nodes: | ||
continue | ||
actor_to_execution_schedule[node.actor_handle].append(node.operation) | ||
actor_to_execution_nodes[node.actor_handle].append(node) | ||
visited_nodes.add(node) | ||
for out_node_task_idx, out_node_type in node.out_edges: | ||
out_node = graph[out_node_task_idx][out_node_type] | ||
out_node.in_edges.remove((node.task_idx, node.operation.type)) | ||
out_node.in_edges.pop((node.task_idx, node.operation.type)) | ||
if out_node.in_degree == 0: | ||
heapq.heappush( | ||
actor_to_candidates[out_node.actor_handle._actor_id], | ||
out_node, | ||
) | ||
for _, candidates in actor_to_candidates.items(): | ||
assert len(candidates) == 0 | ||
return actor_to_execution_schedule | ||
for actor_handle, execution_schedule in actor_to_execution_schedule.items(): | ||
print(f"Actor {actor_handle._ray_actor_id} schedule: {execution_schedule}") | ||
# _visualize_graph_ordered(actor_to_nodes, graph) | ||
return actor_to_execution_schedule, actor_to_execution_nodes | ||
|
||
|
||
def _optimize_execution_schedule( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add unit tests like @kevin85421 did before? (not e2e, but unit level testing) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, will do. |
||
actor_to_execution_schedule: Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]], | ||
actor_to_execution_nodes: Dict[ | ||
"ray.actor.ActorHandle", List[_DAGOperationGraphNode] | ||
], | ||
out_of_order_limit: int = 1, | ||
): | ||
""" | ||
Optimize the execution schedule by overlapping computation and communication. | ||
|
||
Args: | ||
actor_to_execution_schedule: A dictionary that maps an actor handle to | ||
the execution schedule which is a list of operations to be executed. | ||
out_of_order_limit: The maximum number of out-of-order `receive` operations | ||
allowed. | ||
""" | ||
if out_of_order_limit == 0: | ||
return actor_to_execution_schedule, actor_to_execution_nodes | ||
|
||
import copy | ||
|
||
actor_to_optimized_schedule: Dict[ | ||
"ray.actor.ActorHandle", List[_DAGNodeOperation] | ||
] = copy.deepcopy(actor_to_execution_schedule) | ||
actor_to_optimized_nodes: Dict[ | ||
"ray.actor.ActorHandle", List[_DAGOperationGraphNode] | ||
] = copy.deepcopy(actor_to_execution_nodes) | ||
for actor, optimized_nodes in actor_to_optimized_nodes.items(): | ||
optimized_schedule = actor_to_optimized_schedule[actor] | ||
for i in range(len(optimized_nodes)): | ||
if ( | ||
optimized_nodes[i].operation.type == _DAGNodeOperationType.READ | ||
and optimized_nodes[i].requires_nccl | ||
): | ||
for j in range(i - 1, -1, -1): | ||
if ( | ||
optimized_nodes[j].operation.type | ||
== _DAGNodeOperationType.COMPUTE | ||
): | ||
nccl_read_node = optimized_nodes[i] | ||
nccl_read_op = optimized_schedule[i] | ||
sublist = optimized_nodes[j:i] | ||
sublist_op = optimized_schedule[j:i] | ||
optimized_nodes[j + 1 : i + 1] = sublist | ||
optimized_schedule[j + 1 : i + 1] = sublist_op | ||
optimized_nodes[j] = nccl_read_node | ||
optimized_schedule[j] = nccl_read_op | ||
break | ||
if ( | ||
optimized_nodes[j].operation.type == _DAGNodeOperationType.READ | ||
) and (optimized_nodes[j].requires_nccl): | ||
# Keep relative order of nccl reads | ||
break | ||
if ( | ||
optimized_nodes[j].operation.type == _DAGNodeOperationType.WRITE | ||
) and (optimized_nodes[j].requires_nccl): | ||
# Keep relative order of nccl reads and writes | ||
break | ||
print(f"Actor {actor._ray_actor_id} optimized schedule:", optimized_schedule) | ||
return actor_to_optimized_schedule, actor_to_optimized_nodes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we will allow to dynamic profiling (like profiling N iterations, and you can enable/disable at runtime). I think this one is okay for now. Can you create a corresponding issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created #47745