Skip to content

Commit

Permalink
[XLA:GPU] Disable cublas for BF16_BF6_F32 on Hopper because it uses t…
Browse files Browse the repository at this point in the history
…he TF32_TF32_F32 kernel instead.

We have to keep the precision guarantees when dot has the explicit algorithm property set as BF16_BF16_F32.

PiperOrigin-RevId: 681042944
  • Loading branch information
loislo authored and Google-ML-Automation committed Oct 1, 2024
1 parent 517b3b6 commit e1b38f8
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 56 deletions.
15 changes: 13 additions & 2 deletions xla/service/algorithm_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,26 @@ bool HasFastAccum(PrecisionConfig::Algorithm algorithm) {
return algorithm == PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM;
}

bool IsAmpere(stream_executor::GpuComputeCapability gpu_compute_capability) {
return std::holds_alternative<se::CudaComputeCapability>(
gpu_compute_capability) &&
std::get<se::CudaComputeCapability>(gpu_compute_capability).major ==
stream_executor::CudaComputeCapability::AMPERE;
}

// It's clear that those libraries could support more, but we only list the ones
// which we explicitly test for now.
bool IsSupportedByCublasOrCublasLt(PrecisionConfig::Algorithm algorithm) {
bool IsSupportedByCublasOrCublasLt(
PrecisionConfig::Algorithm algorithm,
stream_executor::GpuComputeCapability gpu_compute_capability) {
switch (algorithm) {
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
// Hopper does not have kernels for the algorithm but Ampere does.
return IsAmpere(gpu_compute_capability);
case PrecisionConfig::ALG_UNSET:
case PrecisionConfig::ALG_DOT_F16_F16_F32:
case PrecisionConfig::ALG_DOT_F32_F32_F32:
case PrecisionConfig::ALG_DOT_F64_F64_F64:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32:
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM:
Expand Down
4 changes: 3 additions & 1 deletion xla/service/algorithm_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ bool HasFastAccum(PrecisionConfig::Algorithm algorithm);
//
// We may want to also check storage types, but for now those are checked in
// IsSupportedDotAlgorithmOnGpu.
bool IsSupportedByCublasOrCublasLt(PrecisionConfig::Algorithm algorithm);
bool IsSupportedByCublasOrCublasLt(
PrecisionConfig::Algorithm algorithm,
stream_executor::GpuComputeCapability gpu_compute_capability);

// Checks if we support the given algorithm using cuDNN.
bool IsSupportedByCudnn(PrecisionConfig::Algorithm algorithm);
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/autotuning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ xla_test(
"//xla/stream_executor:device_description_proto_cc",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:test_utils",
Expand All @@ -151,6 +150,7 @@ xla_test(
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform:env",
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) {
if (!debug_options_.xla_gpu_experimental_disable_binary_libraries()) {
// Add cuBLAS reference config, if available.
if (algorithm_util::IsSupportedByCublasOrCublasLt(
dot->precision_config().algorithm()) &&
dot->precision_config().algorithm(), GetComputeCapability()) &&
!dot->sparse_operands() && IsAutotuningEnabled()) {
configs.push_back(CuBlasConfig{});
}
Expand Down
137 changes: 87 additions & 50 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "xla/autotuning.pb.h"
#include "xla/error_spec.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
Expand Down Expand Up @@ -197,59 +199,101 @@ class StatelessAutotunerTest : public HloTestBase {
debug_options, nullptr);
return autotuner.GenerateConfigs(fusion);
}

se::CudaComputeCapability GetCudaComputeCapability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
}

// Returns the config for the current device.
absl::StatusOr<std::vector<GemmFusionAutotunerImpl::BackendConfig>>
GetPossibleMatmulAutotuneConfigs(const HloModule& module) {
DeviceConfig device_config{backend().default_stream_executor(),
backend().memory_allocator()};
AutotuneConfig autotune_config{device_config, GetDebugOptionsForTest()};
GemmFusionAutotunerImpl autotuner(autotune_config, GetToolkitVersion(),
GetDebugOptionsForTest(), nullptr);
const HloFusionInstruction& fusion = *Cast<HloFusionInstruction>(
module.entry_computation()->root_instruction());
return autotuner.GenerateConfigs(fusion);
}

bool hasCublasConfig(
const std::vector<GemmFusionAutotunerImpl::BackendConfig>& configs) {
return std::any_of(
configs.begin(), configs.end(),
[](const GemmFusionAutotunerImpl::BackendConfig& config) {
return std::holds_alternative<GemmFusionAutotunerImpl::CuBlasConfig>(
config);
});
}
};

constexpr absl::string_view kHloDotFusionWithAlgorithm = R"(
HloModule module
computation {
p0 = f32[1024,1024] parameter(0)
p1 = f32[1024,1024] parameter(1)
ROOT r = f32[1024,1024] dot(p0, p1),
algorithm=$0,
lhs_contracting_dims={1},
rhs_contracting_dims={0}
}
ENTRY main {
p0 = f32[1024,1024] parameter(0)
p1 = f32[1024,1024] parameter(1)
ROOT computation = f32[1024,1024] fusion(f32[1024,1024] p0,f32[1024,1024] p1),
kind=kCustom,
calls=computation
}
)";

TEST_F(StatelessAutotunerTest, NoCublasFallbackForTf32Tf32F32X3Algorithm) {
// There is no cublas implementation for dot_tf32_tf32_f32_x3 at the moment.
// At the same time cublas f32 is faster than triton for this algorithm.
// But we don't want to fallback to cuBLAS in this case because we lose the
// precision guarantees.
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
HloModule module
computation {
p0 = f32[1024,1024] parameter(0)
p1 = f32[1024,1024] parameter(1)
ROOT r = f32[1024,1024] dot(p0, p1),
algorithm=dot_tf32_tf32_f32_x3,
lhs_contracting_dims={1},
rhs_contracting_dims={0}
}
ENTRY main {
p0 = f32[1024,1024] parameter(0)
p1 = f32[1024,1024] parameter(1)
ROOT computation = f32[1024,1024] fusion(f32[1024,1024] p0,f32[1024,1024] p1),
kind=kCustom,
calls=computation
}
)"));

const se::CudaComputeCapability ampere{se::CudaComputeCapability::AMPERE,
/*minor=*/0};
TF_ASSERT_OK_AND_ASSIGN(
auto ampere_configs,
GetPossibleMatmulAutotuneConfigs(*module, ampere, GetToolkitVersion(),
GetDebugOptionsForTest()));
EXPECT_FALSE(std::any_of(
ampere_configs.begin(), ampere_configs.end(),
[](const GemmFusionAutotunerImpl::BackendConfig& config) {
return std::holds_alternative<GemmFusionAutotunerImpl::CuBlasConfig>(
config);
}));
auto module, ParseAndReturnVerifiedModule(absl::Substitute(
kHloDotFusionWithAlgorithm, "dot_tf32_tf32_f32_x3")));

