Skip to content

Commit

Permalink
[XLA:GPU] Add sub-byte normalization after TransposeDimensionGrouper
Browse files Browse the repository at this point in the history
TransposeDimensionGrouper inserts bitcasts, and XLA requires subbyte types (int4 in this case) to have explicit bit witdth.

PiperOrigin-RevId: 680620568
  • Loading branch information
mooskagh authored and Google-ML-Automation committed Sep 30, 2024
1 parent d2f9134 commit 27dffd6
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,10 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
!debug_options.xla_gpu_enable_priority_fusion();
pipeline.AddPass<HloPassFix<ReductionSplitter>>(ignore_small_reduce_dims);
pipeline.AddPass<HloPassFix<TreeReductionRewriter>>(gpu_version);
// Normalization passes might have introduced s4 tensors without bit width
// annotations, this pass will add the annotations.
pipeline.AddPass<SubByteNormalization>(
SubByteNormalization::SET_ELEMENT_SIZE);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}

Expand Down

0 comments on commit 27dffd6

Please sign in to comment.