Skip to content

Commit

Permalink
[facebookincubator#10383 ] Support decimal operation not precision lo…
Browse files Browse the repository at this point in the history
…ss mode (10383)

facebookincubator#10383
  • Loading branch information
jinchengchenghh authored and zhztheplayer committed Jul 25, 2024
1 parent 4e889e5 commit dad351f
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 50 deletions.
11 changes: 11 additions & 0 deletions velox/core/QueryConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,13 @@ class QueryConfig {
/// The current spark partition id.
static constexpr const char* kSparkPartitionId = "spark.partition_id";

/// When true, establishing the result type of an arithmetic operation
/// happens according to Hive behavior and SQL ANSI 2011 specification, i.e.
/// rounding the decimal part of the result if an exact representation is not
/// possible. Otherwise, NULL is returned in those cases, as previously
static constexpr const char* kSparkDecimalOperationsAllowPrecisionLoss =
"spark.decimal_operations.allow_precision_loss";

/// The number of local parallel table writer operators per task.
static constexpr const char* kTaskWriterCount = "task_writer_count";

Expand Down Expand Up @@ -642,6 +649,10 @@ class QueryConfig {
return get<int64_t>(kSparkBloomFilterNumBits, kDefault);
}

bool sparkDecimalOperationsAllowPrecisionLoss() const {
return get<bool>(kSparkDecimalOperationsAllowPrecisionLoss, true);
}

// Spark kMaxNumBits is 67'108'864, but velox has memory limit sizeClassSizes
// 256, so decrease it to not over memory limit.
int64_t sparkBloomFilterMaxNumBits() const {
Expand Down
5 changes: 5 additions & 0 deletions velox/docs/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,8 @@ Spark-specific Configuration
* - spark.partition_id
- integer
- The current task's Spark partition ID. It's set by the query engine (Spark) prior to task execution.
* - spark.decimal_operations.allow_precision_loss
- bool
- When true, establishing the result type of an arithmetic operation happens according to Hive behavior and SQL ANSI 2011 specification, i.e.
rounding the decimal part of the result if an exact representation is not
possible. Otherwise, NULL is returned in those cases, as previously
176 changes: 126 additions & 50 deletions velox/functions/sparksql/DecimalArithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,39 @@
namespace facebook::velox::functions::sparksql {
namespace {

std::string getResultScale(std::string precision, std::string scale) {
return fmt::format(
"({}) <= 38 ? ({}) : max(({}) - ({}) + 38, min(({}), 6))",
precision,
scale,
scale,
precision,
scale);
std::string getResultScale(
std::string precision,
std::string scale,
bool allowPrecisionLoss) {
return allowPrecisionLoss
? fmt::format(
"({}) <= 38 ? ({}) : max(({}) - ({}) + 38, min(({}), 6))",
precision,
scale,
scale,
precision,
scale)
: fmt::format("({}) <= 38 ? ({}) : 38", scale, scale);
}

std::pair<std::string, std::string>
getNotAllowPrecisionLossDivideResultScale() {
std::string intDig = "min(38, a_precision - a_scale + b_scale)";
std::string decDig = "min(38, max(6, a_scale + b_precision + 1))";
std::string diff = intDig + " + " + decDig + " - 38";
std::string newDecDig = fmt::format("({}) - ({}) / 2 - 1", decDig, diff);
std::string newIntDig = fmt::format("38 - ({})", newDecDig);
return {
fmt::format(
"({}) > 0 ? ({}) : ({})",
diff,
getResultScale("", newIntDig + " + " + newDecDig, false),
getResultScale("", intDig + " + " + decDig, false)),
fmt::format(
"({}) > 0 ? ({}) : ({})",
diff,
getResultScale("", newDecDig, false),
getResultScale("", decDig, false))};
}

// Returns the whole and fraction parts of a decimal value.
Expand Down Expand Up @@ -416,11 +441,14 @@ class Addition {
uint8_t aPrecision,
uint8_t aScale,
uint8_t bPrecision,
uint8_t bScale) {
uint8_t bScale,
bool allowPrecisionLoss) {
auto precision = std::max(aPrecision - aScale, bPrecision - bScale) +
std::max(aScale, bScale) + 1;
auto scale = std::max(aScale, bScale);
return DecimalUtil::adjustPrecisionScale(precision, scale);
return allowPrecisionLoss
? DecimalUtil::adjustPrecisionScale(precision, scale)
: DecimalUtil::bounded(precision, scale);
}
};

Expand Down Expand Up @@ -464,9 +492,10 @@ class Subtraction {
uint8_t aPrecision,
uint8_t aScale,
uint8_t bPrecision,
uint8_t bScale) {
uint8_t bScale,
bool allowPrecisionLoss) {
return Addition::computeResultPrecisionScale(
aPrecision, aScale, bPrecision, bScale);
aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss);
}
};

Expand Down Expand Up @@ -566,9 +595,12 @@ class Multiply {
uint8_t aPrecision,
uint8_t aScale,
uint8_t bPrecision,
uint8_t bScale) {
return DecimalUtil::adjustPrecisionScale(
aPrecision + bPrecision + 1, aScale + bScale);
uint8_t bScale,
bool allowPrecisionLoss) {
return allowPrecisionLoss
? DecimalUtil::adjustPrecisionScale(
aPrecision + bPrecision + 1, aScale + bScale)
: DecimalUtil::bounded(aPrecision + bPrecision + 1, aScale + bScale);
}

private:
Expand Down Expand Up @@ -616,15 +648,27 @@ class Divide {
uint8_t aPrecision,
uint8_t aScale,
uint8_t bPrecision,
uint8_t bScale) {
auto scale = std::max(6, aScale + bPrecision + 1);
auto precision = aPrecision - aScale + bScale + scale;
return DecimalUtil::adjustPrecisionScale(precision, scale);
uint8_t bScale,
bool allowPrecisionLoss) {
if (allowPrecisionLoss) {
auto scale = std::max(6, aScale + bPrecision + 1);
auto precision = aPrecision - aScale + bScale + scale;
return DecimalUtil::adjustPrecisionScale(precision, scale);
} else {
auto intDig = std::min(38, aPrecision - aScale + bScale);
auto decDig = std::min(38, std::max(6, aScale + bPrecision + 1));
auto diff = (intDig + decDig) - 38;
if (diff > 0) {
decDig -= diff / 2 + 1;
intDig = 38 - decDig;
}
return DecimalUtil::bounded(intDig + decDig, decDig);
}
}
};

std::vector<std::shared_ptr<exec::FunctionSignature>>
decimalAddSubtractSignature() {
decimalAddSubtractSignature(bool allowPrecisionLoss) {
return {
exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
Expand All @@ -638,15 +682,17 @@ decimalAddSubtractSignature() {
"r_scale",
getResultScale(
"max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1",
"max(a_scale, b_scale)"))
"max(a_scale, b_scale)",
allowPrecisionLoss))
.returnType("DECIMAL(r_precision, r_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("DECIMAL(b_precision, b_scale)")
.build()};
.build(),
};
}

std::vector<std::shared_ptr<exec::FunctionSignature>>
decimalMultiplySignature() {
std::vector<std::shared_ptr<exec::FunctionSignature>> decimalMultiplySignature(
bool allowPrecisionLoss) {
return {exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
Expand All @@ -657,45 +703,55 @@ decimalMultiplySignature() {
.integerVariable(
"r_scale",
getResultScale(
"a_precision + b_precision + 1", "a_scale + b_scale"))
"a_precision + b_precision + 1",
"a_scale + b_scale",
allowPrecisionLoss))
.returnType("DECIMAL(r_precision, r_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("DECIMAL(b_precision, b_scale)")
.build()};
}

std::vector<std::shared_ptr<exec::FunctionSignature>> decimalDivideSignature() {
return {
exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.integerVariable("b_precision")
.integerVariable("b_scale")
.integerVariable(
"r_precision",
"min(38, a_precision - a_scale + b_scale + max(6, a_scale + b_precision + 1))")
.integerVariable(
"r_scale",
getResultScale(
"a_precision - a_scale + b_scale + max(6, a_scale + b_precision + 1)",
"max(6, a_scale + b_precision + 1)"))
.returnType("DECIMAL(r_precision, r_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("DECIMAL(b_precision, b_scale)")
.build()};
std::vector<std::shared_ptr<exec::FunctionSignature>> decimalDivideSignature(
bool allowPrecisionLoss) {
auto precisionAndScale = getNotAllowPrecisionLossDivideResultScale();
std::string resultPrecision = allowPrecisionLoss
? "min(38, a_precision - a_scale + b_scale + max(6, a_scale + b_precision + 1))"
: precisionAndScale.first;
std::string resultScale = allowPrecisionLoss
? getResultScale(
"a_precision - a_scale + b_scale + max(6, a_scale + b_precision + 1)",
"max(6, a_scale + b_precision + 1)",
allowPrecisionLoss)
: precisionAndScale.second;
return {exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.integerVariable("b_precision")
.integerVariable("b_scale")
.integerVariable("r_precision", resultPrecision)
.integerVariable("r_scale", resultScale)
.returnType("DECIMAL(r_precision, r_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("DECIMAL(b_precision, b_scale)")
.build()};
}

template <typename Operation>
std::shared_ptr<exec::VectorFunction> createDecimalFunction(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
const core::QueryConfig& config) {
const auto& aType = inputArgs[0].type;
const auto& bType = inputArgs[1].type;
const auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType);
const auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType);
const auto [rPrecision, rScale] = Operation::computeResultPrecisionScale(
aPrecision, aScale, bPrecision, bScale);
aPrecision,
aScale,
bPrecision,
bScale,
config.sparkDecimalOperationsAllowPrecisionLoss());
const uint8_t aRescale =
Operation::computeRescaleFactor(aScale, bScale, rScale);
const uint8_t bRescale =
Expand Down Expand Up @@ -782,21 +838,41 @@ std::shared_ptr<exec::VectorFunction> createDecimalFunction(

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_decimal_add,
decimalAddSubtractSignature(),
decimalAddSubtractSignature(true),
createDecimalFunction<Addition>);

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_decimal_sub,
decimalAddSubtractSignature(),
decimalAddSubtractSignature(true),
createDecimalFunction<Subtraction>);

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_decimal_mul,
decimalMultiplySignature(),
decimalMultiplySignature(true),
createDecimalFunction<Multiply>);

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_decimal_div,
decimalDivideSignature(),
decimalDivideSignature(true),
createDecimalFunction<Divide>);

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_decimal_add_not_allow_precision_loss,
decimalAddSubtractSignature(false),
createDecimalFunction<Addition>);

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_decimal_sub_not_allow_precision_loss,
decimalAddSubtractSignature(false),
createDecimalFunction<Subtraction>);

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_decimal_mul_not_allow_precision_loss,
decimalMultiplySignature(false),
createDecimalFunction<Multiply>);

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_decimal_div_not_allow_precision_loss,
decimalDivideSignature(false),
createDecimalFunction<Divide>);
} // namespace facebook::velox::functions::sparksql
11 changes: 11 additions & 0 deletions velox/functions/sparksql/DecimalUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,17 @@ class DecimalUtil {
}
}

/// This method is used when
/// `spark.sql.decimalOperations.allowPrecisionLoss` is set to false.
/// Make sure the precision and scale is in range.
inline static std::pair<uint8_t, uint8_t> bounded(
uint8_t rPrecision,
uint8_t rScale) {
return {
std::min(rPrecision, DecimalType<TypeKind::HUGEINT>::kMaxPrecision),
std::min(rScale, DecimalType<TypeKind::HUGEINT>::kMaxPrecision)};
}

private:
/// Maintains the max bits that need to be increased for rescaling a value by
/// certain scale. The calculation relies on the following formula:
Expand Down
Loading

0 comments on commit dad351f

Please sign in to comment.