Skip to content

Commit

Permalink
[XLA:GPU] Fix the derivation for the number of warps for tiled HLO co…
Browse files Browse the repository at this point in the history
…mputations.

The number of warps used to process a computation determines how many
registers we are able to use concurrently. Therefore, looking at the largest
(padded) tile size makes sense, since it determines the minimum number of
elements that must be live concurrently.

Previously, the logic erroneously only looked at the output tile sizes.

This approach is not perfect, and may be further improved by e.g. doing a
live range analysis on the tiles of the computation.

PiperOrigin-RevId: 680668856
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Sep 30, 2024
1 parent 12d351d commit 9e9b500
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
11 changes: 9 additions & 2 deletions xla/service/gpu/model/gpu_indexing_performance_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,16 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton(
LaunchDimensions
GpuPerformanceModelWithIndexingAnalysis::GetLaunchDimensionsForTiledFusion(
const TiledHloComputation& tiled_hlo_computation) {
const auto* tiled_root = tiled_hlo_computation.GetRoot();
int64_t num_blocks = tiled_hlo_computation.num_output_tiles();
int64_t num_warps = GetNumWarps(GetPaddedTileSize(tiled_root->tile_sizes()));

// Decide on the number of warps to use based on the largest live tile size
// at any given point within the computation.
int64_t largest_live_tile_size = 1;
for (const auto& tiled_hlo : tiled_hlo_computation.instructions()) {
largest_live_tile_size = std::max(
largest_live_tile_size, GetPaddedTileSize(tiled_hlo->tile_sizes()));
}
int64_t num_warps = GetNumWarps(largest_live_tile_size);

return {static_cast<uint64_t>(num_blocks),
static_cast<uint64_t>(num_warps * WarpSize())};
Expand Down
48 changes: 48 additions & 0 deletions xla/service/gpu/model/gpu_indexing_performance_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,54 @@ ENTRY main {
// and corresponds to 4 warps.
EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize());
}

TEST_F(GpuIndexingPerformanceModelTest,
NumberOfWarpsDependsOnLargestLiveTileSize) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
HloModule m
add {
param_0 = f32[] parameter(0)
param_1 = f32[] parameter(1)
ROOT add = f32[] add(param_0, param_1)
}
fusion_computation {
param_0 = f32[1,4096] parameter(0)
c0 = f32[] constant(0)
ROOT reduce = f32[1] reduce(param_0, c0), dimensions={1}, to_apply=add
}
ENTRY main {
param_0 = f32[1,4096] parameter(0)
ROOT fusion = f32[1] fusion(param_0), kind=kCustom,
calls=fusion_computation,
backend_config={"fusion_backend_config": {"kind":"__triton"}}
}
)"));
auto fusion_adaptor = HloFusionAdaptor::ForInstruction(
module->entry_computation()->root_instruction());

SymbolicTileAnalysisOrError analysis_or_error =
SymbolicTileAnalysis::AnalyzeFusion(
*fusion_adaptor, &mlir_context_,
/*emitter_specific_constraints_builder=*/nullptr);
ASSERT_TRUE(std::holds_alternative<SymbolicTileAnalysis>(analysis_or_error));

TF_ASSERT_OK_AND_ASSIGN(
TiledHloComputation tiled_hlo_computation,
std::get<SymbolicTileAnalysis>(analysis_or_error)
.ComputeTiledHloInstructions(/*tile_parameters=*/{1}));

LaunchDimensions launch_dimensions = GpuPerformanceModelWithIndexingAnalysis::
GetLaunchDimensionsForTiledFusion(tiled_hlo_computation);
EXPECT_EQ(launch_dimensions.num_blocks(), 1);

// The largest tile size is 1 * 4096, for which our implementation recommends
// using 4 warps.
EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize());
}

class FlopsPerElementTest : public GpuIndexingPerformanceModelTest {
public:
void CompareFlopsModels(absl::string_view hlo_module_string) {
Expand Down

0 comments on commit 9e9b500

Please sign in to comment.