diff --git a/xla/service/gpu/model/gpu_indexing_performance_model.cc b/xla/service/gpu/model/gpu_indexing_performance_model.cc index f8ac967dd1ac6..22ac7903c0bc4 100644 --- a/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -524,9 +524,16 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton( LaunchDimensions GpuPerformanceModelWithIndexingAnalysis::GetLaunchDimensionsForTiledFusion( const TiledHloComputation& tiled_hlo_computation) { - const auto* tiled_root = tiled_hlo_computation.GetRoot(); int64_t num_blocks = tiled_hlo_computation.num_output_tiles(); - int64_t num_warps = GetNumWarps(GetPaddedTileSize(tiled_root->tile_sizes())); + + // Decide on the number of warps to use based on the largest live tile size + // at any given point within the computation. + int64_t largest_live_tile_size = 1; + for (const auto& tiled_hlo : tiled_hlo_computation.instructions()) { + largest_live_tile_size = std::max( + largest_live_tile_size, GetPaddedTileSize(tiled_hlo->tile_sizes())); + } + int64_t num_warps = GetNumWarps(largest_live_tile_size); return {static_cast(num_blocks), static_cast(num_warps * WarpSize())}; diff --git a/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index f9f6b05702e79..0de3856b9864d 100644 --- a/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -620,6 +620,54 @@ ENTRY main { // and corresponds to 4 warps. EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize()); } + +TEST_F(GpuIndexingPerformanceModelTest, + NumberOfWarpsDependsOnLargestLiveTileSize) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add = f32[] add(param_0, param_1) +} + +fusion_computation { + param_0 = f32[1,4096] parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[1] reduce(param_0, c0), dimensions={1}, to_apply=add +} + +ENTRY main { + param_0 = f32[1,4096] parameter(0) + ROOT fusion = f32[1] fusion(param_0), kind=kCustom, + calls=fusion_computation, + backend_config={"fusion_backend_config": {"kind":"__triton"}} +} +)")); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction( + module->entry_computation()->root_instruction()); + + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeFusion( + *fusion_adaptor, &mlir_context_, + /*emitter_specific_constraints_builder=*/nullptr); + ASSERT_TRUE(std::holds_alternative(analysis_or_error)); + + TF_ASSERT_OK_AND_ASSIGN( + TiledHloComputation tiled_hlo_computation, + std::get(analysis_or_error) + .ComputeTiledHloInstructions(/*tile_parameters=*/{1})); + + LaunchDimensions launch_dimensions = GpuPerformanceModelWithIndexingAnalysis:: + GetLaunchDimensionsForTiledFusion(tiled_hlo_computation); + EXPECT_EQ(launch_dimensions.num_blocks(), 1); + + // The largest tile size is 1 * 4096, for which our implementation recommends + // using 4 warps. + EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize()); +} + class FlopsPerElementTest : public GpuIndexingPerformanceModelTest { public: void CompareFlopsModels(absl::string_view hlo_module_string) {