Skip to content

Commit

Permalink
[XLA:GPU] Fix comments in collective select folder
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681517882
  • Loading branch information
frgossen authored and Google-ML-Automation committed Oct 2, 2024
1 parent 1a1bfe0 commit 71626ec
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 25 deletions.
2 changes: 0 additions & 2 deletions xla/service/copy_insertion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2293,8 +2293,6 @@ absl::Status CopyInsertion::RemoveUnnecessaryCopies(
}
}

std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);

int64_t num_existing_copies = GetNumExistingCopies(module, execution_threads);
bool changed = true;
int64_t num_iterations = -1;
Expand Down
63 changes: 40 additions & 23 deletions xla/service/gpu/transforms/collective_select_folder.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,54 @@ limitations under the License.

namespace xla {

// When collective-permute operates on a comparison to a device id
// and the senders match the condition's branch
// we can link collective-permute to the original data skipping the comparison.
// For example
// condition = broadcast(compare(replica_id, X), direction=EQ
// data_snd = select(condition, compare_true_data, compare_false_data)
// rcv = collective-permute(data_snd compare_true_data), pairs={{X,0}}
// can be transformed to
// rcv = collective-permute(compare_true_data), pairs={{X,0}}
// If a collective-permute selects its source data based on a partition or
// replica ID and we can prove that the condition is either always true or
// always false, we can fold the redundant select op and use the correct source
// data directly.
//
// The pass is *only* handling compare direction={EQ,NE}.
// The pass handles Compare with and without preceding Broadcast.
// Example:
//
// condition = compare(replica-id(), X), direction=EQ
// snd_data = select(condition, true_data, false_data)
// rcv_data = collective-permute(snd_data), source_target_pairs={{X,0}}
//
// The condition is always true for the only relevant replica X and the IR can
// be folded into
//
// rcv_data = collective-permute(true_data), source_target_pairs={{X,0}}
//
// The pass only supports simple partion/replica-based predicates, comparing
// partition/replica-id with a constant. Only comparison directions {EQ,NE} are
// supported. The predicate may be broadcasted.
//
// This pass is motivated by pipeline parallelism, where it removes undesired
// data dependencies.
//
// Example:
//
// This pass is particularly useful in the pipeline parallelism generated module
// such as:
// fwd_data = ...
// bwd_data =
// is_first_device = ...
// is_last_device = ...
// data_snd = select(is_last_device, bwd_data, fwd_data)
// bwd_data_rcv = collective-permute(data_snd), pairs={{3,0}}
// fwd_data_rcv = collective-permute(data_snd), pairs={{0,1},{1,2},{2,3}}
// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv)
// snd_data = select(is_last_device, bwd_data, fwd_data)
// rcv_bwd_data = collective-permute(snd_data),
// source_target_pairs={{LAST_ID,0}}
// rcv_fwd_data = collective-permute(snd_data),
// source_target_pairs={{0,1},{1,2},...,{LAST_ID,0}}
// ROOT rcv_data = select(is_first_device, rcv_bwd_data, rcv_fwd_data)
//
// After the transformation, the module will become:
// fwd_data_snd = ...
// bwd_data_snd = ...
// The select can be removed on both paths resulting in
//
// fwd_data = ...
// bwd_data =
// is_first_device = ...
// bwd_data_rcv = collective-permute(bwd_data_snd), pairs={{3,0}}
// fwd_data_rcv = collective-permute(fwd_data_snd), pairs={{0,1},{1,2},{2,3}}
// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv)
// is_last_device = ...
// rcv_bwd_data = collective-permute(bwd_data),
// source_target_pairs={{LAST_ID,0}}
// rcv_fwd_data = collective-permute(fwd_data),
// source_target_pairs={{0,1},{1,2},...,{LAST_ID,0}}
// ROOT rcv_data = select(is_first_device, rcv_bwd_data, rcv_fwd_data)
//
class CollectiveSelectFolder : public HloModulePass {
public:
absl::string_view name() const override { return "collective-select-folder"; }
Expand Down

0 comments on commit 71626ec

Please sign in to comment.