Skip to content

Commit

Permalink
[XLA:GPU] Generalize the Reduce-Precision in MHLO to also work on Ten…
Browse files Browse the repository at this point in the history
…sors and use it in a Triton emitter.

PiperOrigin-RevId: 681364188
  • Loading branch information
dimitar-asenov authored and Google-ML-Automation committed Oct 2, 2024
1 parent a4cb02b commit 38b6144
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 110 deletions.
17 changes: 17 additions & 0 deletions xla/mlir_hlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,23 @@ cc_library(
name = "map_mhlo_to_scalar_op",
hdrs = ["mhlo/transforms/map_mhlo_to_scalar_op.h"],
strip_include_prefix = ".",
deps = [
":mlir_hlo",
":transformation_helpers",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
],
)

cc_library(
name = "transformation_helpers",
hdrs = ["mhlo/transforms/transformation_helpers.h"],
strip_include_prefix = ".",
deps = [
":mlir_hlo",
"@llvm-project//llvm:Support",
Expand Down
115 changes: 5 additions & 110 deletions xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "mhlo/IR/hlo_ops.h"
#include "mhlo/transforms/transformation_helpers.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
Expand Down Expand Up @@ -469,117 +470,11 @@ inline Value mapMhloOpToStdScalarOp<mhlo::CompareOp>(

template <>
inline Value mapMhloOpToStdScalarOp<mhlo::ReducePrecisionOp>(
Location loc, ArrayRef<Type> /*resultTypes*/, ArrayRef<Type> argTypes,
Location loc, ArrayRef<Type> /*resultTypes*/, ArrayRef<Type> /*argTypes*/,
mhlo::ReducePrecisionOp::Adaptor adaptor, OpBuilder* builder) {
using llvm::APInt;
mlir::ImplicitLocOpBuilder b(loc, *builder);

// Integer and float types for casting and constant generation.
auto floatType =
mlir::cast<FloatType>(getElementTypeOrSelf(argTypes.front()));
int64_t nbits = floatType.getWidth();
auto intType = mlir::IntegerType::get(loc.getContext(), nbits);

Value xAsInt = b.create<arith::BitcastOp>(intType, adaptor.getOperand());

// SignificandWidth includes the implicit extra bit.
auto srcMantissaBits = floatType.getFPMantissaWidth() - 1;
int srcExponentBits = nbits - 1 - srcMantissaBits;

// Clear the sign bit, it does not participate in rounding and we will restore
// it later.
APInt signBitMask(nbits, 1);
signBitMask <<= nbits - 1;

APInt expBitsMask(nbits, 1);
expBitsMask = ((expBitsMask << srcExponentBits) - 1) << srcMantissaBits;

auto createConstant = [&](const APInt& v) {
return b.create<arith::ConstantIntOp>(v.getZExtValue(), intType)
.getResult();
};

Value xAbsBits =
b.create<arith::AndIOp>(xAsInt, createConstant(~signBitMask));
Value xIsNan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, xAbsBits,
createConstant(expBitsMask));

int destMantissaBits = adaptor.getMantissaBits();
if (destMantissaBits < static_cast<int>(srcMantissaBits)) {
// Last remaining mantissa bit.
APInt lastMantissaBitMask(nbits, 1);
lastMantissaBitMask <<= srcMantissaBits - destMantissaBits;

// Compute rounding bias for round-to-nearest with ties to even. This is
// equal to a base value of 0111... plus one bit if the last remaining
// mantissa bit is 1.
APInt baseRoundingBias = lastMantissaBitMask.lshr(1) - 1;

Value mantissaDiff = b.create<arith::ConstantIntOp>(
srcMantissaBits - destMantissaBits, intType);
Value highestMantissaMaskVal = createConstant(lastMantissaBitMask);
Value baseRoundingBiasVal = createConstant(baseRoundingBias);
Value xLastMantissaBit = b.create<arith::ShRUIOp>(
b.create<arith::AndIOp>(xAsInt, highestMantissaMaskVal), mantissaDiff);
Value xRoundingBias =
b.create<arith::AddIOp>(xLastMantissaBit, baseRoundingBiasVal);

// Add rounding bias, and mask out truncated bits. Note that the case
// where adding the rounding bias overflows into the exponent bits is
// correct; the non-masked mantissa bits will all be zero, and the
// exponent will be incremented by one.
APInt truncationMask = ~(lastMantissaBitMask - 1);
Value xRounded = b.create<arith::AddIOp>(xAsInt, xRoundingBias);
xAsInt = b.create<arith::AndIOp>(xRounded, createConstant(truncationMask));
}

int destExponentBits = adaptor.getExponentBits();
if (destExponentBits < srcExponentBits) {
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
// significant bit -- is equal to 1.0f for all exponent sizes. Adding
// 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
// size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
// exponent (corresponding to 0.0f).
//
// Thus, the f32 exponent corresponding to the highest non-infinite
// exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
// exponent corresponding to the lowest exponent for a bit size of n is
// (2^7-1) - 2^(n-1)-1.
//
// Note that we have already checked that exponents_bits >= 1.
APInt exponentBias(nbits, 1);
exponentBias = (exponentBias << (srcExponentBits - 1)) - 1;

APInt reducedExponentBias(nbits, 1);
reducedExponentBias = (reducedExponentBias << (destExponentBits - 1)) - 1;

APInt reducedMaxExponent = exponentBias + reducedExponentBias;
APInt reducedMinExponent = exponentBias - reducedExponentBias;

// Do we overflow or underflow?
Value xExponent =
b.create<arith::AndIOp>(xAsInt, createConstant(expBitsMask));
Value xOverflows = b.create<arith::CmpIOp>(
arith::CmpIPredicate::ugt, xExponent,
createConstant(reducedMaxExponent << srcMantissaBits));
Value xUnderflows = b.create<arith::CmpIOp>(
arith::CmpIPredicate::ule, xExponent,
createConstant(reducedMinExponent << srcMantissaBits));

// Compute appropriately-signed values of zero and infinity.
Value xSignedZero =
b.create<arith::AndIOp>(xAsInt, createConstant(signBitMask));
Value xSignedInf =
b.create<arith::OrIOp>(xSignedZero, createConstant(expBitsMask));

// Force to zero or infinity if overflow or underflow. (Note that this
// truncates all denormal values to zero, rather than rounding them.)
xAsInt = b.create<arith::SelectOp>(xOverflows, xSignedInf, xAsInt);
xAsInt = b.create<arith::SelectOp>(xUnderflows, xSignedZero, xAsInt);
}

Value result = b.create<arith::BitcastOp>(floatType, xAsInt);
return b.create<arith::SelectOp>(xIsNan, adaptor.getOperand(), result);
return reducePrecision<arith::BitcastOp>(loc, adaptor.getOperand(),
adaptor.getExponentBits(),
adaptor.getMantissaBits(), builder);
}

template <>
Expand Down
174 changes: 174 additions & 0 deletions xla/mlir_hlo/mhlo/transforms/transformation_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_
#define XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_

#include <cstdint>
#include <optional>
#include <vector>

#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"

namespace mlir::mhlo {

// Creates an integer constant that is either a tensor (if shape is provided) or
// a scalar.
template <typename T>
arith::ConstantOp createConst(ImplicitLocOpBuilder& b,
mlir::IntegerType intType, T value,
std::optional<ArrayRef<int64_t>> shape) {
if (shape.has_value()) {
auto tensorType = mlir::RankedTensorType::get(shape.value(), intType);
return b.create<arith::ConstantOp>(mlir::DenseElementsAttr::get(
tensorType, mlir::APInt(intType.getIntOrFloatBitWidth(), value)));
}
return b.create<arith::ConstantOp>(b.getIntegerAttr(intType, value));
}

// Returns the input value with a reduced precision as specified by the target
// exponent and mantissa bits. This function will preserve the input shape on
// the output - i.e. it works with both scalars and tensors.
//
// The templated bitcast type allows this function to work with different kinds
// of bitcats, e.g. `arith.bitcast` or `triton.bitcast`.
template <typename BitCastOp>
Value reducePrecision(Location loc, Value input, int destExponentBits,
int destMantissaBits, OpBuilder* builder) {
using llvm::APInt;
mlir::ImplicitLocOpBuilder b(loc, *builder);

// Integer and float types for casting and constant generation.
auto floatType = mlir::cast<FloatType>(getElementTypeOrSelf(input.getType()));
int64_t nbits = floatType.getWidth();
auto intScalarType = mlir::IntegerType::get(loc.getContext(), nbits);

Type intType = intScalarType;
std::optional<std::vector<int64_t>> shape;
if (auto tensorType = llvm::dyn_cast<TensorType>(input.getType())) {
shape = tensorType.getShape().vec();
intType = tensorType.clone(intScalarType);
}

Value xAsInt = b.create<BitCastOp>(intType, input);

// SignificandWidth includes the implicit extra bit.
auto srcMantissaBits = floatType.getFPMantissaWidth() - 1;
int srcExponentBits = nbits - 1 - srcMantissaBits;

// Clear the sign bit, it does not participate in rounding and we will restore
// it later.
APInt signBitMask(nbits, 1);
signBitMask <<= nbits - 1;

APInt expBitsMask(nbits, 1);
expBitsMask = ((expBitsMask << srcExponentBits) - 1) << srcMantissaBits;

auto createConstant = [&](const APInt& v) {
return createConst(b, intScalarType, v.getZExtValue(), shape);
};

Value xAbsBits =
b.create<arith::AndIOp>(xAsInt, createConstant(~signBitMask));
Value xIsNan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, xAbsBits,
createConstant(expBitsMask));

if (destMantissaBits < static_cast<int>(srcMantissaBits)) {
// Last remaining mantissa bit.
APInt lastMantissaBitMask(nbits, 1);
lastMantissaBitMask <<= srcMantissaBits - destMantissaBits;

// Compute rounding bias for round-to-nearest with ties to even. This is
// equal to a base value of 0111... plus one bit if the last remaining
// mantissa bit is 1.
APInt baseRoundingBias = lastMantissaBitMask.lshr(1) - 1;

Value mantissaDiff = createConst(b, intScalarType,
srcMantissaBits - destMantissaBits, shape);

Value highestMantissaMaskVal = createConstant(lastMantissaBitMask);
Value baseRoundingBiasVal = createConstant(baseRoundingBias);
Value xLastMantissaBit = b.create<arith::ShRUIOp>(
b.create<arith::AndIOp>(xAsInt, highestMantissaMaskVal), mantissaDiff);
Value xRoundingBias =
b.create<arith::AddIOp>(xLastMantissaBit, baseRoundingBiasVal);

// Add rounding bias, and mask out truncated bits. Note that the case
// where adding the rounding bias overflows into the exponent bits is
// correct; the non-masked mantissa bits will all be zero, and the
// exponent will be incremented by one.
APInt truncationMask = ~(lastMantissaBitMask - 1);
Value xRounded = b.create<arith::AddIOp>(xAsInt, xRoundingBias);
xAsInt = b.create<arith::AndIOp>(xRounded, createConstant(truncationMask));
}

if (destExponentBits < srcExponentBits) {
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
// significant bit -- is equal to 1.0f for all exponent sizes. Adding
// 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
// size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
// exponent (corresponding to 0.0f).
//
// Thus, the f32 exponent corresponding to the highest non-infinite
// exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
// exponent corresponding to the lowest exponent for a bit size of n is
// (2^7-1) - 2^(n-1)-1.
//
// Note that we have already checked that exponents_bits >= 1.
APInt exponentBias(nbits, 1);
exponentBias = (exponentBias << (srcExponentBits - 1)) - 1;

APInt reducedExponentBias(nbits, 1);
reducedExponentBias = (reducedExponentBias << (destExponentBits - 1)) - 1;

APInt reducedMaxExponent = exponentBias + reducedExponentBias;
APInt reducedMinExponent = exponentBias - reducedExponentBias;

// Do we overflow or underflow?
Value xExponent =
b.create<arith::AndIOp>(xAsInt, createConstant(expBitsMask));
Value xOverflows = b.create<arith::CmpIOp>(
arith::CmpIPredicate::ugt, xExponent,
createConstant(reducedMaxExponent << srcMantissaBits));
Value xUnderflows = b.create<arith::CmpIOp>(
arith::CmpIPredicate::ule, xExponent,
createConstant(reducedMinExponent << srcMantissaBits));

// Compute appropriately-signed values of zero and infinity.
Value xSignedZero =
b.create<arith::AndIOp>(xAsInt, createConstant(signBitMask));
Value xSignedInf =
b.create<arith::OrIOp>(xSignedZero, createConstant(expBitsMask));

// Force to zero or infinity if overflow or underflow. (Note that this
// truncates all denormal values to zero, rather than rounding them.)
xAsInt = b.create<arith::SelectOp>(xOverflows, xSignedInf, xAsInt);
xAsInt = b.create<arith::SelectOp>(xUnderflows, xSignedZero, xAsInt);
}

Value result = b.create<BitCastOp>(input.getType(), xAsInt);
return b.create<arith::SelectOp>(xIsNan, input, result);
}
} // namespace mlir::mhlo

#endif // XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_
1 change: 1 addition & 0 deletions xla/service/gpu/fusions/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ cc_library(
"//xla/hlo/utils:hlo_query",
"//xla/mlir_hlo",
"//xla/mlir_hlo:map_mhlo_to_scalar_op",
"//xla/mlir_hlo:transformation_helpers",
"//xla/service:algorithm_util",
"//xla/service:dump",
"//xla/service:hlo_module_config",
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ limitations under the License.
#include "xla/literal.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h"
#include "xla/mlir_hlo/mhlo/transforms/transformation_helpers.h"
#include "xla/permutation_util.h"
#include "xla/primitive_util.h"
#include "xla/service/algorithm_util.h"
Expand Down Expand Up @@ -585,6 +586,9 @@ absl::StatusOr<Value> EmitElementwise(ImplicitLocOpBuilder& b,
Compare(b, {inputs[0], ZerosLike(b, inputs[0])},
mlir::mhlo::ComparisonDirection::NE),
inputs[1], inputs[2]);
case HloOpcode::kReducePrecision:
return mlir::mhlo::reducePrecision<mt::BitcastOp>(
b.getLoc(), inputs[0], hlo.exponent_bits(), hlo.mantissa_bits(), &b);
default:
return absl::InvalidArgumentError(
absl::StrCat("Unsupported elementwise operation ", hlo.ToString()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,30 @@ INSTANTIATE_TEST_SUITE_P(IotaEmitterParametrizedTestSuite,
::testing::ValuesIn({S8, S16, S32, S64, BF16, F16, F32,
F64}));

TEST_F(TritonEmitterTest, ReducePrecisionIsLoweredCorrectly) {
const std::string kHloText = R"(
triton_computation {
p = f32[5,7] parameter(0)
ROOT rp = f32[5,7] reduce-precision(p), exponent_bits=2, mantissa_bits=2
}
ENTRY entry_computation {
p = f32[5,7] parameter(0)
ROOT fusion = f32[5,7] fusion(p), kind=kCustom, calls=triton_computation,
backend_config={
"fusion_backend_config":{ "kind":"__triton", "block_level_fusion_config":{
"output_tile_sizes":["4","4"], "num_warps":"1"}}
}
})";
TF_EXPECT_OK(
CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"(
CHECK: tt.load
)"));

EXPECT_TRUE(
RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0}));
}

} // namespace
} // namespace gpu
} // namespace xla
5 changes: 5 additions & 0 deletions xla/service/gpu/fusions/triton/triton_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ absl::flat_hash_set<HloOpcode> TritonSupportedUnaryElementwiseOps(
HloOpcode::kCeil};
ret.insert(additional_opcodes.begin(), additional_opcodes.end());
}

if (primitive_util::IsFloatingPointType(element_type)) {
ret.insert(HloOpcode::kReducePrecision);
}

return ret;
}

Expand Down

0 comments on commit 38b6144

Please sign in to comment.