Skip to content

Commit

Permalink
Support custom comparison in Histogram Aggregation (attempt 2) (#11154)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #11154

This is a second attempt to land the changes in
#11120

The original description:
Building on #11021 this adds support for custom
comparison functions provided by custom types in the Histogram aggregationt.

New context:
I landed this along with #11119 so it
got reverted along with it.  This particular change did not introduce any issues
though.

Reviewed By: xiaoxmeng

Differential Revision: D63795479

fbshipit-source-id: 6c045c8648abe600e5656d3874c47aba21186168
  • Loading branch information
Kevin Wilfong authored and facebook-github-bot committed Oct 3, 2024
1 parent 118759b commit 7ef70fc
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 6 deletions.
100 changes: 94 additions & 6 deletions velox/functions/prestosql/aggregates/HistogramAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,77 @@ struct Accumulator {
}
};

// A wrapper around Accumulator that overrides hash and equal_to functions to
// use the custom comparisons provided by a custom type.
template <TypeKind Kind>
struct CustomComparisonAccumulator {
using NativeType = typename TypeTraits<Kind>::NativeType;

struct Hash {
const TypePtr& type;

size_t operator()(const NativeType& value) const {
return static_cast<const CanProvideCustomComparisonType<Kind>*>(
type.get())
->hash(value);
}
};

struct EqualTo {
const TypePtr& type;

bool operator()(const NativeType& left, const NativeType& right) const {
return static_cast<const CanProvideCustomComparisonType<Kind>*>(
type.get())
->compare(left, right) == 0;
}
};

// The underlying Accumulator to which all operations are delegated.
Accumulator<
NativeType,
CustomComparisonAccumulator::Hash,
CustomComparisonAccumulator::EqualTo>
base;

CustomComparisonAccumulator(
const TypePtr& type,
HashStringAllocator* allocator)
: base{
CustomComparisonAccumulator::Hash{type},
CustomComparisonAccumulator::EqualTo{type},
allocator} {}

size_t size() const {
return base.size();
}

void addValue(
DecodedVector& decoded,
vector_size_t index,
HashStringAllocator* allocator) {
base.addValue(decoded, index, allocator);
}

void addValueWithCount(
NativeType value,
int64_t count,
HashStringAllocator* allocator) {
base.addValueWithCount(value, count, allocator);
}

void extractValues(
FlatVector<NativeType>& keys,
FlatVector<int64_t>& counts,
vector_size_t offset) {
base.extractValues(keys, counts, offset);
}

void free(HashStringAllocator& allocator) {
base.free(allocator);
}
};

struct StringViewAccumulator {
/// A map of unique StringViews pointing to storage managed by 'strings'.
Accumulator<StringView> base;
Expand Down Expand Up @@ -269,14 +340,15 @@ FOLLY_ALWAYS_INLINE void addToFinalAggregation(
}
}

template <typename T>
template <
typename T,
typename AccumulatorType =
typename AccumulatorTypeTraits<T>::AccumulatorType>
class HistogramAggregate : public exec::Aggregate {
public:
explicit HistogramAggregate(TypePtr resultType)
: Aggregate(std::move(resultType)) {}

using AccumulatorType = typename AccumulatorTypeTraits<T>::AccumulatorType;

int32_t accumulatorFixedWidthSize() const override {
return sizeof(AccumulatorType);
}
Expand Down Expand Up @@ -480,6 +552,14 @@ class HistogramAggregate : public exec::Aggregate {
DecodedVector decodedIntermediate_;
};

template <TypeKind Kind>
std::unique_ptr<exec::Aggregate> createHistogramAggregateWithCustomCompare(
const TypePtr& resultType) {
return std::make_unique<HistogramAggregate<
typename TypeTraits<Kind>::NativeType,
CustomComparisonAccumulator<Kind>>>(resultType);
}

} // namespace

void registerHistogramAggregate(
Expand Down Expand Up @@ -508,9 +588,17 @@ void registerHistogramAggregate(
VELOX_CHECK_EQ(
argTypes.size(), 1, "{}: unexpected number of arguments", name);

auto inputType = argTypes[0];
switch (exec::isRawInput(step) ? inputType->kind()
: inputType->childAt(0)->kind()) {
auto inputType =
exec::isRawInput(step) ? argTypes[0] : argTypes[0]->childAt(0);

if (inputType->providesCustomComparison()) {
return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
createHistogramAggregateWithCustomCompare,
inputType->kind(),
resultType);
}

switch (inputType->kind()) {
case TypeKind::BOOLEAN:
return std::make_unique<HistogramAggregate<bool>>(resultType);
case TypeKind::TINYINT:
Expand Down
46 changes: 46 additions & 0 deletions velox/functions/prestosql/aggregates/tests/HistogramTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "velox/exec/RowContainer.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"

using namespace facebook::velox::exec;
using namespace facebook::velox::exec::test;
Expand Down Expand Up @@ -207,6 +208,32 @@ TEST_F(HistogramTest, groupByString) {
testGlobalHistogramWithDuck(data);
}

TEST_F(HistogramTest, groupByTimestampWithTimezones) {
auto vector = makeFlatVector<int64_t>(
{pack(0, 0),
pack(1, 0),
pack(2, 0),
pack(0, 1),
pack(1, 1),
pack(1, 2),
pack(2, 2),
pack(3, 3),
pack(1, 1),
pack(3, 0)},
TIMESTAMP_WITH_TIME_ZONE());

auto keys = makeFlatVector<int16_t>(10, [](auto row) { return row % 2; });

auto expected = makeRowVector(
{makeFlatVector<int16_t>({0, 1}),
makeMapVector<int64_t, int64_t>(
{{{pack(0, 0), 1}, {pack(1, 1), 2}, {pack(2, 0), 2}},
{{pack(0, 1), 1}, {pack(1, 0), 2}, {pack(3, 3), 2}}},
MAP(TIMESTAMP_WITH_TIME_ZONE(), BIGINT()))});

testHistogram("histogram(c1)", {"c0"}, keys, vector, expected);
}

TEST_F(HistogramTest, globalInteger) {
vector_size_t num = 29;
auto vector = makeFlatVector<int32_t>(
Expand Down Expand Up @@ -319,6 +346,25 @@ TEST_F(HistogramTest, globalNaNs) {
testHistogram("histogram(c1)", {}, vector, vector, expected);
}

TEST_F(HistogramTest, globalTimestampWithTimezones) {
auto vector = makeFlatVector<int64_t>(
{pack(0, 0),
pack(1, 0),
pack(2, 0),
pack(0, 1),
pack(1, 1),
pack(1, 2),
pack(2, 2),
pack(3, 3)},
TIMESTAMP_WITH_TIME_ZONE());

auto expected = makeRowVector({makeMapVector<int64_t, int64_t>(
{{{pack(0, 0), 2}, {pack(1, 0), 3}, {pack(2, 0), 2}, {pack(3, 3), 1}}},
MAP(TIMESTAMP_WITH_TIME_ZONE(), BIGINT()))});

testHistogram("histogram(c1)", {}, vector, vector, expected);
}

TEST_F(HistogramTest, arrays) {
auto input = makeRowVector({
makeFlatVector<int64_t>({0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}),
Expand Down

0 comments on commit 7ef70fc

Please sign in to comment.