Skip to content

Commit

Permalink
Support custom comparison in Multi Map Aggregation (#11133)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #11133

Building on #11021 this adds support
for custom comparison functions provided by custom types in the multi map
aggregation.

Reviewed By: xiaoxmeng

Differential Revision: D63645720

fbshipit-source-id: b94a8c94301f79661b58c17fa6b1a82d18a7db41
  • Loading branch information
Kevin Wilfong authored and facebook-github-bot committed Oct 1, 2024
1 parent 811fbde commit 7ad64b3
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 5 deletions.
118 changes: 113 additions & 5 deletions velox/functions/prestosql/aggregates/MultiMapAggAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,97 @@ struct ComplexTypeMultiMapAccumulator {
}
};

// A wrapper around MultiMapAccumulator that overrides hash and equal_to
// functions to use the custom comparisons provided by a custom type.
template <TypeKind Kind>
struct CustomComparisonMultiMapAccumulator {
using NativeType = typename TypeTraits<Kind>::NativeType;

struct Hash {
const TypePtr& type;

size_t operator()(const NativeType& value) const {
return static_cast<const CanProvideCustomComparisonType<Kind>*>(
type.get())
->hash(value);
}
};

struct EqualTo {
const TypePtr& type;

bool operator()(const NativeType& left, const NativeType& right) const {
return static_cast<const CanProvideCustomComparisonType<Kind>*>(
type.get())
->compare(left, right) == 0;
}
};

// The underlying MultiMapAccumulator to which all operations are
// delegated.
MultiMapAccumulator<
NativeType,
CustomComparisonMultiMapAccumulator::Hash,
CustomComparisonMultiMapAccumulator::EqualTo>
base;

CustomComparisonMultiMapAccumulator(
const TypePtr& type,
HashStringAllocator* allocator)
: base{
CustomComparisonMultiMapAccumulator::Hash{type},
CustomComparisonMultiMapAccumulator::EqualTo{type},
allocator} {}

size_t size() const {
return base.size();
}

size_t numValues() const {
return base.numValues();
}

// Adds key-value pair.
void insert(
const DecodedVector& decodedKeys,
const DecodedVector& decodedValues,
vector_size_t index,
HashStringAllocator& allocator) {
base.insert(decodedKeys, decodedValues, index, allocator);
}

// Adds a key with a list of values.
void insertMultiple(
const DecodedVector& decodedKeys,
vector_size_t keyIndex,
const DecodedVector& decodedValues,
vector_size_t valueIndex,
vector_size_t numValues,
HashStringAllocator& allocator) {
base.insertMultiple(
decodedKeys, keyIndex, decodedValues, valueIndex, numValues, allocator);
}

ValueList& insertKey(
const DecodedVector& decodedKeys,
vector_size_t index,
HashStringAllocator& allocator) {
return base.insertKey(decodedKeys, index, allocator);
}

void extract(
VectorPtr& mapKeys,
ArrayVector& mapValueArrays,
vector_size_t& keyOffset,
vector_size_t& valueOffset) {
base.extract(mapKeys, mapValueArrays, keyOffset, valueOffset);
}

void free(HashStringAllocator& allocator) {
base.free(allocator);
}
};

