Skip to content

Commit

Permalink
PR #17453: Reorder Collective Optimization Passes
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17453

Moves the collective quantizer pass ahead of the collective pipeliner to preserve FP8 quantization and dequantization patterns preceded or followed by collectives without running the collective pipeliner post layout assignment. See #12866 and #15292.
Copybara import of the project:

--
d8e8f63 by Philipp Hack <[email protected]>:

Moves the collective quantizer pass before the collective pipeliner.

--
b34f203 by Philipp Hack <[email protected]>:

Moves the collective quantizer pass before the collective pipeliner.

Merging this change closes #17453

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17453 from philipphack:u_collective_passes_xla b34f203
PiperOrigin-RevId: 681670419
  • Loading branch information
philipphack authored and Google-ML-Automation committed Oct 3, 2024
1 parent da687fc commit 1ad143f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 64 deletions.
12 changes: 6 additions & 6 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,12 @@ absl::Status RunCollectiveOptimizationPasses(
/*enable_reduce_scatter=*/debug_options
.xla_gpu_enable_while_loop_reduce_scatter_code_motion());

// Moves collectives' subsequent quantization before the collective to
// minimize data transfers.
collectives_pipeline.AddPass<CollectiveQuantizer>();
// Remove dead computations after collective quantization.
collectives_pipeline.AddPass<HloDCE>();

