diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index f784b8a69d0c..40ab71ffec8b 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -304,6 +304,13 @@ class QueryConfig { /// The current spark partition id. static constexpr const char* kSparkPartitionId = "spark.partition_id"; + /// When true, establishing the result type of an arithmetic operation + /// happens according to Hive behavior and SQL ANSI 2011 specification, i.e. + /// rounding the decimal part of the result if an exact representation is not + /// possible. Otherwise, NULL is returned in those cases, as previously + static constexpr const char* kSparkDecimalOperationsAllowPrecisionLoss = + "spark.decimal_operations.allow_precision_loss"; + /// The number of local parallel table writer operators per task. static constexpr const char* kTaskWriterCount = "task_writer_count"; @@ -642,6 +649,10 @@ class QueryConfig { return get(kSparkBloomFilterNumBits, kDefault); } + bool sparkDecimalOperationsAllowPrecisionLoss() const { + return get(kSparkDecimalOperationsAllowPrecisionLoss, true); + } + // Spark kMaxNumBits is 67'108'864, but velox has memory limit sizeClassSizes // 256, so decrease it to not over memory limit. int64_t sparkBloomFilterMaxNumBits() const { diff --git a/velox/docs/configs.rst b/velox/docs/configs.rst index cd782b2d763d..02f05ac35823 100644 --- a/velox/docs/configs.rst +++ b/velox/docs/configs.rst @@ -687,3 +687,8 @@ Spark-specific Configuration * - spark.partition_id - integer - The current task's Spark partition ID. It's set by the query engine (Spark) prior to task execution. + * - spark.decimal_operations.allow_precision_loss + - bool + - When true, establishing the result type of an arithmetic operation happens according to Hive behavior and SQL ANSI 2011 specification, i.e. + rounding the decimal part of the result if an exact representation is not + possible. Otherwise, NULL is returned in those cases, as previously diff --git a/velox/functions/sparksql/DecimalArithmetic.cpp b/velox/functions/sparksql/DecimalArithmetic.cpp index 519ff0d5cf58..70543d304e0b 100644 --- a/velox/functions/sparksql/DecimalArithmetic.cpp +++ b/velox/functions/sparksql/DecimalArithmetic.cpp @@ -23,14 +23,39 @@ namespace facebook::velox::functions::sparksql { namespace { -std::string getResultScale(std::string precision, std::string scale) { - return fmt::format( - "({}) <= 38 ? ({}) : max(({}) - ({}) + 38, min(({}), 6))", - precision, - scale, - scale, - precision, - scale); +std::string getResultScale( + std::string precision, + std::string scale, + bool allowPrecisionLoss) { + return allowPrecisionLoss + ? fmt::format( + "({}) <= 38 ? ({}) : max(({}) - ({}) + 38, min(({}), 6))", + precision, + scale, + scale, + precision, + scale) + : fmt::format("({}) <= 38 ? ({}) : 38", scale, scale); +} + +std::pair +getNotAllowPrecisionLossDivideResultScale() { + std::string intDig = "min(38, a_precision - a_scale + b_scale)"; + std::string decDig = "min(38, max(6, a_scale + b_precision + 1))"; + std::string diff = intDig + " + " + decDig + " - 38"; + std::string newDecDig = fmt::format("({}) - ({}) / 2 - 1", decDig, diff); + std::string newIntDig = fmt::format("38 - ({})", newDecDig); + return { + fmt::format( + "({}) > 0 ? ({}) : ({})", + diff, + getResultScale("", newIntDig + " + " + newDecDig, false), + getResultScale("", intDig + " + " + decDig, false)), + fmt::format( + "({}) > 0 ? ({}) : ({})", + diff, + getResultScale("", newDecDig, false), + getResultScale("", decDig, false))}; } // Returns the whole and fraction parts of a decimal value. @@ -416,11 +441,14 @@ class Addition { uint8_t aPrecision, uint8_t aScale, uint8_t bPrecision, - uint8_t bScale) { + uint8_t bScale, + bool allowPrecisionLoss) { auto precision = std::max(aPrecision - aScale, bPrecision - bScale) + std::max(aScale, bScale) + 1; auto scale = std::max(aScale, bScale); - return DecimalUtil::adjustPrecisionScale(precision, scale); + return allowPrecisionLoss + ? DecimalUtil::adjustPrecisionScale(precision, scale) + : DecimalUtil::bounded(precision, scale); } }; @@ -464,9 +492,10 @@ class Subtraction { uint8_t aPrecision, uint8_t aScale, uint8_t bPrecision, - uint8_t bScale) { + uint8_t bScale, + bool allowPrecisionLoss) { return Addition::computeResultPrecisionScale( - aPrecision, aScale, bPrecision, bScale); + aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss); } }; @@ -566,9 +595,12 @@ class Multiply { uint8_t aPrecision, uint8_t aScale, uint8_t bPrecision, - uint8_t bScale) { - return DecimalUtil::adjustPrecisionScale( - aPrecision + bPrecision + 1, aScale + bScale); + uint8_t bScale, + bool allowPrecisionLoss) { + return allowPrecisionLoss + ? DecimalUtil::adjustPrecisionScale( + aPrecision + bPrecision + 1, aScale + bScale) + : DecimalUtil::bounded(aPrecision + bPrecision + 1, aScale + bScale); } private: @@ -616,15 +648,27 @@ class Divide { uint8_t aPrecision, uint8_t aScale, uint8_t bPrecision, - uint8_t bScale) { - auto scale = std::max(6, aScale + bPrecision + 1); - auto precision = aPrecision - aScale + bScale + scale; - return DecimalUtil::adjustPrecisionScale(precision, scale); + uint8_t bScale, + bool allowPrecisionLoss) { + if (allowPrecisionLoss) { + auto scale = std::max(6, aScale + bPrecision + 1); + auto precision = aPrecision - aScale + bScale + scale; + return DecimalUtil::adjustPrecisionScale(precision, scale); + } else { + auto intDig = std::min(38, aPrecision - aScale + bScale); + auto decDig = std::min(38, std::max(6, aScale + bPrecision + 1)); + auto diff = (intDig + decDig) - 38; + if (diff > 0) { + decDig -= diff / 2 + 1; + intDig = 38 - decDig; + } + return DecimalUtil::bounded(intDig + decDig, decDig); + } } }; std::vector> -decimalAddSubtractSignature() { +decimalAddSubtractSignature(bool allowPrecisionLoss) { return { exec::FunctionSignatureBuilder() .integerVariable("a_precision") @@ -638,15 +682,17 @@ decimalAddSubtractSignature() { "r_scale", getResultScale( "max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1", - "max(a_scale, b_scale)")) + "max(a_scale, b_scale)", + allowPrecisionLoss)) .returnType("DECIMAL(r_precision, r_scale)") .argumentType("DECIMAL(a_precision, a_scale)") .argumentType("DECIMAL(b_precision, b_scale)") - .build()}; + .build(), + }; } -std::vector> -decimalMultiplySignature() { +std::vector> decimalMultiplySignature( + bool allowPrecisionLoss) { return {exec::FunctionSignatureBuilder() .integerVariable("a_precision") .integerVariable("a_scale") @@ -657,45 +703,55 @@ decimalMultiplySignature() { .integerVariable( "r_scale", getResultScale( - "a_precision + b_precision + 1", "a_scale + b_scale")) + "a_precision + b_precision + 1", + "a_scale + b_scale", + allowPrecisionLoss)) .returnType("DECIMAL(r_precision, r_scale)") .argumentType("DECIMAL(a_precision, a_scale)") .argumentType("DECIMAL(b_precision, b_scale)") .build()}; } -std::vector> decimalDivideSignature() { - return { - exec::FunctionSignatureBuilder() - .integerVariable("a_precision") - .integerVariable("a_scale") - .integerVariable("b_precision") - .integerVariable("b_scale") - .integerVariable( - "r_precision", - "min(38, a_precision - a_scale + b_scale + max(6, a_scale + b_precision + 1))") - .integerVariable( - "r_scale", - getResultScale( - "a_precision - a_scale + b_scale + max(6, a_scale + b_precision + 1)", - "max(6, a_scale + b_precision + 1)")) - .returnType("DECIMAL(r_precision, r_scale)") - .argumentType("DECIMAL(a_precision, a_scale)") - .argumentType("DECIMAL(b_precision, b_scale)") - .build()}; +std::vector> decimalDivideSignature( + bool allowPrecisionLoss) { + auto precisionAndScale = getNotAllowPrecisionLossDivideResultScale(); + std::string resultPrecision = allowPrecisionLoss + ? "min(38, a_precision - a_scale + b_scale + max(6, a_scale + b_precision + 1))" + : precisionAndScale.first; + std::string resultScale = allowPrecisionLoss + ? getResultScale( + "a_precision - a_scale + b_scale + max(6, a_scale + b_precision + 1)", + "max(6, a_scale + b_precision + 1)", + allowPrecisionLoss) + : precisionAndScale.second; + return {exec::FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("b_precision") + .integerVariable("b_scale") + .integerVariable("r_precision", resultPrecision) + .integerVariable("r_scale", resultScale) + .returnType("DECIMAL(r_precision, r_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(b_precision, b_scale)") + .build()}; } template std::shared_ptr createDecimalFunction( const std::string& name, const std::vector& inputArgs, - const core::QueryConfig& /*config*/) { + const core::QueryConfig& config) { const auto& aType = inputArgs[0].type; const auto& bType = inputArgs[1].type; const auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType); const auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType); const auto [rPrecision, rScale] = Operation::computeResultPrecisionScale( - aPrecision, aScale, bPrecision, bScale); + aPrecision, + aScale, + bPrecision, + bScale, + config.sparkDecimalOperationsAllowPrecisionLoss()); const uint8_t aRescale = Operation::computeRescaleFactor(aScale, bScale, rScale); const uint8_t bRescale = @@ -782,21 +838,41 @@ std::shared_ptr createDecimalFunction( VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_decimal_add, - decimalAddSubtractSignature(), + decimalAddSubtractSignature(true), createDecimalFunction); VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_decimal_sub, - decimalAddSubtractSignature(), + decimalAddSubtractSignature(true), createDecimalFunction); VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_decimal_mul, - decimalMultiplySignature(), + decimalMultiplySignature(true), createDecimalFunction); VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_decimal_div, - decimalDivideSignature(), + decimalDivideSignature(true), + createDecimalFunction); + +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( + udf_decimal_add_not_allow_precision_loss, + decimalAddSubtractSignature(false), + createDecimalFunction); + +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( + udf_decimal_sub_not_allow_precision_loss, + decimalAddSubtractSignature(false), + createDecimalFunction); + +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( + udf_decimal_mul_not_allow_precision_loss, + decimalMultiplySignature(false), + createDecimalFunction); + +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( + udf_decimal_div_not_allow_precision_loss, + decimalDivideSignature(false), createDecimalFunction); } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/DecimalUtil.h b/velox/functions/sparksql/DecimalUtil.h index fbe5da77809e..484820c70b3f 100644 --- a/velox/functions/sparksql/DecimalUtil.h +++ b/velox/functions/sparksql/DecimalUtil.h @@ -211,6 +211,17 @@ class DecimalUtil { } } + /// This method is used when + /// `spark.sql.decimalOperations.allowPrecisionLoss` is set to false. + /// Make sure the precision and scale is in range. + inline static std::pair bounded( + uint8_t rPrecision, + uint8_t rScale) { + return { + std::min(rPrecision, DecimalType::kMaxPrecision), + std::min(rScale, DecimalType::kMaxPrecision)}; + } + private: /// Maintains the max bits that need to be increased for rescaling a value by /// certain scale. The calculation relies on the following formula: diff --git a/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp b/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp index 2ecbd4f3da23..5e8107face6d 100644 --- a/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp +++ b/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp @@ -22,6 +22,21 @@ using namespace facebook::velox; using namespace facebook::velox::test; using namespace facebook::velox::functions::test; +namespace facebook::velox::functions::sparksql { +void registerNotAllowPrecisionLossFunction() { + VELOX_REGISTER_VECTOR_FUNCTION( + udf_decimal_add_not_allow_precision_loss, "add_not_allow_precision_loss"); + VELOX_REGISTER_VECTOR_FUNCTION( + udf_decimal_sub_not_allow_precision_loss, + "subtract_not_allow_precision_loss"); + VELOX_REGISTER_VECTOR_FUNCTION( + udf_decimal_mul_not_allow_precision_loss, + "multiply_not_allow_precision_loss"); + VELOX_REGISTER_VECTOR_FUNCTION( + udf_decimal_div_not_allow_precision_loss, + "divide_not_allow_precision_loss"); +} +} // namespace facebook::velox::functions::sparksql namespace facebook::velox::functions::sparksql::test { namespace { @@ -29,6 +44,7 @@ class DecimalArithmeticTest : public SparkFunctionBaseTest { public: DecimalArithmeticTest() { options_.parseDecimalAsDouble = false; + registerNotAllowPrecisionLossFunction(); } protected: @@ -76,6 +92,13 @@ class DecimalArithmeticTest : public SparkFunctionBaseTest { } return makeNullableFlatVector(numbers, type); } + + void setDecimalOperationsAllowPrecisionLoss(bool allowPrecisionLoss) { + queryCtx_->testingOverrideConfigUnsafe({ + {core::QueryConfig::kSparkDecimalOperationsAllowPrecisionLoss, + std::to_string(allowPrecisionLoss)}, + }); + } }; TEST_F(DecimalArithmeticTest, add) { @@ -522,5 +545,50 @@ TEST_F(DecimalArithmeticTest, decimalDivTest) { {makeConstant( DecimalUtil::kLongDecimalMax, 1, DECIMAL(38, 0))}); } + +TEST_F(DecimalArithmeticTest, notAllowPrecisionLoss) { + setDecimalOperationsAllowPrecisionLoss(false); + + testArithmeticFunction( + "add_not_allow_precision_loss", + {makeFlatVector( + std::vector{11232100, 9998888, 12345678, 2135632}, + DECIMAL(38, 7)), + makeFlatVector(std::vector{1, 2, 3, 4}, DECIMAL(10, 0))}, + makeFlatVector( + std::vector{21232100, 29998888, 42345678, 42135632}, + DECIMAL(38, 7))); + + testArithmeticFunction( + "subtract_not_allow_precision_loss", + {makeFlatVector( + std::vector{11232100, 9998888, 12345678, 2135632}, + DECIMAL(38, 7)), + makeFlatVector(std::vector{1, 2, 3, 4}, DECIMAL(10, 0))}, + makeFlatVector( + std::vector{1232100, -10001112, -17654322, -37864368}, + DECIMAL(38, 7))); + + testDecimalExpr( + makeConstant(60501, 1, DECIMAL(38, 10)), + "multiply_not_allow_precision_loss(c0, c1)", + {makeConstant(201, 1, DECIMAL(20, 5)), + makeConstant(301, 1, DECIMAL(20, 5))}); + + // diff > 0 + testDecimalExpr( + makeConstant( + HugeInt::parse("5" + std::string(18, '0')), 1, DECIMAL(38, 18)), + "divide_not_allow_precision_loss(c0, c1)", + {makeConstant(500, 1, DECIMAL(20, 2)), + makeConstant(1000, 1, DECIMAL(17, 3))}); + // diff < 0 + testDecimalExpr( + makeConstant( + HugeInt::parse("5" + std::string(10, '0')), 1, DECIMAL(31, 10)), + "divide_not_allow_precision_loss(c0, c1)", + {makeConstant(500, 1, DECIMAL(20, 2)), + makeConstant(1000, 1, DECIMAL(7, 3))}); +} } // namespace } // namespace facebook::velox::functions::sparksql::test diff --git a/velox/functions/sparksql/tests/DecimalUtilTest.cpp b/velox/functions/sparksql/tests/DecimalUtilTest.cpp index 833b88605a20..a967998b4c94 100644 --- a/velox/functions/sparksql/tests/DecimalUtilTest.cpp +++ b/velox/functions/sparksql/tests/DecimalUtilTest.cpp @@ -35,6 +35,13 @@ class DecimalUtilTest : public testing::Test { ASSERT_EQ(overflow, expectedOverflow); ASSERT_EQ(r, expectedResult); } + + void testBounded( + uint8_t rPrecision, + uint8_t rScale, + std::pair expected) { + ASSERT_EQ(DecimalUtil::bounded(rPrecision, rScale), expected); + } }; } // namespace @@ -60,4 +67,10 @@ TEST_F(DecimalUtilTest, minLeadingZeros) { 12); ASSERT_EQ(result, 0); } + +TEST_F(DecimalUtilTest, bounded) { + testBounded(10, 3, {10, 3}); + testBounded(40, 3, {38, 3}); + testBounded(44, 42, {38, 38}); +} } // namespace facebook::velox::functions::sparksql::test