template <typename T>
struct MultiMapAccumulatorTypeTraits {
using AccumulatorType = MultiMapAccumulator<T>;
Expand Down Expand Up @@ -255,15 +346,15 @@ struct MultiMapAccumulatorTypeTraits<ComplexType> {
using AccumulatorType = ComplexTypeMultiMapAccumulator;
};

template <typename K>
template <
typename K,
typename AccumulatorType =
typename MultiMapAccumulatorTypeTraits<K>::AccumulatorType>
class MultiMapAggAggregate : public exec::Aggregate {
public:
explicit MultiMapAggAggregate(TypePtr resultType)
: exec::Aggregate(std::move(resultType)) {}

using AccumulatorType =
typename MultiMapAccumulatorTypeTraits<K>::AccumulatorType;

bool isFixedSize() const override {
return false;
}
Expand Down Expand Up @@ -496,6 +587,14 @@ class MultiMapAggAggregate : public exec::Aggregate {
DecodedVector decodedValueArrays_;
};

template <TypeKind Kind>
std::unique_ptr<exec::Aggregate> createMultiMapAggAggregateWithCustomCompare(
const TypePtr& resultType) {
return std::make_unique<MultiMapAggAggregate<
typename TypeTraits<Kind>::NativeType,
CustomComparisonMultiMapAccumulator<Kind>>>(resultType);
}

} // namespace

void registerMultiMapAggAggregate(
Expand All @@ -522,7 +621,16 @@ void registerMultiMapAggAggregate(
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
auto typeKind = resultType->childAt(0)->kind();
const auto keyType = resultType->childAt(0);
const auto typeKind = keyType->kind();

if (keyType->providesCustomComparison()) {
return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
createMultiMapAggAggregateWithCustomCompare,
typeKind,
resultType);
}

switch (typeKind) {
case TypeKind::BOOLEAN:
return std::make_unique<MultiMapAggAggregate<bool>>(resultType);
Expand Down
129 changes: 129 additions & 0 deletions velox/functions/prestosql/aggregates/tests/MultiMapAggTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cmath>
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"

namespace facebook::velox::aggregate::prestosql {
namespace {
Expand Down Expand Up @@ -346,5 +347,133 @@ TEST_F(MultiMapAggTest, doubleKeyGlobal) {
{expected});
}

TEST_F(MultiMapAggTest, timestampWithTimeZoneGlobal) {
auto data = makeRowVector(
{makeFlatVector<int64_t>(
{pack(0, 0),
pack(1, 0),
pack(2, 0),
pack(0, 1),
pack(1, 1),
pack(1, 2),
pack(2, 2),
pack(3, 3),
pack(1, 1),
pack(3, 0)},
TIMESTAMP_WITH_TIME_ZONE()),
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10})});

auto expected = makeRowVector({makeMapVector(
{0},
makeFlatVector<int64_t>(
{pack(0, 0), pack(1, 0), pack(2, 0), pack(3, 3)},
TIMESTAMP_WITH_TIME_ZONE()),
makeArrayVector<int32_t>({{1, 4}, {2, 5, 6, 9}, {3, 7}, {8, 10}}))});

testAggregations(
{data},
{},
{"multimap_agg(c0, c1)"},
// Sort the result arrays to ensure deterministic results.
{"transform_values(a0, (k, v) -> array_sort(v))"},
{expected});

// Input keys are complex type (row).
data = makeRowVector(
{makeRowVector({makeFlatVector<int64_t>(
{pack(0, 0),
pack(1, 0),
pack(2, 0),
pack(0, 1),
pack(1, 1),
pack(1, 2),
pack(2, 2),
pack(3, 3),
pack(1, 1),
pack(3, 0)},
TIMESTAMP_WITH_TIME_ZONE())}),
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10})});

expected = makeRowVector({makeMapVector(
{0},
makeRowVector({makeFlatVector<int64_t>(
{pack(0, 0), pack(1, 0), pack(2, 0), pack(3, 3)},
TIMESTAMP_WITH_TIME_ZONE())}),
makeArrayVector<int32_t>({{1, 4}, {2, 5, 6, 9}, {3, 7}, {8, 10}}))});

testAggregations(
{data},
{},
{"multimap_agg(c0, c1)"},
// Sort the result arrays to ensure deterministic results.
{"transform_values(a0, (k, v) -> array_sort(v))"},
{expected});
}

TEST_F(MultiMapAggTest, timestampWithTimeZoneGroupBy) {
auto data = makeRowVector(
{makeFlatVector<int64_t>(
{pack(0, 0),
pack(1, 0),
pack(1, 0),
pack(0, 1),
pack(1, 1),
pack(1, 2),
pack(1, 2),
pack(0, 3),
pack(1, 1),
pack(0, 0)},
TIMESTAMP_WITH_TIME_ZONE()),
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
makeFlatVector<int32_t>({1, 1, 1, 1, 1, 2, 2, 2, 2, 2})});

auto expected = makeRowVector({makeMapVector(
{0, 2},
makeFlatVector<int64_t>(
{pack(0, 0), pack(1, 0), pack(0, 3), pack(1, 2)},
TIMESTAMP_WITH_TIME_ZONE()),
makeArrayVector<int32_t>({{1, 4}, {2, 3, 5}, {8, 10}, {6, 7, 9}}))});

testAggregations(
{data},
{"c2"},
{"multimap_agg(c0, c1)"},
// Sort the result arrays to ensure deterministic results.
{"transform_values(a0, (k, v) -> array_sort(v))"},
{expected});

// Input keys are complex type (row).
data = makeRowVector(
{makeRowVector({makeFlatVector<int64_t>(
{pack(0, 0),
pack(1, 0),
pack(1, 0),
pack(0, 1),
pack(1, 1),
pack(1, 2),
pack(1, 2),
pack(0, 3),
pack(1, 1),
pack(0, 0)},
TIMESTAMP_WITH_TIME_ZONE())}),
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
makeFlatVector<int32_t>({1, 1, 1, 1, 1, 2, 2, 2, 2, 2})});

expected = makeRowVector({makeMapVector(
{0, 2},
makeRowVector({makeFlatVector<int64_t>(
{pack(0, 0), pack(1, 0), pack(0, 3), pack(1, 2)},
TIMESTAMP_WITH_TIME_ZONE())}),
makeArrayVector<int32_t>({{1, 4}, {2, 3, 5}, {8, 10}, {6, 7, 9}}))});

testAggregations(
{data},
{"c2"},
{"multimap_agg(c0, c1)"},
// Sort the result arrays to ensure deterministic results.
{"transform_values(a0, (k, v) -> array_sort(v))"},
{expected});
}

} // namespace
} // namespace facebook::velox::aggregate::prestosql

0 comments on commit 7ad64b3

Please sign in to comment.