From 6218990565e2ebbc24419ede63b35389162f5eb6 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Wed, 2 Oct 2024 11:03:57 -0700 Subject: [PATCH] [XLA:GPU] Fix comments in collective select folder PiperOrigin-RevId: 681517882 --- xla/service/copy_insertion.cc | 2 - .../gpu/transforms/collective_select_folder.h | 63 ++++++++++++------- 2 files changed, 40 insertions(+), 25 deletions(-) diff --git a/xla/service/copy_insertion.cc b/xla/service/copy_insertion.cc index 6e2fc858d0958..39eafb7886717 100644 --- a/xla/service/copy_insertion.cc +++ b/xla/service/copy_insertion.cc @@ -2293,8 +2293,6 @@ absl::Status CopyInsertion::RemoveUnnecessaryCopies( } } - std::unique_ptr call_graph = CallGraph::Build(module); - int64_t num_existing_copies = GetNumExistingCopies(module, execution_threads); bool changed = true; int64_t num_iterations = -1; diff --git a/xla/service/gpu/transforms/collective_select_folder.h b/xla/service/gpu/transforms/collective_select_folder.h index 3e14ecbf054e1..c53eb2ca508b3 100644 --- a/xla/service/gpu/transforms/collective_select_folder.h +++ b/xla/service/gpu/transforms/collective_select_folder.h @@ -24,37 +24,54 @@ limitations under the License. namespace xla { -// When collective-permute operates on a comparison to a device id -// and the senders match the condition's branch -// we can link collective-permute to the original data skipping the comparison. -// For example -// condition = broadcast(compare(replica_id, X), direction=EQ -// data_snd = select(condition, compare_true_data, compare_false_data) -// rcv = collective-permute(data_snd compare_true_data), pairs={{X,0}} -// can be transformed to -// rcv = collective-permute(compare_true_data), pairs={{X,0}} +// If a collective-permute selects its source data based on a partition or +// replica ID and we can prove that the condition is either always true or +// always false, we can fold the redundant select op and use the correct source +// data directly. // -// The pass is *only* handling compare direction={EQ,NE}. -// The pass handles Compare with and without preceding Broadcast. +// Example: +// +// condition = compare(replica-id(), X), direction=EQ +// snd_data = select(condition, true_data, false_data) +// rcv_data = collective-permute(snd_data), source_target_pairs={{X,0}} +// +// The condition is always true for the only relevant replica X and the IR can +// be folded into +// +// rcv_data = collective-permute(true_data), source_target_pairs={{X,0}} +// +// The pass only supports simple partion/replica-based predicates, comparing +// partition/replica-id with a constant. Only comparison directions {EQ,NE} are +// supported. The predicate may be broadcasted. +// +// This pass is motivated by pipeline parallelism, where it removes undesired +// data dependencies. +// +// Example: // -// This pass is particularly useful in the pipeline parallelism generated module -// such as: // fwd_data = ... // bwd_data = // is_first_device = ... // is_last_device = ... -// data_snd = select(is_last_device, bwd_data, fwd_data) -// bwd_data_rcv = collective-permute(data_snd), pairs={{3,0}} -// fwd_data_rcv = collective-permute(data_snd), pairs={{0,1},{1,2},{2,3}} -// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv) +// snd_data = select(is_last_device, bwd_data, fwd_data) +// rcv_bwd_data = collective-permute(snd_data), +// source_target_pairs={{LAST_ID,0}} +// rcv_fwd_data = collective-permute(snd_data), +// source_target_pairs={{0,1},{1,2},...,{LAST_ID,0}} +// ROOT rcv_data = select(is_first_device, rcv_bwd_data, rcv_fwd_data) // -// After the transformation, the module will become: -// fwd_data_snd = ... -// bwd_data_snd = ... +// The select can be removed on both paths resulting in +// +// fwd_data = ... +// bwd_data = // is_first_device = ... -// bwd_data_rcv = collective-permute(bwd_data_snd), pairs={{3,0}} -// fwd_data_rcv = collective-permute(fwd_data_snd), pairs={{0,1},{1,2},{2,3}} -// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv) +// is_last_device = ... +// rcv_bwd_data = collective-permute(bwd_data), +// source_target_pairs={{LAST_ID,0}} +// rcv_fwd_data = collective-permute(fwd_data), +// source_target_pairs={{0,1},{1,2},...,{LAST_ID,0}} +// ROOT rcv_data = select(is_first_device, rcv_bwd_data, rcv_fwd_data) +// class CollectiveSelectFolder : public HloModulePass { public: absl::string_view name() const override { return "collective-select-folder"; }