Skip to content

Commit

Permalink
refactor the decimal register
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed Sep 25, 2024
1 parent 593e606 commit 61f976b
Showing 1 changed file with 59 additions and 54 deletions.
113 changes: 59 additions & 54 deletions velox/functions/sparksql/DecimalArithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,60 +498,41 @@ struct DecimalDivideFunction {
uint8_t rPrecision_;
};

template <template <class, bool> typename T, bool allowPrecisionLoss>
using ParameterBinder = TempWrapper<T<exec::VectorExec, allowPrecisionLoss>>;

template <typename Func, typename TReturn, typename... TArgs>
bool registerFunction(
const std::vector<std::string>& aliases = {},
const std::vector<exec::SignatureVariable>& constraints = {},
bool overwrite = true) {
using funcClass = typename Func::template udf<exec::VectorExec>;
using holderClass = core::UDFHolder<
funcClass,
exec::VectorExec,
TReturn,
ConstantChecker<TArgs...>,
typename UnwrapConstantType<TArgs>::type...>;
return exec::registerSimpleFunction<holderClass>(
aliases, constraints, overwrite);
}

template <template <class, bool> typename Func, bool allowPrecisionLoss>
template <template <class> typename Func>
void registerDecimalBinary(
const std::string& name,
std::vector<exec::SignatureVariable> constraints) {
// (long, long) -> long
registerFunction<
ParameterBinder<Func, allowPrecisionLoss>,
Func,
LongDecimal<P3, S3>,
LongDecimal<P1, S1>,
LongDecimal<P2, S2>>({name}, constraints);

// (short, short) -> short
registerFunction<
ParameterBinder<Func, allowPrecisionLoss>,
Func,
ShortDecimal<P3, S3>,
ShortDecimal<P1, S1>,
ShortDecimal<P2, S2>>({name}, constraints);

// (short, short) -> long
registerFunction<
ParameterBinder<Func, allowPrecisionLoss>,
Func,
LongDecimal<P3, S3>,
ShortDecimal<P1, S1>,
ShortDecimal<P2, S2>>({name}, constraints);

// (short, long) -> long
registerFunction<
ParameterBinder<Func, allowPrecisionLoss>,
Func,
LongDecimal<P3, S3>,
ShortDecimal<P1, S1>,
LongDecimal<P2, S2>>({name}, constraints);

// (long, short) -> long
registerFunction<
ParameterBinder<Func, allowPrecisionLoss>,
Func,
LongDecimal<P3, S3>,
LongDecimal<P1, S1>,
ShortDecimal<P2, S2>>({name}, constraints);
Expand Down Expand Up @@ -585,8 +566,7 @@ std::vector<exec::SignatureVariable> makeConstraints(
S3::name(), finalScale, exec::ParameterType::kIntegerParameter)};
}

