diff --git a/velox/functions/sparksql/DecimalArithmetic.cpp b/velox/functions/sparksql/DecimalArithmetic.cpp index 519ff0d5cf58..61599bce10ea 100644 --- a/velox/functions/sparksql/DecimalArithmetic.cpp +++ b/velox/functions/sparksql/DecimalArithmetic.cpp @@ -14,405 +14,250 @@ * limitations under the License. */ -#include "velox/common/base/CheckedArithmetic.h" -#include "velox/expression/DecodedArgs.h" -#include "velox/expression/VectorFunction.h" +#include "velox/functions/Macros.h" +#include "velox/functions/Registerer.h" #include "velox/functions/sparksql/DecimalUtil.h" -#include "velox/type/DecimalUtil.h" namespace facebook::velox::functions::sparksql { namespace { -std::string getResultScale(std::string precision, std::string scale) { - return fmt::format( - "({}) <= 38 ? ({}) : max(({}) - ({}) + 38, min(({}), 6))", - precision, - scale, - scale, - precision, - scale); -} - -// Returns the whole and fraction parts of a decimal value. -template -inline std::pair getWholeAndFraction(T value, uint8_t scale) { - const auto scaleFactor = velox::DecimalUtil::kPowersOfTen[scale]; - const T whole = value / scaleFactor; - return {whole, value - whole * scaleFactor}; -} - -// Increases the scale of input value by 'delta'. Returns the input value if -// delta is not positive. -inline int128_t increaseScale(int128_t in, int16_t delta) { - // No need to consider overflow as 'delta == higher scale - input scale', so - // the scaled value will not exceed the maximum of long decimal. - return delta <= 0 ? in : in * velox::DecimalUtil::kPowersOfTen[delta]; -} - -// Scales up the whole part to result scale, and combine it with fraction part -// to produce a full result for decimal add. Checks whether the result -// overflows. -template -inline T -decimalAddResult(T whole, T fraction, uint8_t resultScale, bool& overflow) { - T scaledWhole = DecimalUtil::multiply( - whole, velox::DecimalUtil::kPowersOfTen[resultScale], overflow); - if (FOLLY_UNLIKELY(overflow)) { - return 0; - } - const auto result = scaledWhole + fraction; - if constexpr (std::is_same_v) { - overflow = (result > velox::DecimalUtil::kShortDecimalMax) || - (result < velox::DecimalUtil::kShortDecimalMin); - } else { - overflow = (result > velox::DecimalUtil::kLongDecimalMax) || - (result < velox::DecimalUtil::kLongDecimalMin); +struct DecimalAddSubtractBase { + protected: + void initializeBase(const std::vector& inputTypes) { + auto [aPrecision, aScale] = getDecimalPrecisionScale(*inputTypes[0]); + auto [bPrecision, bScale] = getDecimalPrecisionScale(*inputTypes[1]); + aScale_ = aScale; + bScale_ = bScale; + auto [rPrecision, rScale] = + computeResultPrecisionScale(aPrecision, aScale_, bPrecision, bScale_); + rPrecision_ = rPrecision; + rScale_ = rScale; + aRescale_ = computeRescaleFactor(aScale_, bScale_); + bRescale_ = computeRescaleFactor(bScale_, aScale_); } - return result; -} -// Reduces the scale of input value by 'delta'. Returns the input value if delta -// is not positive. -template -inline static T reduceScale(T in, int32_t delta) { - if (delta <= 0) { - return in; - } - T result; - bool overflow; - const auto scaleFactor = velox::DecimalUtil::kPowersOfTen[delta]; - if constexpr (std::is_same_v) { - VELOX_DCHECK_LE( - scaleFactor, - std::numeric_limits::max(), - "Scale factor should not exceed the maximum of int64_t."); + // Adds the values 'a' and 'b' and stores the result in 'r'. To align the + // scales of inputs, the value with the smaller scale is rescaled to the + // larger scale. 'aRescale' and 'bRescale' are the rescale factors needed to + // rescale 'a' and 'b'. 'rPrecision' and 'rScale' are the precision and scale + // of the result. + template + bool applyAdd(TResult& r, A a, B b) { + // The overflow flag is set to true if an overflow occurs + // during the addition. + bool overflow = false; + if (rPrecision_ < LongDecimalType::kMaxPrecision) { + const int128_t aRescaled = + a * velox::DecimalUtil::kPowersOfTen[aRescale_]; + const int128_t bRescaled = + b * velox::DecimalUtil::kPowersOfTen[bRescale_]; + r = TResult(aRescaled + bRescaled); + } else { + const uint32_t minLeadingZeros = + sparksql::DecimalUtil::minLeadingZeros( + a, b, aRescale_, bRescale_); + if (minLeadingZeros >= 3) { + // Fast path for no overflow. If both numbers contain at least 3 leading + // zeros, they can be added directly without the risk of overflow. + // The reason is if a number contains at least 2 leading zeros, it is + // ensured that the number fits in the maximum of decimal, because + // '2^126 - 1 < 10^38 - 1'. If both numbers contain at least 3 leading + // zeros, we are guaranteed that the result will have at least 2 leading + // zeros. + int128_t aRescaled = a * velox::DecimalUtil::kPowersOfTen[aRescale_]; + int128_t bRescaled = b * velox::DecimalUtil::kPowersOfTen[bRescale_]; + r = reduceScale( + TResult(aRescaled + bRescaled), + std::max(aScale_, bScale_) - rScale_); + } else { + // The risk of overflow should be considered. Add whole and fraction + // parts separately, and then combine. + r = addLarge(a, b, aScale_, bScale_, rScale_, overflow); + } + } + return !overflow && + velox::DecimalUtil::valueInPrecisionRange(r, rPrecision_); } - DecimalUtil::divideWithRoundUp( - result, in, T(scaleFactor), 0, overflow); - VELOX_DCHECK(!overflow); - return result; -} -// Adds two non-negative values by adding the whole and fraction parts -// separately. -template -inline static TResult addLargeNonNegative( - A a, - B b, - uint8_t aScale, - uint8_t bScale, - uint8_t rScale, - bool& overflow) { - VELOX_DCHECK_GE( - a, 0, "Non-negative value is expected in addLargeNonNegative."); - VELOX_DCHECK_GE( - b, 0, "Non-negative value is expected in addLargeNonNegative."); - - // Separate whole and fraction parts. - const auto [aWhole, aFraction] = getWholeAndFraction(a, aScale); - const auto [bWhole, bFraction] = getWholeAndFraction(b, bScale); - - // Adjust fractional parts to higher scale. - const auto higherScale = std::max(aScale, bScale); - const auto aFractionScaled = - increaseScale((int128_t)aFraction, higherScale - aScale); - const auto bFractionScaled = - increaseScale((int128_t)bFraction, higherScale - bScale); - - int128_t fraction; - bool carryToLeft = false; - const auto carrier = velox::DecimalUtil::kPowersOfTen[higherScale]; - if (aFractionScaled >= carrier - bFractionScaled) { - fraction = aFractionScaled + bFractionScaled - carrier; - carryToLeft = true; - } else { - fraction = aFractionScaled + bFractionScaled; + private: + // Returns the whole and fraction parts of a decimal value. + template + static std::pair getWholeAndFraction(T value, uint8_t scale) { + const auto scaleFactor = velox::DecimalUtil::kPowersOfTen[scale]; + const T whole = value / scaleFactor; + return {whole, value - whole * scaleFactor}; } - // Scale up the whole part and scale down the fraction part to combine them. - fraction = reduceScale(TResult(fraction), higherScale - rScale); - const auto whole = TResult(aWhole) + TResult(bWhole) + TResult(carryToLeft); - return decimalAddResult(whole, TResult(fraction), rScale, overflow); -} - -// Adds two opposite values by adding the whole and fraction parts separately. -template -inline static TResult addLargeOpposite( - A a, - B b, - uint8_t aScale, - uint8_t bScale, - int32_t rScale, - bool& overflow) { - VELOX_DCHECK( - (a < 0 && b > 0) || (a > 0 && b < 0), - "One positve and one negative value are expected in addLargeOpposite."); - - // Separate whole and fraction parts. - const auto [aWhole, aFraction] = getWholeAndFraction(a, aScale); - const auto [bWhole, bFraction] = getWholeAndFraction(b, bScale); - - // Adjust fractional parts to higher scale. - const auto higherScale = std::max(aScale, bScale); - const auto aFractionScaled = - increaseScale((int128_t)aFraction, higherScale - aScale); - const auto bFractionScaled = - increaseScale((int128_t)bFraction, higherScale - bScale); - - // No need to consider overflow because two inputs are opposite. - int128_t whole = (int128_t)aWhole + (int128_t)bWhole; - int128_t fraction = aFractionScaled + bFractionScaled; - - // If the whole and fractional parts have different signs, adjust them to the - // same sign. - const auto scaleFactor = velox::DecimalUtil::kPowersOfTen[higherScale]; - if (whole < 0 && fraction > 0) { - whole += 1; - fraction -= scaleFactor; - } else if (whole > 0 && fraction < 0) { - whole -= 1; - fraction += scaleFactor; + // Increases the scale of input value by 'delta'. Returns the input value if + // delta is not positive. + static int128_t increaseScale(int128_t in, int16_t delta) { + // No need to consider overflow as 'delta == higher scale - input scale', so + // the scaled value will not exceed the maximum of long decimal. + return delta <= 0 ? in : in * velox::DecimalUtil::kPowersOfTen[delta]; } - // Scale up the whole part and scale down the fraction part to combine them. - fraction = reduceScale(TResult(fraction), higherScale - rScale); - return decimalAddResult(TResult(whole), TResult(fraction), rScale, overflow); -} + // Scales up the whole part to result scale, and combine it with fraction part + // to produce a full result for decimal add. Checks whether the result + // overflows. + template + static T + decimalAddResult(T whole, T fraction, uint8_t resultScale, bool& overflow) { + T scaledWhole = sparksql::DecimalUtil::multiply( + whole, velox::DecimalUtil::kPowersOfTen[resultScale], overflow); + if (FOLLY_UNLIKELY(overflow)) { + return 0; + } + const auto result = scaledWhole + fraction; + if constexpr (std::is_same_v) { + overflow = (result > velox::DecimalUtil::kShortDecimalMax) || + (result < velox::DecimalUtil::kShortDecimalMin); + } else { + overflow = (result > velox::DecimalUtil::kLongDecimalMax) || + (result < velox::DecimalUtil::kLongDecimalMin); + } + return result; + } -template -inline static TResult addLarge( - A a, - B b, - uint8_t aScale, - uint8_t bScale, - int32_t rScale, - bool& overflow) { - if (a >= 0 && b >= 0) { - // Both non-negative. - return addLargeNonNegative( - a, b, aScale, bScale, rScale, overflow); - } else if (a <= 0 && b <= 0) { - // Both non-positive. - return TResult(-addLargeNonNegative( - A(-a), B(-b), aScale, bScale, rScale, overflow)); - } else { - // One positive and the other negative. - return addLargeOpposite( - a, b, aScale, bScale, rScale, overflow); + // Reduces the scale of input value by 'delta'. Returns the input value if + // delta is not positive. + template + static T reduceScale(T in, int32_t delta) { + if (delta <= 0) { + return in; + } + T result; + bool overflow; + const auto scaleFactor = velox::DecimalUtil::kPowersOfTen[delta]; + if constexpr (std::is_same_v) { + VELOX_DCHECK_LE( + scaleFactor, + std::numeric_limits::max(), + "Scale factor should not exceed the maximum of int64_t."); + } + DecimalUtil::divideWithRoundUp( + result, in, T(scaleFactor), 0, overflow); + VELOX_DCHECK(!overflow); + return result; } -} -template < - typename R /* Result Type */, - typename A /* Argument1 */, - typename B /* Argument2 */, - typename Operation /* Arithmetic operation */> -class DecimalBaseFunction : public exec::VectorFunction { - public: - DecimalBaseFunction( - uint8_t aRescale, - uint8_t bRescale, - uint8_t aPrecision, + // Adds two non-negative values by adding the whole and fraction parts + // separately. + template + static TResult addLargeNonNegative( + A a, + B b, uint8_t aScale, - uint8_t bPrecision, uint8_t bScale, - uint8_t rPrecision, - uint8_t rScale) - : aRescale_(aRescale), - bRescale_(bRescale), - aPrecision_(aPrecision), - aScale_(aScale), - bPrecision_(bPrecision), - bScale_(bScale), - rPrecision_(rPrecision), - rScale_(rScale) {} - - void apply( - const SelectivityVector& rows, - std::vector& args, - const TypePtr& resultType, - exec::EvalCtx& context, - VectorPtr& result) const override { - auto rawResults = prepareResults(rows, resultType, context, result); - if (args[0]->isConstantEncoding() && args[1]->isFlatEncoding()) { - // Fast path for (const, flat). - auto constant = args[0]->asUnchecked>()->valueAt(0); - auto flatValues = args[1]->asUnchecked>(); - auto rawValues = flatValues->mutableRawValues(); - context.applyToSelectedNoThrow(rows, [&](auto row) { - bool overflow = false; - Operation::template apply( - rawResults[row], - constant, - rawValues[row], - aRescale_, - bRescale_, - aPrecision_, - aScale_, - bPrecision_, - bScale_, - rPrecision_, - rScale_, - overflow); - if (overflow || - !velox::DecimalUtil::valueInPrecisionRange( - rawResults[row], rPrecision_)) { - result->setNull(row, true); - } - }); - } else if (args[0]->isFlatEncoding() && args[1]->isConstantEncoding()) { - // Fast path for (flat, const). - auto flatValues = args[0]->asUnchecked>(); - auto constant = args[1]->asUnchecked>()->valueAt(0); - auto rawValues = flatValues->mutableRawValues(); - context.applyToSelectedNoThrow(rows, [&](auto row) { - bool overflow = false; - Operation::template apply( - rawResults[row], - rawValues[row], - constant, - aRescale_, - bRescale_, - aPrecision_, - aScale_, - bPrecision_, - bScale_, - rPrecision_, - rScale_, - overflow); - if (overflow || - !velox::DecimalUtil::valueInPrecisionRange( - rawResults[row], rPrecision_)) { - result->setNull(row, true); - } - }); - } else if (args[0]->isFlatEncoding() && args[1]->isFlatEncoding()) { - // Fast path for (flat, flat). - auto flatA = args[0]->asUnchecked>(); - auto rawA = flatA->mutableRawValues(); - auto flatB = args[1]->asUnchecked>(); - auto rawB = flatB->mutableRawValues(); - - context.applyToSelectedNoThrow(rows, [&](auto row) { - bool overflow = false; - Operation::template apply( - rawResults[row], - rawA[row], - rawB[row], - aRescale_, - bRescale_, - aPrecision_, - aScale_, - bPrecision_, - bScale_, - rPrecision_, - rScale_, - overflow); - if (overflow || - !velox::DecimalUtil::valueInPrecisionRange( - rawResults[row], rPrecision_)) { - result->setNull(row, true); - } - }); + uint8_t rScale, + bool& overflow) { + VELOX_DCHECK_GE( + a, 0, "Non-negative value is expected in addLargeNonNegative."); + VELOX_DCHECK_GE( + b, 0, "Non-negative value is expected in addLargeNonNegative."); + + // Separate whole and fraction parts. + const auto [aWhole, aFraction] = getWholeAndFraction(a, aScale); + const auto [bWhole, bFraction] = getWholeAndFraction(b, bScale); + + // Adjust fractional parts to higher scale. + const auto higherScale = std::max(aScale, bScale); + const auto aFractionScaled = + increaseScale((int128_t)aFraction, higherScale - aScale); + const auto bFractionScaled = + increaseScale((int128_t)bFraction, higherScale - bScale); + + int128_t fraction; + bool carryToLeft = false; + const auto carrier = velox::DecimalUtil::kPowersOfTen[higherScale]; + if (aFractionScaled >= carrier - bFractionScaled) { + fraction = aFractionScaled + bFractionScaled - carrier; + carryToLeft = true; } else { - // Fast path if one or more arguments are encoded. - exec::DecodedArgs decodedArgs(rows, args, context); - auto a = decodedArgs.at(0); - auto b = decodedArgs.at(1); - context.applyToSelectedNoThrow(rows, [&](auto row) { - bool overflow = false; - Operation::template apply( - rawResults[row], - a->valueAt(row), - b->valueAt(row), - aRescale_, - bRescale_, - aPrecision_, - aScale_, - bPrecision_, - bScale_, - rPrecision_, - rScale_, - overflow); - if (overflow || - !velox::DecimalUtil::valueInPrecisionRange( - rawResults[row], rPrecision_)) { - result->setNull(row, true); - } - }); + fraction = aFractionScaled + bFractionScaled; } - } - private: - R* prepareResults( - const SelectivityVector& rows, - const TypePtr& resultType, - exec::EvalCtx& context, - VectorPtr& result) const { - context.ensureWritable(rows, resultType, result); - result->clearNulls(rows); - return result->asUnchecked>()->mutableRawValues(); + // Scale up the whole part and scale down the fraction part to combine them. + fraction = reduceScale(TResult(fraction), higherScale - rScale); + const auto whole = TResult(aWhole) + TResult(bWhole) + TResult(carryToLeft); + return decimalAddResult(whole, TResult(fraction), rScale, overflow); } - const uint8_t aRescale_; - const uint8_t bRescale_; - const uint8_t aPrecision_; - const uint8_t aScale_; - const uint8_t bPrecision_; - const uint8_t bScale_; - const uint8_t rPrecision_; - const uint8_t rScale_; -}; - -class Addition { - public: + // Adds two opposite values by adding the whole and fraction parts separately. template - inline static void apply( - TResult& r, + static TResult addLargeOpposite( A a, B b, - uint8_t aRescale, - uint8_t bRescale, - uint8_t /* aPrecision */, uint8_t aScale, - uint8_t /* bPrecision */, uint8_t bScale, - uint8_t rPrecision, - uint8_t rScale, + int32_t rScale, bool& overflow) { - if (rPrecision < LongDecimalType::kMaxPrecision) { - const int128_t aRescaled = a * velox::DecimalUtil::kPowersOfTen[aRescale]; - const int128_t bRescaled = b * velox::DecimalUtil::kPowersOfTen[bRescale]; - r = TResult(aRescaled + bRescaled); - } else { - const uint32_t minLeadingZeros = - DecimalUtil::minLeadingZeros(a, b, aRescale, bRescale); - if (minLeadingZeros >= 3) { - // Fast path for no overflow. If both numbers contain at least 3 leading - // zeros, they can be added directly without the risk of overflow. - // The reason is if a number contains at least 2 leading zeros, it is - // ensured that the number fits in the maximum of decimal, because - // '2^126 - 1 < 10^38 - 1'. If both numbers contain at least 3 leading - // zeros, we are guaranteed that the result will have at least 2 leading - // zeros. - int128_t aRescaled = a * velox::DecimalUtil::kPowersOfTen[aRescale]; - int128_t bRescaled = b * velox::DecimalUtil::kPowersOfTen[bRescale]; - r = reduceScale( - TResult(aRescaled + bRescaled), std::max(aScale, bScale) - rScale); - } else { - // The risk of overflow should be considered. Add whole and fraction - // parts separately, and then combine. - r = addLarge(a, b, aScale, bScale, rScale, overflow); - } + VELOX_DCHECK( + (a < 0 && b > 0) || (a > 0 && b < 0), + "One positve and one negative value are expected in addLargeOpposite."); + + // Separate whole and fraction parts. + const auto [aWhole, aFraction] = getWholeAndFraction(a, aScale); + const auto [bWhole, bFraction] = getWholeAndFraction(b, bScale); + + // Adjust fractional parts to higher scale. + const auto higherScale = std::max(aScale, bScale); + const auto aFractionScaled = + increaseScale((int128_t)aFraction, higherScale - aScale); + const auto bFractionScaled = + increaseScale((int128_t)bFraction, higherScale - bScale); + + // No need to consider overflow because two inputs are opposite. + int128_t whole = (int128_t)aWhole + (int128_t)bWhole; + int128_t fraction = aFractionScaled + bFractionScaled; + + // If the whole and fractional parts have different signs, adjust them to + // the same sign. + const auto scaleFactor = velox::DecimalUtil::kPowersOfTen[higherScale]; + if (whole < 0 && fraction > 0) { + whole += 1; + fraction -= scaleFactor; + } else if (whole > 0 && fraction < 0) { + whole -= 1; + fraction += scaleFactor; } + + // Scale up the whole part and scale down the fraction part to combine them. + fraction = reduceScale(TResult(fraction), higherScale - rScale); + return decimalAddResult( + TResult(whole), TResult(fraction), rScale, overflow); } - inline static uint8_t - computeRescaleFactor(uint8_t fromScale, uint8_t toScale, uint8_t rScale = 0) { - return std::max(0, toScale - fromScale); + // Add whole and fraction parts separately, and then combine. The overflow + // flag will be set to true if an overflow occurs during the addition. + template + static TResult addLarge( + A a, + B b, + uint8_t aScale, + uint8_t bScale, + int32_t rScale, + bool& overflow) { + if (a >= 0 && b >= 0) { + // Both non-negative. + return addLargeNonNegative( + a, b, aScale, bScale, rScale, overflow); + } else if (a <= 0 && b <= 0) { + // Both non-positive. + return TResult(-addLargeNonNegative( + A(-a), B(-b), aScale, bScale, rScale, overflow)); + } else { + // One positive and the other negative. + return addLargeOpposite( + a, b, aScale, bScale, rScale, overflow); + } } - inline static std::pair computeResultPrecisionScale( + // Computes the result precision and scale for decimal add and subtract + // operations following Hive's formulas. + // If result is representable with long decimal, the result + // scale is the maximum of 'aScale' and 'bScale'. If not, reduces result scale + // and caps the result precision at 38. + static std::pair computeResultPrecisionScale( uint8_t aPrecision, uint8_t aScale, uint8_t bPrecision, @@ -420,92 +265,92 @@ class Addition { auto precision = std::max(aPrecision - aScale, bPrecision - bScale) + std::max(aScale, bScale) + 1; auto scale = std::max(aScale, bScale); - return DecimalUtil::adjustPrecisionScale(precision, scale); + return sparksql::DecimalUtil::adjustPrecisionScale(precision, scale); + } + + static uint8_t computeRescaleFactor(uint8_t fromScale, uint8_t toScale) { + return std::max(0, toScale - fromScale); } + + uint8_t aScale_; + uint8_t bScale_; + uint8_t aRescale_; + uint8_t bRescale_; + uint8_t rPrecision_; + uint8_t rScale_; }; -class Subtraction { - public: - template - inline static void apply( - TResult& r, - A a, - B b, - uint8_t aRescale, - uint8_t bRescale, - uint8_t aPrecision, - uint8_t aScale, - uint8_t bPrecision, - uint8_t bScale, - uint8_t rPrecision, - uint8_t rScale, - bool& overflow) { - Addition::apply( - r, - a, - B(-b), - aRescale, - bRescale, - aPrecision, - aScale, - bPrecision, - bScale, - rPrecision, - rScale, - overflow); +template +struct DecimalAddFunction : DecimalAddSubtractBase { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + template + void initialize( + const std::vector& inputTypes, + const core::QueryConfig& /*config*/, + A* /*a*/, + B* /*b*/) { + initializeBase(inputTypes); } - inline static uint8_t - computeRescaleFactor(uint8_t fromScale, uint8_t toScale, uint8_t rScale = 0) { - return std::max(0, toScale - fromScale); + template + bool call(R& out, const A& a, const B& b) { + return applyAdd(out, a, b); } +}; - inline static std::pair computeResultPrecisionScale( - uint8_t aPrecision, - uint8_t aScale, - uint8_t bPrecision, - uint8_t bScale) { - return Addition::computeResultPrecisionScale( - aPrecision, aScale, bPrecision, bScale); +template +struct DecimalSubtractFunction : DecimalAddSubtractBase { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + template + void initialize( + const std::vector& inputTypes, + const core::QueryConfig& /*config*/, + A* /*a*/, + B* /*b*/) { + initializeBase(inputTypes); + } + + template + bool call(R& out, const A& a, const B& b) { + return applyAdd(out, a, B(-b)); } }; -class Multiply { - public: - // Derive from Arrow. - // https://github.com/apache/arrow/blob/release-12.0.1-rc1/cpp/src/gandiva/precompiled/decimal_ops.cc#L331 +template +struct DecimalMultiplyFunction { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + template + void initialize( + const std::vector& inputTypes, + const core::QueryConfig& /*config*/, + A* /*a*/, + B* /*b*/) { + auto [aPrecision, aScale] = getDecimalPrecisionScale(*inputTypes[0]); + auto [bPrecision, bScale] = getDecimalPrecisionScale(*inputTypes[1]); + auto [rPrecision, rScale] = DecimalUtil::adjustPrecisionScale( + aPrecision + bPrecision + 1, aScale + bScale); + rPrecision_ = rPrecision; + deltaScale_ = aScale + bScale - rScale; + } + template - inline static void apply( - R& r, - A a, - B b, - uint8_t aRescale, - uint8_t bRescale, - uint8_t aPrecision, - uint8_t aScale, - uint8_t bPrecision, - uint8_t bScale, - uint8_t rPrecision, - uint8_t rScale, - bool& overflow) { - if (rPrecision < 38) { - R result = DecimalUtil::multiply(R(a), R(b), overflow); - VELOX_DCHECK(!overflow); - r = DecimalUtil::multiply( - result, - R(velox::DecimalUtil::kPowersOfTen[aRescale + bRescale]), - overflow); + bool call(R& out, const A& a, const B& b) { + bool overflow = false; + if (rPrecision_ < 38) { + out = DecimalUtil::multiply(R(a), R(b), overflow); VELOX_DCHECK(!overflow); } else if (a == 0 && b == 0) { // Handle this separately to avoid divide-by-zero errors. - r = R(0); + out = R(0); } else { - auto deltaScale = aScale + bScale - rScale; - if (deltaScale == 0) { + if (deltaScale_ == 0) { // No scale down. // Multiply when the out_precision is 38, and there is no trimming of // the scale i.e the intermediate value is the same as the final value. - r = DecimalUtil::multiply(R(a), R(b), overflow); + out = DecimalUtil::multiply(R(a), R(b), overflow); } else { // Scale down. // It's possible that the intermediate value does not fit in 128-bits, @@ -520,10 +365,10 @@ class Multiply { // Needs int256. int256_t reslarge = static_cast(a) * static_cast(b); - reslarge = reduceScaleBy(reslarge, deltaScale); - r = DecimalUtil::convert(reslarge, overflow); + reslarge = reduceScaleBy(reslarge, deltaScale_); + out = DecimalUtil::convert(reslarge, overflow); } else { - if (LIKELY(deltaScale <= 38)) { + if (LIKELY(deltaScale_ <= 38)) { // The largest value that result can have here is (2^64 - 1) * (2^63 // - 1) = 1.70141E+38,which is greater than // DecimalUtil::kLongDecimalMax. @@ -533,9 +378,9 @@ class Multiply { // ((2^64 - 1) * (2^63 - 1)) / 10, which is less than // DecimalUtil::kLongDecimalMax, so there cannot be any overflow. DecimalUtil::divideWithRoundUp( - r, + out, result, - R(velox::DecimalUtil::kPowersOfTen[deltaScale]), + R(velox::DecimalUtil::kPowersOfTen[deltaScale_]), 0, overflow); VELOX_DCHECK(!overflow); @@ -550,29 +395,18 @@ class Multiply { // the right of the rightmost "visible" one. The reason why we have // to handle this case separately is because a scale multiplier with // a deltaScale 39 does not fit into 128 bit. - r = R(0); + out = R(0); } } } } - } - - inline static uint8_t - computeRescaleFactor(uint8_t fromScale, uint8_t toScale, uint8_t rScale = 0) { - return 0; - } - inline static std::pair computeResultPrecisionScale( - uint8_t aPrecision, - uint8_t aScale, - uint8_t bPrecision, - uint8_t bScale) { - return DecimalUtil::adjustPrecisionScale( - aPrecision + bPrecision + 1, aScale + bScale); + return !overflow && + velox::DecimalUtil::valueInPrecisionRange(out, rPrecision_); } private: - inline static int256_t reduceScaleBy(int256_t in, int32_t reduceBy) { + static int256_t reduceScaleBy(int256_t in, int32_t reduceBy) { if (reduceBy == 0) { return in; } @@ -586,33 +420,40 @@ class Multiply { } return result; } + + uint8_t rPrecision_; + // The difference between result scale and the sum of aScale and bScale. + int32_t deltaScale_; }; -class Divide { - public: - template - inline static void apply( - R& r, - A a, - B b, - uint8_t aRescale, - uint8_t /* bRescale */, - uint8_t /* aPrecision */, - uint8_t /* aScale */, - uint8_t /* bPrecision */, - uint8_t /* bScale */, - uint8_t /* rPrecision */, - uint8_t /* rScale */, - bool& overflow) { - DecimalUtil::divideWithRoundUp(r, a, b, aRescale, overflow); +template +struct DecimalDivideFunction { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + template + void initialize( + const std::vector& inputTypes, + const core::QueryConfig& /*config*/, + A* /*a*/, + B* /*b*/) { + auto [aPrecision, aScale] = getDecimalPrecisionScale(*inputTypes[0]); + auto [bPrecision, bScale] = getDecimalPrecisionScale(*inputTypes[1]); + auto [rPrecision, rScale] = + computeResultPrecisionScale(aPrecision, aScale, bPrecision, bScale); + rPrecision_ = rPrecision; + aRescale_ = rScale - aScale + bScale; } - inline static uint8_t - computeRescaleFactor(uint8_t fromScale, uint8_t toScale, uint8_t rScale) { - return rScale - fromScale + toScale; + template + bool call(R& out, const A& a, const B& b) { + bool overflow = false; + DecimalUtil::divideWithRoundUp(out, a, b, aRescale_, overflow); + return !overflow && + velox::DecimalUtil::valueInPrecisionRange(out, rPrecision_); } - inline static std::pair computeResultPrecisionScale( + private: + static std::pair computeResultPrecisionScale( uint8_t aPrecision, uint8_t aScale, uint8_t bPrecision, @@ -621,182 +462,139 @@ class Divide { auto precision = aPrecision - aScale + bScale + scale; return DecimalUtil::adjustPrecisionScale(precision, scale); } + + uint8_t aRescale_; + uint8_t rPrecision_; }; -std::vector> -decimalAddSubtractSignature() { +template