From 38b61443d75e3389d05ae1461be2ef373b4c016b Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 2 Oct 2024 02:20:12 -0700 Subject: [PATCH] [XLA:GPU] Generalize the Reduce-Precision in MHLO to also work on Tensors and use it in a Triton emitter. PiperOrigin-RevId: 681364188 --- xla/mlir_hlo/BUILD | 17 ++ .../mhlo/transforms/map_mhlo_to_scalar_op.h | 115 +----------- .../mhlo/transforms/transformation_helpers.h | 174 ++++++++++++++++++ xla/service/gpu/fusions/triton/BUILD | 1 + .../fusions/triton/triton_fusion_emitter.cc | 4 + .../triton_fusion_emitter_device_test.cc | 24 +++ .../gpu/fusions/triton/triton_support.cc | 5 + 7 files changed, 230 insertions(+), 110 deletions(-) create mode 100644 xla/mlir_hlo/mhlo/transforms/transformation_helpers.h diff --git a/xla/mlir_hlo/BUILD b/xla/mlir_hlo/BUILD index 8de9f293aeb25..30eb4a8fccf14 100644 --- a/xla/mlir_hlo/BUILD +++ b/xla/mlir_hlo/BUILD @@ -495,6 +495,23 @@ cc_library( name = "map_mhlo_to_scalar_op", hdrs = ["mhlo/transforms/map_mhlo_to_scalar_op.h"], strip_include_prefix = ".", + deps = [ + ":mlir_hlo", + ":transformation_helpers", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "transformation_helpers", + hdrs = ["mhlo/transforms/transformation_helpers.h"], + strip_include_prefix = ".", deps = [ ":mlir_hlo", "@llvm-project//llvm:Support", diff --git a/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h b/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h index c3b7f7103841f..ccfdfcd9a5952 100644 --- a/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h +++ b/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/transformation_helpers.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -469,117 +470,11 @@ inline Value mapMhloOpToStdScalarOp( template <> inline Value mapMhloOpToStdScalarOp( - Location loc, ArrayRef /*resultTypes*/, ArrayRef argTypes, + Location loc, ArrayRef /*resultTypes*/, ArrayRef /*argTypes*/, mhlo::ReducePrecisionOp::Adaptor adaptor, OpBuilder* builder) { - using llvm::APInt; - mlir::ImplicitLocOpBuilder b(loc, *builder); - - // Integer and float types for casting and constant generation. - auto floatType = - mlir::cast(getElementTypeOrSelf(argTypes.front())); - int64_t nbits = floatType.getWidth(); - auto intType = mlir::IntegerType::get(loc.getContext(), nbits); - - Value xAsInt = b.create(intType, adaptor.getOperand()); - - // SignificandWidth includes the implicit extra bit. - auto srcMantissaBits = floatType.getFPMantissaWidth() - 1; - int srcExponentBits = nbits - 1 - srcMantissaBits; - - // Clear the sign bit, it does not participate in rounding and we will restore - // it later. - APInt signBitMask(nbits, 1); - signBitMask <<= nbits - 1; - - APInt expBitsMask(nbits, 1); - expBitsMask = ((expBitsMask << srcExponentBits) - 1) << srcMantissaBits; - - auto createConstant = [&](const APInt& v) { - return b.create(v.getZExtValue(), intType) - .getResult(); - }; - - Value xAbsBits = - b.create(xAsInt, createConstant(~signBitMask)); - Value xIsNan = b.create(arith::CmpIPredicate::ugt, xAbsBits, - createConstant(expBitsMask)); - - int destMantissaBits = adaptor.getMantissaBits(); - if (destMantissaBits < static_cast(srcMantissaBits)) { - // Last remaining mantissa bit. - APInt lastMantissaBitMask(nbits, 1); - lastMantissaBitMask <<= srcMantissaBits - destMantissaBits; - - // Compute rounding bias for round-to-nearest with ties to even. This is - // equal to a base value of 0111... plus one bit if the last remaining - // mantissa bit is 1. - APInt baseRoundingBias = lastMantissaBitMask.lshr(1) - 1; - - Value mantissaDiff = b.create( - srcMantissaBits - destMantissaBits, intType); - Value highestMantissaMaskVal = createConstant(lastMantissaBitMask); - Value baseRoundingBiasVal = createConstant(baseRoundingBias); - Value xLastMantissaBit = b.create( - b.create(xAsInt, highestMantissaMaskVal), mantissaDiff); - Value xRoundingBias = - b.create(xLastMantissaBit, baseRoundingBiasVal); - - // Add rounding bias, and mask out truncated bits. Note that the case - // where adding the rounding bias overflows into the exponent bits is - // correct; the non-masked mantissa bits will all be zero, and the - // exponent will be incremented by one. - APInt truncationMask = ~(lastMantissaBitMask - 1); - Value xRounded = b.create(xAsInt, xRoundingBias); - xAsInt = b.create(xRounded, createConstant(truncationMask)); - } - - int destExponentBits = adaptor.getExponentBits(); - if (destExponentBits < srcExponentBits) { - // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- - // significant bit -- is equal to 1.0f for all exponent sizes. Adding - // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- - // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' - // exponent (corresponding to 0.0f). - // - // Thus, the f32 exponent corresponding to the highest non-infinite - // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 - // exponent corresponding to the lowest exponent for a bit size of n is - // (2^7-1) - 2^(n-1)-1. - // - // Note that we have already checked that exponents_bits >= 1. - APInt exponentBias(nbits, 1); - exponentBias = (exponentBias << (srcExponentBits - 1)) - 1; - - APInt reducedExponentBias(nbits, 1); - reducedExponentBias = (reducedExponentBias << (destExponentBits - 1)) - 1; - - APInt reducedMaxExponent = exponentBias + reducedExponentBias; - APInt reducedMinExponent = exponentBias - reducedExponentBias; - - // Do we overflow or underflow? - Value xExponent = - b.create(xAsInt, createConstant(expBitsMask)); - Value xOverflows = b.create( - arith::CmpIPredicate::ugt, xExponent, - createConstant(reducedMaxExponent << srcMantissaBits)); - Value xUnderflows = b.create( - arith::CmpIPredicate::ule, xExponent, - createConstant(reducedMinExponent << srcMantissaBits)); - - // Compute appropriately-signed values of zero and infinity. - Value xSignedZero = - b.create(xAsInt, createConstant(signBitMask)); - Value xSignedInf = - b.create(xSignedZero, createConstant(expBitsMask)); - - // Force to zero or infinity if overflow or underflow. (Note that this - // truncates all denormal values to zero, rather than rounding them.) - xAsInt = b.create(xOverflows, xSignedInf, xAsInt); - xAsInt = b.create(xUnderflows, xSignedZero, xAsInt); - } - - Value result = b.create(floatType, xAsInt); - return b.create(xIsNan, adaptor.getOperand(), result); + return reducePrecision(loc, adaptor.getOperand(), + adaptor.getExponentBits(), + adaptor.getMantissaBits(), builder); } template <> diff --git a/xla/mlir_hlo/mhlo/transforms/transformation_helpers.h b/xla/mlir_hlo/mhlo/transforms/transformation_helpers.h new file mode 100644 index 0000000000000..211e973bab39a --- /dev/null +++ b/xla/mlir_hlo/mhlo/transforms/transformation_helpers.h @@ -0,0 +1,174 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_ +#define XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_ + +#include +#include +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" + +namespace mlir::mhlo { + +// Creates an integer constant that is either a tensor (if shape is provided) or +// a scalar. +template +arith::ConstantOp createConst(ImplicitLocOpBuilder& b, + mlir::IntegerType intType, T value, + std::optional> shape) { + if (shape.has_value()) { + auto tensorType = mlir::RankedTensorType::get(shape.value(), intType); + return b.create(mlir::DenseElementsAttr::get( + tensorType, mlir::APInt(intType.getIntOrFloatBitWidth(), value))); + } + return b.create(b.getIntegerAttr(intType, value)); +} + +// Returns the input value with a reduced precision as specified by the target +// exponent and mantissa bits. This function will preserve the input shape on +// the output - i.e. it works with both scalars and tensors. +// +// The templated bitcast type allows this function to work with different kinds +// of bitcats, e.g. `arith.bitcast` or `triton.bitcast`. +template +Value reducePrecision(Location loc, Value input, int destExponentBits, + int destMantissaBits, OpBuilder* builder) { + using llvm::APInt; + mlir::ImplicitLocOpBuilder b(loc, *builder); + + // Integer and float types for casting and constant generation. + auto floatType = mlir::cast(getElementTypeOrSelf(input.getType())); + int64_t nbits = floatType.getWidth(); + auto intScalarType = mlir::IntegerType::get(loc.getContext(), nbits); + + Type intType = intScalarType; + std::optional> shape; + if (auto tensorType = llvm::dyn_cast(input.getType())) { + shape = tensorType.getShape().vec(); + intType = tensorType.clone(intScalarType); + } + + Value xAsInt = b.create(intType, input); + + // SignificandWidth includes the implicit extra bit. + auto srcMantissaBits = floatType.getFPMantissaWidth() - 1; + int srcExponentBits = nbits - 1 - srcMantissaBits; + + // Clear the sign bit, it does not participate in rounding and we will restore + // it later. + APInt signBitMask(nbits, 1); + signBitMask <<= nbits - 1; + + APInt expBitsMask(nbits, 1); + expBitsMask = ((expBitsMask << srcExponentBits) - 1) << srcMantissaBits; + + auto createConstant = [&](const APInt& v) { + return createConst(b, intScalarType, v.getZExtValue(), shape); + }; + + Value xAbsBits = + b.create(xAsInt, createConstant(~signBitMask)); + Value xIsNan = b.create(arith::CmpIPredicate::ugt, xAbsBits, + createConstant(expBitsMask)); + + if (destMantissaBits < static_cast(srcMantissaBits)) { + // Last remaining mantissa bit. + APInt lastMantissaBitMask(nbits, 1); + lastMantissaBitMask <<= srcMantissaBits - destMantissaBits; + + // Compute rounding bias for round-to-nearest with ties to even. This is + // equal to a base value of 0111... plus one bit if the last remaining + // mantissa bit is 1. + APInt baseRoundingBias = lastMantissaBitMask.lshr(1) - 1; + + Value mantissaDiff = createConst(b, intScalarType, + srcMantissaBits - destMantissaBits, shape); + + Value highestMantissaMaskVal = createConstant(lastMantissaBitMask); + Value baseRoundingBiasVal = createConstant(baseRoundingBias); + Value xLastMantissaBit = b.create( + b.create(xAsInt, highestMantissaMaskVal), mantissaDiff); + Value xRoundingBias = + b.create(xLastMantissaBit, baseRoundingBiasVal); + + // Add rounding bias, and mask out truncated bits. Note that the case + // where adding the rounding bias overflows into the exponent bits is + // correct; the non-masked mantissa bits will all be zero, and the + // exponent will be incremented by one. + APInt truncationMask = ~(lastMantissaBitMask - 1); + Value xRounded = b.create(xAsInt, xRoundingBias); + xAsInt = b.create(xRounded, createConstant(truncationMask)); + } + + if (destExponentBits < srcExponentBits) { + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- + // significant bit -- is equal to 1.0f for all exponent sizes. Adding + // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- + // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' + // exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n is + // (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + APInt exponentBias(nbits, 1); + exponentBias = (exponentBias << (srcExponentBits - 1)) - 1; + + APInt reducedExponentBias(nbits, 1); + reducedExponentBias = (reducedExponentBias << (destExponentBits - 1)) - 1; + + APInt reducedMaxExponent = exponentBias + reducedExponentBias; + APInt reducedMinExponent = exponentBias - reducedExponentBias; + + // Do we overflow or underflow? + Value xExponent = + b.create(xAsInt, createConstant(expBitsMask)); + Value xOverflows = b.create( + arith::CmpIPredicate::ugt, xExponent, + createConstant(reducedMaxExponent << srcMantissaBits)); + Value xUnderflows = b.create( + arith::CmpIPredicate::ule, xExponent, + createConstant(reducedMinExponent << srcMantissaBits)); + + // Compute appropriately-signed values of zero and infinity. + Value xSignedZero = + b.create(xAsInt, createConstant(signBitMask)); + Value xSignedInf = + b.create(xSignedZero, createConstant(expBitsMask)); + + // Force to zero or infinity if overflow or underflow. (Note that this + // truncates all denormal values to zero, rather than rounding them.) + xAsInt = b.create(xOverflows, xSignedInf, xAsInt); + xAsInt = b.create(xUnderflows, xSignedZero, xAsInt); + } + + Value result = b.create(input.getType(), xAsInt); + return b.create(xIsNan, input, result); +} +} // namespace mlir::mhlo + +#endif // XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_ diff --git a/xla/service/gpu/fusions/triton/BUILD b/xla/service/gpu/fusions/triton/BUILD index 4515e19239d3d..a0ce50ae54d4f 100644 --- a/xla/service/gpu/fusions/triton/BUILD +++ b/xla/service/gpu/fusions/triton/BUILD @@ -51,6 +51,7 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", + "//xla/mlir_hlo:transformation_helpers", "//xla/service:algorithm_util", "//xla/service:dump", "//xla/service:hlo_module_config", diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index 69bd809d345b5..126af60418587 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -106,6 +106,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "xla/mlir_hlo/mhlo/transforms/transformation_helpers.h" #include "xla/permutation_util.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" @@ -585,6 +586,9 @@ absl::StatusOr EmitElementwise(ImplicitLocOpBuilder& b, Compare(b, {inputs[0], ZerosLike(b, inputs[0])}, mlir::mhlo::ComparisonDirection::NE), inputs[1], inputs[2]); + case HloOpcode::kReducePrecision: + return mlir::mhlo::reducePrecision( + b.getLoc(), inputs[0], hlo.exponent_bits(), hlo.mantissa_bits(), &b); default: return absl::InvalidArgumentError( absl::StrCat("Unsupported elementwise operation ", hlo.ToString())); diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index 1cb92de9436e4..2541f74d4e257 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -1265,6 +1265,30 @@ INSTANTIATE_TEST_SUITE_P(IotaEmitterParametrizedTestSuite, ::testing::ValuesIn({S8, S16, S32, S64, BF16, F16, F32, F64})); +TEST_F(TritonEmitterTest, ReducePrecisionIsLoweredCorrectly) { + const std::string kHloText = R"( +triton_computation { + p = f32[5,7] parameter(0) + ROOT rp = f32[5,7] reduce-precision(p), exponent_bits=2, mantissa_bits=2 +} + +ENTRY entry_computation { + p = f32[5,7] parameter(0) + ROOT fusion = f32[5,7] fusion(p), kind=kCustom, calls=triton_computation, + backend_config={ + "fusion_backend_config":{ "kind":"__triton", "block_level_fusion_config":{ + "output_tile_sizes":["4","4"], "num_warps":"1"}} + } +})"; + TF_EXPECT_OK( + CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( +CHECK: tt.load +)")); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/triton/triton_support.cc b/xla/service/gpu/fusions/triton/triton_support.cc index d006ae65fcc55..57e92d9717bc1 100644 --- a/xla/service/gpu/fusions/triton/triton_support.cc +++ b/xla/service/gpu/fusions/triton/triton_support.cc @@ -107,6 +107,11 @@ absl::flat_hash_set TritonSupportedUnaryElementwiseOps( HloOpcode::kCeil}; ret.insert(additional_opcodes.begin(), additional_opcodes.end()); } + + if (primitive_util::IsFloatingPointType(element_type)) { + ret.insert(HloOpcode::kReducePrecision); + } + return ret; }