From 9e9b5005edb8a243c8964c931af8f251b862961b Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 30 Sep 2024 12:07:45 -0700 Subject: [PATCH] [XLA:GPU] Fix the derivation for the number of warps for tiled HLO computations. The number of warps used to process a computation determines how many registers we are able to use concurrently. Therefore, looking at the largest (padded) tile size makes sense, since it determines the minimum number of elements that must be live concurrently. Previously, the logic erroneously only looked at the output tile sizes. This approach is not perfect, and may be further improved by e.g. doing a live range analysis on the tiles of the computation. PiperOrigin-RevId: 680668856 --- .../model/gpu_indexing_performance_model.cc | 11 ++++- .../gpu_indexing_performance_model_test.cc | 48 +++++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/model/gpu_indexing_performance_model.cc b/xla/service/gpu/model/gpu_indexing_performance_model.cc index f8ac967dd1ac6..22ac7903c0bc4 100644 --- a/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -524,9 +524,16 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton( LaunchDimensions GpuPerformanceModelWithIndexingAnalysis::GetLaunchDimensionsForTiledFusion( const TiledHloComputation& tiled_hlo_computation) { - const auto* tiled_root = tiled_hlo_computation.GetRoot(); int64_t num_blocks = tiled_hlo_computation.num_output_tiles(); - int64_t num_warps = GetNumWarps(GetPaddedTileSize(tiled_root->tile_sizes())); + + // Decide on the number of warps to use based on the largest live tile size + // at any given point within the computation. + int64_t largest_live_tile_size = 1; + for (const auto& tiled_hlo : tiled_hlo_computation.instructions()) { + largest_live_tile_size = std::max( + largest_live_tile_size, GetPaddedTileSize(tiled_hlo->tile_sizes())); + } + int64_t num_warps = GetNumWarps(largest_live_tile_size); return {static_cast(num_blocks), static_cast(num_warps * WarpSize())}; diff --git a/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index f9f6b05702e79..0de3856b9864d 100644 --- a/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -620,6 +620,54 @@ ENTRY main { // and corresponds to 4 warps. EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize()); } + +TEST_F(GpuIndexingPerformanceModelTest, + NumberOfWarpsDependsOnLargestLiveTileSize) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add = f32[] add(param_0, param_1) +} + +fusion_computation { + param_0 = f32[1,4096] parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[1] reduce(param_0, c0), dimensions={1}, to_apply=add +} + +ENTRY main { + param_0 = f32[1,4096] parameter(0) + ROOT fusion = f32[1] fusion(param_0), kind=kCustom, + calls=fusion_computation, + backend_config={"fusion_backend_config": {"kind":"__triton"}} +} +)")); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction( + module->entry_computation()->root_instruction()); + + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeFusion( + *fusion_adaptor, &mlir_context_, + /*emitter_specific_constraints_builder=*/nullptr); + ASSERT_TRUE(std::holds_alternative(analysis_or_error)); + + TF_ASSERT_OK_AND_ASSIGN( + TiledHloComputation tiled_hlo_computation, + std::get(analysis_or_error) + .ComputeTiledHloInstructions(/*tile_parameters=*/{1})); + + LaunchDimensions launch_dimensions = GpuPerformanceModelWithIndexingAnalysis:: + GetLaunchDimensionsForTiledFusion(tiled_hlo_computation); + EXPECT_EQ(launch_dimensions.num_blocks(), 1); + + // The largest tile size is 1 * 4096, for which our implementation recommends + // using 4 warps. + EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize()); +} + class FlopsPerElementTest : public GpuIndexingPerformanceModelTest { public: void CompareFlopsModels(absl::string_view hlo_module_string) {