template <template <class, bool> typename Func>
void registerDecimalAddSubtract(const std::string& name) {
std::pair<std::string, std::string> getAddSubtractResultPrecisionScale() {
std::string rPrecision = fmt::format(
"max({a_precision} - {a_scale}, {b_precision} - {b_scale}) + max({a_scale}, {b_scale}) + 1",
fmt::arg("a_precision", P1::name()),
Expand All @@ -597,22 +577,46 @@ void registerDecimalAddSubtract(const std::string& name) {
"max({a_scale}, {b_scale})",
fmt::arg("a_scale", S1::name()),
fmt::arg("b_scale", S2::name()));
registerDecimalBinary<Func, true>(
name, makeConstraints(rPrecision, rScale, true));
registerDecimalBinary<Func, false>(
name + kDenyPrecisionLoss, makeConstraints(rPrecision, rScale, false));
return {rPrecision, rScale};
}

} // namespace

template <typename TExec>
using AddFunctionAllowPrecisionLoss = DecimalAddFunction<TExec, true>;

template <typename TExec>
using AddFunctionDenyPrecisionLoss = DecimalAddFunction<TExec, false>;

void registerDecimalAdd(const std::string& prefix) {
registerDecimalAddSubtract<DecimalAddFunction>(prefix + "add");
auto [rPrecision, rScale] = getAddSubtractResultPrecisionScale();
registerDecimalBinary<AddFunctionAllowPrecisionLoss>(
prefix + "add", makeConstraints(rPrecision, rScale, true));
registerDecimalBinary<AddFunctionDenyPrecisionLoss>(
prefix + "add" + kDenyPrecisionLoss,
makeConstraints(rPrecision, rScale, false));
}

template <typename TExec>
using SubtractFunctionAllowPrecisionLoss = DecimalSubtractFunction<TExec, true>;

template <typename TExec>
using SubtractFunctionDenyPrecisionLoss = DecimalSubtractFunction<TExec, false>;

void registerDecimalSubtract(const std::string& prefix) {
registerDecimalAddSubtract<DecimalSubtractFunction>(prefix + "subtract");
auto [rPrecision, rScale] = getAddSubtractResultPrecisionScale();
registerDecimalBinary<SubtractFunctionAllowPrecisionLoss>(
prefix + "subtract", makeConstraints(rPrecision, rScale, true));
registerDecimalBinary<SubtractFunctionDenyPrecisionLoss>(
prefix + "subtract" + kDenyPrecisionLoss,
makeConstraints(rPrecision, rScale, false));
}

template <typename TExec>
using MultiplyFunctionAllowPrecisionLoss = DecimalMultiplyFunction<TExec, true>;

template <typename TExec>
using MultiplyFunctionDenyPrecisionLoss = DecimalMultiplyFunction<TExec, false>;

void registerDecimalMultiply(const std::string& prefix) {
std::string rPrecision = fmt::format(
"{a_precision} + {b_precision} + 1",
Expand All @@ -622,15 +626,14 @@ void registerDecimalMultiply(const std::string& prefix) {
"{a_scale} + {b_scale}",
fmt::arg("a_scale", S1::name()),
fmt::arg("b_scale", S2::name()));
registerDecimalBinary<DecimalMultiplyFunction, true>(
registerDecimalBinary<MultiplyFunctionAllowPrecisionLoss>(
prefix + "multiply", makeConstraints(rPrecision, rScale, true));
registerDecimalBinary<DecimalMultiplyFunction, false>(
registerDecimalBinary<MultiplyFunctionDenyPrecisionLoss>(
prefix + "multiply" + kDenyPrecisionLoss,
makeConstraints(rPrecision, rScale, false));
}

std::vector<exec::SignatureVariable>
getDivideConstraintsNotAllowPrecisionLoss() {
std::vector<exec::SignatureVariable> getDivideConstraintsDenyPrecisionLoss() {
std::string wholeDigits = fmt::format(
"min(38, {a_precision} - {a_scale} + {b_scale})",
fmt::arg("a_precision", P1::name()),
Expand Down Expand Up @@ -677,36 +680,38 @@ std::vector<exec::SignatureVariable> getDivideConstraintsAllowPrecisionLoss() {
return makeConstraints(rPrecision, rScale, true);
}

template <bool allowPrecisionLoss>
void registerDecimalDivide(const std::string& prefix) {
std::vector<exec::SignatureVariable> constraints;
std::string functionName = prefix + "divide";
if constexpr (allowPrecisionLoss) {
constraints = getDivideConstraintsAllowPrecisionLoss();
} else {
constraints = getDivideConstraintsNotAllowPrecisionLoss();
functionName += kDenyPrecisionLoss;
}
registerDecimalBinary<DecimalDivideFunction, allowPrecisionLoss>(
functionName, constraints);
template <typename TExec>
using DivideFunctionAllowPrecisionLoss = DecimalDivideFunction<TExec, true>;

template <typename TExec>
using DivideFunctionDenyPrecisionLoss = DecimalDivideFunction<TExec, false>;

template <template <class> typename Func>
void registerDecimalDivide(
const std::string& functionName,
std::vector<exec::SignatureVariable> constraints) {
registerDecimalBinary<Func>(functionName, constraints);

// (short, long) -> short
registerFunction<
ParameterBinder<DecimalDivideFunction, allowPrecisionLoss>,
Func,
ShortDecimal<P3, S3>,
ShortDecimal<P1, S1>,
LongDecimal<P2, S2>>({functionName}, constraints);

// (long, short) -> short
registerFunction<
ParameterBinder<DecimalDivideFunction, allowPrecisionLoss>,
Func,
ShortDecimal<P3, S3>,
LongDecimal<P1, S1>,
ShortDecimal<P2, S2>>({functionName}, constraints);
}

void registerDecimalDivide(const std::string& prefix) {
registerDecimalDivide<true>(prefix);
registerDecimalDivide<false>(prefix);
registerDecimalDivide<DivideFunctionAllowPrecisionLoss>(
prefix + "divide", getDivideConstraintsAllowPrecisionLoss());
registerDecimalDivide<DivideFunctionDenyPrecisionLoss>(
prefix + "divide" + kDenyPrecisionLoss,
getDivideConstraintsDenyPrecisionLoss());
}
} // namespace facebook::velox::functions::sparksql

0 comments on commit 61f976b

Please sign in to comment.