if (!debug_options.xla_gpu_run_post_layout_collective_pipeliner()) {
TF_RETURN_IF_ERROR(
AddCollectivePipelinerPasses(debug_options, collectives_pipeline));
Expand Down Expand Up @@ -959,12 +965,6 @@ absl::Status RunCollectiveOptimizationPasses(
// Remove dead computations left over after ar/rs promotion.
collectives_pipeline.AddPass<HloDCE>();

// Moves collectives' subsequent quantization before the collective to
// minimize data transfers.
collectives_pipeline.AddPass<CollectiveQuantizer>();
// Remove dead computations after collective quantization.
collectives_pipeline.AddPass<HloDCE>();

// Run WhileLoopTripCountAnnotator after collective pipelining and before
// layout assignment and fusion.This pass does some pattern-matching on
// while bodies/conditions, and this is where the HLO is "nicest".
Expand Down
99 changes: 41 additions & 58 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1093,73 +1093,57 @@ ENTRY main.9_spmd {
CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr);
}

TEST_F(CollectiveOpsTestE2E, PostLayoutCollectivePipeliner) {
// We need fp8 support to test the post-layout collective pipeliner. This will
// preserve the desired fp8 patterns and so the gemm rewriter can correctly
// recognize them and rewrite to custom fp8 gemm calls.
TEST_F(CollectiveOpsTestE2E, CollectivePipelinerF8) {
// Verify that FP8 patterns are preserved when collectives are pipelined so
// the GEMM rewriter can create FP8 matmuls.
if (!HasFp8Support()) {
GTEST_SKIP() << "Test requires a post-Ada GPU.";
GTEST_SKIP() << "Test requires Hopper or newer architecture.";
}

absl::string_view kModuleReplicatedStr = R"(
HloModule module, entry_computation_layout={(bf16[384,128], bf16[96,128], bf16[], bf16[])->bf16[384,128]}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
HloModule module, entry_computation_layout={(bf16[128,128], bf16[32,128], bf16[], bf16[])->bf16[512,128]}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4
while_cond {
param = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
input = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) parameter(0)
loop_counter = s32[] get-tuple-element(input), index=0
c4 = s32[] constant(4)
ROOT compare = pred[] compare(loop_counter, c4), direction=LT
}
while_body {
param = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[384,128] get-tuple-element(param), index=1
get-tuple-element.k = bf16[96,128] get-tuple-element(param), index=2
constant.2561 = s32[] constant(0)
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
constant.2560 = s32[] constant(-1)
add.231 = s32[] add(subtract.139, constant.2560)
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.k = bf16[32,128] dynamic-slice(get-tuple-element.k, select.1348, constant.2561), dynamic_slice_sizes={32,128}
r = bf16[32,128] bitcast(dynamic-slice.k)
a = bf16[32,128] add(r, r), control-predecessors={constant.2559}
// A fp8 pattern of quant-dequant before the collective AG.
qa = <<F8E4M3>>[32,128] convert(a)
dqa = bf16[32,128] convert(qa)
a_scale = bf16[] get-tuple-element(param), index=3
a_scales = bf16[32,128] broadcast(a_scale), dimensions={}
dqa_unscaled = bf16[32,128] multiply(dqa, a_scales)
mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}}
ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128}
qma = <<F8E4M3>>[128,128] convert(ma)
dqma = bf16[128,128] convert(qma)
ma_scale = bf16[] get-tuple-element(param), index=4
ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={}
dqma_unscaled = bf16[128,128] multiply(dqma, ma_scales)
mc = bf16[128,128] dot(dqma_unscaled, mb), lhs_contracting_dims={1}, rhs_contracting_dims={0}
dynamic-update-slice.35 = bf16[384,128] dynamic-update-slice(get-tuple-element.395, mc, select.1348, constant.2561)
ROOT tuple = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k, a_scale, ma_scale), control-predecessors={a}
input = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) parameter(0)
loop_counter = s32[] get-tuple-element(input), index=0
lhs = bf16[128,128] get-tuple-element(input), index=1
rhs = bf16[32,128] get-tuple-element(input), index=2
partial_dot_output = bf16[512,128] get-tuple-element(input), index=5
lhs_f8 = f8e4m3fn[128,128] convert(lhs)
rhs_f8 = f8e4m3fn[32,128] convert(rhs)
lhs_bf16 = bf16[128,128] convert(lhs_f8)
rhs_bf16 = bf16[32,128] convert(rhs_f8)
scale_lhs = bf16[] get-tuple-element(input), index=3
scale_rhs = bf16[] get-tuple-element(input), index=4
scale_lhs_bcast = bf16[128,128] broadcast(scale_lhs), dimensions={}
scale_rhs_bcast = bf16[32,128] broadcast(scale_rhs), dimensions={}
lhs_scaled = bf16[128,128] multiply(lhs_bf16, scale_lhs_bcast)
rhs_scaled = bf16[32,128] multiply(rhs_bf16, scale_rhs_bcast)
rhs_scaled_all_gathered = bf16[128,128] all-gather(rhs_scaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}}
dot = bf16[128,128] dot(lhs_scaled, rhs_scaled_all_gathered), lhs_contracting_dims={1}, rhs_contracting_dims={1}
c0 = s32[] constant(0)
size = s32[] constant(128)
iteration_offset = s32[] multiply(loop_counter, size)
updated_dot_output = bf16[512,128] dynamic-update-slice(partial_dot_output, dot, iteration_offset, c0)
c1 = s32[] constant(1)
loop_counter_plus_one = s32[] add(loop_counter, c1)
ROOT tuple = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) tuple(loop_counter_plus_one, lhs, rhs, scale_lhs, scale_rhs, updated_dot_output)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[384,128] parameter(0)
p1 = bf16[96,128] parameter(1)
s0 = bf16[] parameter(2)
s1 = bf16[] parameter(3)
tuple = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) tuple(c0, p0, p1, s0, s1)
while = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[384,128] get-tuple-element(while), index=1
lhs = bf16[128,128] parameter(0)
rhs = bf16[32,128] parameter(1)
scale_lhs = bf16[] parameter(2)
scale_rhs = bf16[] parameter(3)
result_buffer = bf16[512,128] constant(0.)
while_input = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) tuple(c0, lhs, rhs, scale_lhs, scale_rhs, result_buffer)
while = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) while(while_input), condition=while_cond, body=while_body
ROOT dot_output = bf16[512,128] get-tuple-element(while), index=5
}
)";

Expand All @@ -1169,7 +1153,6 @@ ENTRY entry {
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
auto opts = GetDebugOptionsForTest();
opts.set_xla_gpu_run_post_layout_collective_pipeliner(true);
opts.set_xla_gpu_enable_pipelined_collectives(true);
opts.set_xla_gpu_enable_triton_gemm(false);
CollectiveOpsVerifyF8Matmul(
Expand Down

0 comments on commit 1ad143f

Please sign in to comment.