TF_ASSERT_OK_AND_ASSIGN(auto configs,
GetPossibleMatmulAutotuneConfigs(*module));
EXPECT_FALSE(hasCublasConfig(configs))
<< "There is no cublas implementation for dot_tf32_tf32_f32_x3. That is "
"why we don't want to fallback to cublas.";
}

const se::CudaComputeCapability hopper{se::CudaComputeCapability::HOPPER,
/*minor=*/0};
TEST_F(StatelessAutotunerTest,
NoCublasFallbackForBf16Bf16F32AlgorithmOnHopper) {
// There is no cublas implementation for dot_bf16_bf16_f32 at the moment.
// At the same time cublas f32 is faster than triton for this algorithm.
// But we don't want to fallback to cuBLAS in this case because we lose the
// precision guarantees.
TF_ASSERT_OK_AND_ASSIGN(
auto hopper_configs,
GetPossibleMatmulAutotuneConfigs(*module, hopper, GetToolkitVersion(),
GetDebugOptionsForTest()));
EXPECT_FALSE(std::any_of(
hopper_configs.begin(), hopper_configs.end(),
[](const GemmFusionAutotunerImpl::BackendConfig& config) {
return std::holds_alternative<GemmFusionAutotunerImpl::CuBlasConfig>(
config);
}));
auto module, ParseAndReturnVerifiedModule(absl::Substitute(
kHloDotFusionWithAlgorithm, "dot_bf16_bf16_f32")));

TF_ASSERT_OK_AND_ASSIGN(auto configs,
GetPossibleMatmulAutotuneConfigs(*module));
switch (GetCudaComputeCapability().major) {
case se::CudaComputeCapability::AMPERE:
EXPECT_TRUE(hasCublasConfig(configs))
<< "There is a cublas implementation for dot_bf16_bf16_f32 on Ampere";
break;
case se::CudaComputeCapability::HOPPER:
EXPECT_FALSE(hasCublasConfig(configs))
<< "There is no cublas implementation for dot_bf16_bf16_f32 on "
"Hopper. That is why we don't want to fallback to cublas.";
break;
default:
// We don't know what to expect for other compute capabilities.
EXPECT_FALSE(hasCublasConfig(configs));
}
}

class GemmFusionAutotunerTest : public StatelessAutotunerTest {
Expand All @@ -263,13 +307,6 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest {
return debug_options;
}

se::CudaComputeCapability GetCudaComputeCapability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
}

void CheckTritonAutotuning(absl::string_view hlo,
absl::string_view expected) {
HloPassPipeline pipeline("gemm_rewrite");
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/transforms/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1991,7 +1991,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
backend_config.precision_config().operand_precision());
const PrecisionConfig::Algorithm algorithm =
backend_config.precision_config().algorithm();
if (!algorithm_util::IsSupportedByCublasOrCublasLt(algorithm)) return false;
if (!algorithm_util::IsSupportedByCublasOrCublasLt(algorithm, gpu_version_))
return false;

TF_ASSIGN_OR_RETURN(
const se::blas::ComputationType compute_type,
Expand Down

0 comments on commit e1b38f8

Please sign in to comment.