From b6604148cca24784c2df8382d42834d96a98cdc5 Mon Sep 17 00:00:00 2001 From: Kevin Wilfong Date: Tue, 24 Sep 2024 17:15:54 -0700 Subject: [PATCH] Support custom comparison in ContainerRowSerde (#11023) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/11023 Building on https://github.com/facebookincubator/velox/pull/11021 this adds support for custom comparison functions provided by custom types in ContainerRowSerde. Compare was already handled by updating SimpleVector, so this just updates the hash function and adds tests. Reviewed By: kgpai Differential Revision: D62902814 --- velox/exec/ContainerRowSerde.cpp | 370 ++++++++++++++++----- velox/exec/tests/CMakeLists.txt | 1 + velox/exec/tests/ContainerRowSerdeTest.cpp | 150 +++++++-- 3 files changed, 409 insertions(+), 112 deletions(-) diff --git a/velox/exec/ContainerRowSerde.cpp b/velox/exec/ContainerRowSerde.cpp index 70d698089ec5..415aad140093 100644 --- a/velox/exec/ContainerRowSerde.cpp +++ b/velox/exec/ContainerRowSerde.cpp @@ -350,13 +350,21 @@ void deserializeSwitch( } // Comparison of serialization and vector. +template std::optional compareSwitch( ByteInputStream& stream, const BaseVector& vector, vector_size_t index, CompareFlags flags); -template +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t< + Kind != TypeKind::VARCHAR && Kind != TypeKind::VARBINARY && + Kind != TypeKind::ARRAY && Kind != TypeKind::MAP && + Kind != TypeKind::ROW, + int32_t> = 0> std::optional compare( ByteInputStream& left, const BaseVector& right, @@ -365,10 +373,16 @@ std::optional compare( using T = typename TypeTraits::NativeType; auto rightValue = right.asUnchecked>()->valueAt(index); auto leftValue = left.read(); - auto result = right.typeUsesCustomComparison() - ? SimpleVector::template comparePrimitiveAscWithCustomComparison( - right.type().get(), leftValue, rightValue) - : SimpleVector::comparePrimitiveAsc(leftValue, rightValue); + + int result; + if constexpr (typeProvidesCustomComparison) { + result = + SimpleVector::template comparePrimitiveAscWithCustomComparison( + right.type().get(), leftValue, rightValue); + } else { + result = SimpleVector::comparePrimitiveAsc(leftValue, rightValue); + } + return flags.ascending ? result : result * -1; } @@ -398,8 +412,11 @@ int compareStringAsc( return leftSize - rightView.size(); } -template <> -std::optional compare( +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +std::optional compare( ByteInputStream& left, const BaseVector& right, vector_size_t index, @@ -408,8 +425,11 @@ std::optional compare( return flags.ascending ? result : result * -1; } -template <> -std::optional compare( +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +std::optional compare( ByteInputStream& left, const BaseVector& right, vector_size_t index, @@ -418,8 +438,11 @@ std::optional compare( return flags.ascending ? result : result * -1; } -template <> -std::optional compare( +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +std::optional compare( ByteInputStream& left, const BaseVector& right, vector_size_t index, @@ -444,7 +467,13 @@ std::optional compare( return result; } - auto result = compareSwitch(left, *child, wrappedIndex, flags); + std::optional result; + if (child->typeUsesCustomComparison()) { + result = compareSwitch(left, *child, wrappedIndex, flags); + } else { + result = compareSwitch(left, *child, wrappedIndex, flags); + } + if (result.has_value() && result.value() == 0) { continue; } @@ -453,6 +482,7 @@ std::optional compare( return 0; } +template std::optional compareArrays( ByteInputStream& left, const BaseVector& elements, @@ -479,7 +509,8 @@ std::optional compareArrays( } auto elementIndex = elements.wrappedIndex(offset + i); - auto result = compareSwitch(left, *wrappedElements, elementIndex, flags); + auto result = compareSwitch( + left, *wrappedElements, elementIndex, flags); if (result.has_value() && result.value() == 0) { continue; } @@ -488,6 +519,7 @@ std::optional compareArrays( return flags.ascending ? (leftSize - rightSize) : (rightSize - leftSize); } +template std::optional compareArrayIndices( ByteInputStream& left, const BaseVector& elements, @@ -514,7 +546,8 @@ std::optional compareArrayIndices( } auto elementIndex = elements.wrappedIndex(rightIndices[i]); - auto result = compareSwitch(left, *wrappedElements, elementIndex, flags); + auto result = compareSwitch( + left, *wrappedElements, elementIndex, flags); if (result.has_value() && result.value() == 0) { continue; } @@ -523,8 +556,11 @@ std::optional compareArrayIndices( return flags.ascending ? (leftSize - rightSize) : (rightSize - leftSize); } -template <> -std::optional compare( +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +std::optional compare( ByteInputStream& left, const BaseVector& right, vector_size_t index, @@ -532,16 +568,28 @@ std::optional compare( auto array = right.wrappedVector()->asUnchecked(); VELOX_CHECK_EQ(array->encoding(), VectorEncoding::Simple::ARRAY); auto wrappedIndex = right.wrappedIndex(index); - return compareArrays( - left, - *array->elements(), - array->offsetAt(wrappedIndex), - array->sizeAt(wrappedIndex), - flags); + if (array->type()->childAt(0)->providesCustomComparison()) { + return compareArrays( + left, + *array->elements(), + array->offsetAt(wrappedIndex), + array->sizeAt(wrappedIndex), + flags); + } else { + return compareArrays( + left, + *array->elements(), + array->offsetAt(wrappedIndex), + array->sizeAt(wrappedIndex), + flags); + } } -template <> -std::optional compare( +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +std::optional compare( ByteInputStream& left, const BaseVector& right, vector_size_t index, @@ -552,20 +600,42 @@ std::optional compare( auto size = map->sizeAt(wrappedIndex); std::vector indices(size); auto rightIndices = map->sortedKeyIndices(wrappedIndex); - auto result = compareArrayIndices(left, *map->mapKeys(), rightIndices, flags); + std::optional result; + + if (map->type()->childAt(0)->providesCustomComparison()) { + result = + compareArrayIndices(left, *map->mapKeys(), rightIndices, flags); + } else { + result = + compareArrayIndices(left, *map->mapKeys(), rightIndices, flags); + } + if (result.has_value() && result.value() == 0) { - return compareArrayIndices(left, *map->mapValues(), rightIndices, flags); + if (map->type()->childAt(1)->providesCustomComparison()) { + return compareArrayIndices( + left, *map->mapValues(), rightIndices, flags); + } else { + return compareArrayIndices( + left, *map->mapValues(), rightIndices, flags); + } } return result; } +template std::optional compareSwitch( ByteInputStream& stream, const BaseVector& vector, vector_size_t index, CompareFlags flags) { - return VELOX_DYNAMIC_TYPE_DISPATCH( - compare, vector.typeKind(), stream, vector, index, flags); + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( + compare, + typeProvidesCustomComparison, + vector.typeKind(), + stream, + vector, + index, + flags); } // Returns a view over a serialized string with the string as a @@ -585,13 +655,21 @@ StringView readStringView(ByteInputStream& stream, std::string& storage) { } // Comparison of two serializations. +template std::optional compareSwitch( ByteInputStream& left, ByteInputStream& right, const Type* type, CompareFlags flags); -template +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t< + Kind != TypeKind::VARCHAR && Kind != TypeKind::VARBINARY && + Kind != TypeKind::ARRAY && Kind != TypeKind::MAP && + Kind != TypeKind::ROW, + int32_t> = 0> std::optional compare( ByteInputStream& left, ByteInputStream& right, @@ -600,15 +678,24 @@ std::optional compare( using T = typename TypeTraits::NativeType; T leftValue = left.read(); T rightValue = right.read(); - auto result = type->providesCustomComparison() - ? SimpleVector::template comparePrimitiveAscWithCustomComparison( - type, leftValue, rightValue) - : SimpleVector::comparePrimitiveAsc(leftValue, rightValue); + + int result; + if constexpr (typeProvidesCustomComparison) { + result = + SimpleVector::template comparePrimitiveAscWithCustomComparison( + type, leftValue, rightValue); + } else { + result = SimpleVector::comparePrimitiveAsc(leftValue, rightValue); + } + return flags.ascending ? result : result * -1; } -template <> -std::optional compare( +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +std::optional compare( ByteInputStream& left, ByteInputStream& right, const Type* /*type*/, @@ -621,8 +708,11 @@ std::optional compare( : rightValue.compare(leftValue); } -template <> -std::optional compare( +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +std::optional compare( ByteInputStream& left, ByteInputStream& right, const Type* /*type*/, @@ -635,6 +725,7 @@ std::optional compare( : rightValue.compare(leftValue); } +template std::optional compareArrays( ByteInputStream& left, ByteInputStream& right, @@ -659,7 +750,8 @@ std::optional compareArrays( return result; } - auto result = compareSwitch(left, right, elementType, flags); + auto result = compareSwitch( + left, right, elementType, flags); if (result.has_value() && result.value() == 0) { continue; } @@ -668,8 +760,11 @@ std::optional compareArrays( return flags.ascending ? (leftSize - rightSize) : (rightSize - leftSize); } -template <> -std::optional compare( +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +std::optional compare( ByteInputStream& left, ByteInputStream& right, const Type* type, @@ -689,7 +784,14 @@ std::optional compare( return result; } - auto result = compareSwitch(left, right, rowType.childAt(i).get(), flags); + std::optional result; + const auto& childType = rowType.childAt(i); + if (childType->providesCustomComparison()) { + result = compareSwitch(left, right, childType.get(), flags); + } else { + result = compareSwitch(left, right, childType.get(), flags); + } + if (result.has_value() && result.value() == 0) { continue; } @@ -698,66 +800,115 @@ std::optional compare( return 0; } -template <> -std::optional compare( +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +std::optional compare( ByteInputStream& left, ByteInputStream& right, const Type* type, CompareFlags flags) { - return compareArrays(left, right, type->childAt(0).get(), flags); + const auto& elementType = type->childAt(0); + + if (elementType->providesCustomComparison()) { + return compareArrays(left, right, elementType.get(), flags); + } else { + return compareArrays(left, right, elementType.get(), flags); + } } -template <> -std::optional compare( +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +std::optional compare( ByteInputStream& left, ByteInputStream& right, const Type* type, CompareFlags flags) { - auto result = compareArrays(left, right, type->childAt(0).get(), flags); + std::optional result; + const auto& keyType = type->childAt(0); + const auto& valueType = type->childAt(1); + + if (keyType->providesCustomComparison()) { + result = compareArrays(left, right, keyType.get(), flags); + } else { + result = compareArrays(left, right, keyType.get(), flags); + } + if (result.has_value() && result.value() == 0) { - return compareArrays(left, right, type->childAt(1).get(), flags); + if (valueType->providesCustomComparison()) { + return compareArrays(left, right, valueType.get(), flags); + } else { + return compareArrays(left, right, valueType.get(), flags); + } } return result; } +template std::optional compareSwitch( ByteInputStream& left, ByteInputStream& right, const Type* type, CompareFlags flags) { - return VELOX_DYNAMIC_TYPE_DISPATCH( - compare, type->kind(), left, right, type, flags); + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( + compare, + typeProvidesCustomComparison, + type->kind(), + left, + right, + type, + flags); } // Hash functions. +template uint64_t hashSwitch(ByteInputStream& stream, const Type* type); -template -uint64_t hashOne(ByteInputStream& stream, const Type* /*type*/) { +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t< + Kind != TypeKind::VARBINARY && Kind != TypeKind::VARCHAR && + Kind != TypeKind::ARRAY && Kind != TypeKind::MAP && + Kind != TypeKind::ROW, + int32_t> = 0> +uint64_t hashOne(ByteInputStream& stream, const Type* type) { using T = typename TypeTraits::NativeType; - if constexpr (std::is_floating_point_v) { - return util::floating_point::NaNAwareHash()(stream.read()); + + T value = stream.read(); + + if constexpr (typeProvidesCustomComparison) { + return static_cast*>(type)->hash( + value); + } else if constexpr (std::is_floating_point_v) { + return util::floating_point::NaNAwareHash()(value); } else { - return folly::hasher()(stream.read()); + return folly::hasher()(value); } } -template <> -uint64_t hashOne( - ByteInputStream& stream, - const Type* /*type*/) { +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +uint64_t hashOne(ByteInputStream& stream, const Type* /*type*/) { std::string storage; return folly::hasher()(readStringView(stream, storage)); } -template <> -uint64_t hashOne( - ByteInputStream& stream, - const Type* /*type*/) { +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +uint64_t hashOne(ByteInputStream& stream, const Type* /*type*/) { std::string storage; return folly::hasher()(readStringView(stream, storage)); } +template uint64_t hashArray(ByteInputStream& in, uint64_t hash, const Type* elementType) { auto size = in.read(); @@ -767,15 +918,18 @@ hashArray(ByteInputStream& in, uint64_t hash, const Type* elementType) { if (bits::isBitSet(nulls.data(), i)) { value = BaseVector::kNullHash; } else { - value = hashSwitch(in, elementType); + value = hashSwitch(in, elementType); } hash = bits::commutativeHashMix(hash, value); } return hash; } -template <> -uint64_t hashOne(ByteInputStream& in, const Type* type) { +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +uint64_t hashOne(ByteInputStream& in, const Type* type) { auto size = type->size(); auto nulls = readNulls(in, size); uint64_t hash = BaseVector::kNullHash; @@ -784,28 +938,58 @@ uint64_t hashOne(ByteInputStream& in, const Type* type) { if (bits::isBitSet(nulls.data(), i)) { value = BaseVector::kNullHash; } else { - value = hashSwitch(in, type->childAt(i).get()); + const auto& childType = type->childAt(i); + if (childType->providesCustomComparison()) { + value = hashSwitch(in, childType.get()); + } else { + value = hashSwitch(in, childType.get()); + } } hash = i == 0 ? value : bits::hashMix(hash, value); } return hash; } -template <> -uint64_t hashOne(ByteInputStream& in, const Type* type) { - return hashArray(in, BaseVector::kNullHash, type->childAt(0).get()); +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +uint64_t hashOne(ByteInputStream& in, const Type* type) { + const auto& elementType = type->childAt(0); + + if (elementType->providesCustomComparison()) { + return hashArray(in, BaseVector::kNullHash, elementType.get()); + } else { + return hashArray(in, BaseVector::kNullHash, elementType.get()); + } } -template <> -uint64_t hashOne(ByteInputStream& in, const Type* type) { - return hashArray( - in, - hashArray(in, BaseVector::kNullHash, type->childAt(0).get()), - type->childAt(1).get()); +template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> +uint64_t hashOne(ByteInputStream& in, const Type* type) { + const auto& keyType = type->childAt(0); + const auto& valueType = type->childAt(1); + + uint64_t hash; + if (keyType->providesCustomComparison()) { + hash = hashArray(in, BaseVector::kNullHash, keyType.get()); + } else { + hash = hashArray(in, BaseVector::kNullHash, keyType.get()); + } + + if (valueType->providesCustomComparison()) { + return hashArray(in, hash, valueType.get()); + } else { + return hashArray(in, hash, valueType.get()); + } } +template uint64_t hashSwitch(ByteInputStream& in, const Type* type) { - return VELOX_DYNAMIC_TYPE_DISPATCH(hashOne, type->kind(), in, type); + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( + hashOne, typeProvidesCustomComparison, type->kind(), in, type); } } // namespace @@ -838,7 +1022,13 @@ int32_t ContainerRowSerde::compare( VELOX_DCHECK( !right.isNullAt(index), "Null top-level values are not supported"); VELOX_DCHECK(flags.nullAsValue(), "not supported null handling mode"); - return compareSwitch(left, *right.base(), right.index(index), flags).value(); + if (right.base()->typeUsesCustomComparison()) { + return compareSwitch(left, *right.base(), right.index(index), flags) + .value(); + } else { + return compareSwitch(left, *right.base(), right.index(index), flags) + .value(); + } } // static @@ -849,7 +1039,11 @@ int32_t ContainerRowSerde::compare( CompareFlags flags) { VELOX_DCHECK(flags.nullAsValue(), "not supported null handling mode"); - return compareSwitch(left, right, type, flags).value(); + if (type->providesCustomComparison()) { + return compareSwitch(left, right, type, flags).value(); + } else { + return compareSwitch(left, right, type, flags).value(); + } } std::optional ContainerRowSerde::compareWithNulls( @@ -859,7 +1053,11 @@ std::optional ContainerRowSerde::compareWithNulls( CompareFlags flags) { VELOX_DCHECK( !right.isNullAt(index), "Null top-level values are not supported"); - return compareSwitch(left, *right.base(), right.index(index), flags); + if (right.base()->typeUsesCustomComparison()) { + return compareSwitch(left, *right.base(), right.index(index), flags); + } else { + return compareSwitch(left, *right.base(), right.index(index), flags); + } } std::optional ContainerRowSerde::compareWithNulls( @@ -867,12 +1065,20 @@ std::optional ContainerRowSerde::compareWithNulls( ByteInputStream& right, const Type* type, CompareFlags flags) { - return compareSwitch(left, right, type, flags); + if (type->providesCustomComparison()) { + return compareSwitch(left, right, type, flags); + } else { + return compareSwitch(left, right, type, flags); + } } // static uint64_t ContainerRowSerde::hash(ByteInputStream& in, const Type* type) { - return hashSwitch(in, type); + if (type->providesCustomComparison()) { + return hashSwitch(in, type); + } else { + return hashSwitch(in, type); + } } } // namespace facebook::velox::exec diff --git a/velox/exec/tests/CMakeLists.txt b/velox/exec/tests/CMakeLists.txt index 32a46e398432..8e0e404f67e6 100644 --- a/velox/exec/tests/CMakeLists.txt +++ b/velox/exec/tests/CMakeLists.txt @@ -136,6 +136,7 @@ target_link_libraries( velox_serialization velox_test_util velox_type + velox_type_test_lib velox_vector velox_vector_fuzzer velox_writer_fuzzer diff --git a/velox/exec/tests/ContainerRowSerdeTest.cpp b/velox/exec/tests/ContainerRowSerdeTest.cpp index 8663d75287b7..91e82cd790f5 100644 --- a/velox/exec/tests/ContainerRowSerdeTest.cpp +++ b/velox/exec/tests/ContainerRowSerdeTest.cpp @@ -20,6 +20,7 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/memory/HashStringAllocator.h" +#include "velox/type/tests/utils/CustomTypesForTesting.h" #include "velox/vector/fuzzer/VectorFuzzer.h" #include "velox/vector/tests/utils/VectorTestBase.h" @@ -165,20 +166,54 @@ class ContainerRowSerdeTest : public testing::Test, } } - void testCompare(const VectorPtr& vector) { - auto positions = serializeWithPositions(vector); + void testCompare(const VectorPtr& actual) { + testCompare(actual, actual); + } + + void testCompare(const VectorPtr& actual, const VectorPtr& expected) { + auto positionsActual = serializeWithPositions(actual); + auto positionsExpected = serializeWithPositions(expected); CompareFlags compareFlags = CompareFlags::equality(CompareFlags::NullHandlingMode::kNullAsValue); - DecodedVector decodedVector(*vector); + DecodedVector decodedVector(*expected); - for (auto i = 0; i < positions.size(); ++i) { - auto stream = HashStringAllocator::prepareRead(positions.at(i).header); + for (auto i = 0; i < positionsActual.size(); ++i) { + // Test comparing reading from a ByteInputStream and a DecodedVector. + auto actualStream = + HashStringAllocator::prepareRead(positionsActual.at(i).header); ASSERT_EQ( 0, - ContainerRowSerde::compare(*stream, decodedVector, i, compareFlags)) - << "at " << i << ": " << vector->toString(i); + ContainerRowSerde::compare( + *actualStream, decodedVector, i, compareFlags)) + << "at " << i << ": " << actual->toString(i) << " " + << expected->toString(i); + + // Test comparing reading from two ByteInputStreams. + actualStream = + HashStringAllocator::prepareRead(positionsActual.at(i).header); + auto expectedStream = + HashStringAllocator::prepareRead(positionsExpected.at(i).header); + ASSERT_EQ( + 0, + ContainerRowSerde::compare( + *actualStream, + *expectedStream, + actual->type().get(), + compareFlags)) + << "at " << i << ": " << actual->toString(i) << " " + << expected->toString(i); + + // Test comparing hashes. + actualStream = + HashStringAllocator::prepareRead(positionsActual.at(i).header); + + ASSERT_EQ( + expected->hashValueAt(i), + ContainerRowSerde::hash(*actualStream, actual->type().get())) + << "at " << i << ": " << actual->toString(i) << " " + << expected->toString(i); } } @@ -575,34 +610,89 @@ TEST_F(ContainerRowSerdeTest, fuzzCompare) { TEST_F(ContainerRowSerdeTest, nans) { // Verify that the NaNs with different representations are considered equal // and have the same hash value. - auto vector = makeNullableFlatVector( - {std::nan("1"), - std::nan("2"), - std::numeric_limits::quiet_NaN(), - std::numeric_limits::signaling_NaN()}); - // Compare with the same NaN value - auto expected = makeConstant(std::nan("1"), 4, vector->type()); - - auto positions = serializeWithPositions(vector); + testCompare( + makeFlatVector( + {std::nan("1"), + std::nan("2"), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::signaling_NaN()}), + // Compare with the same NaN value. + makeFlatVector( + {std::nan("1"), std::nan("1"), std::nan("1"), std::nan("1")})); +} - CompareFlags compareFlags = - CompareFlags::equality(CompareFlags::NullHandlingMode::kNullAsValue); +TEST_F(ContainerRowSerdeTest, customComparison) { + testCompare( + makeFlatVector( + {0, 1, 256, 257, 512, 513}, + test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON()), + // Compare with the same values based on the custom comparison. + makeFlatVector( + {0, 1, 0, 1, 0, 1}, test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())); +} - DecodedVector decodedVector(*expected); +TEST_F(ContainerRowSerdeTest, arrayOfCustomComparison) { + testCompare( + makeNullableArrayVector( + {{0, 1, 2}, + {256, 257, 258}, + {512, 513, 514}, + {3, 4, 5}, + {259, 260, 261}, + {515, 516, 517}, + {std::nullopt}}, + ARRAY(test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())), + // Compare with the same values based on the custom comparison. + makeNullableArrayVector( + {{0, 1, 2}, + {0, 1, 2}, + {0, 1, 2}, + {3, 4, 5}, + {3, 4, 5}, + {3, 4, 5}, + {std::nullopt}}, + ARRAY(test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON()))); +} - for (auto i = 0; i < positions.size(); ++i) { - auto stream = HashStringAllocator::prepareRead(positions.at(i).header); - ASSERT_EQ( - 0, ContainerRowSerde::compare(*stream, decodedVector, i, compareFlags)) - << "at " << i << ": " << vector->toString(i); +TEST_F(ContainerRowSerdeTest, mapOfCustomComparison) { + testCompare( + makeNullableMapVector( + {{{{0, 10}, {1, 11}, {2, 12}}}, + {{{256, 266}, {257, 267}, {258, 268}}}, + {{{512, 522}, {513, 523}, {514, 524}}}, + {{{3, 103}, {4, 104}, {5, 105}}}, + {{{259, 359}, {260, 360}, {261, 361}}}, + {{{515, 615}, {516, 616}, {517, 617}}}, + {{{0, std::nullopt}}}}, + MAP(test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), + test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())), + // Compare with the same values based on the custom comparison. + makeNullableMapVector( + {{{{0, 10}, {1, 11}, {2, 12}}}, + {{{0, 10}, {1, 11}, {2, 12}}}, + {{{0, 10}, {1, 11}, {2, 12}}}, + {{{3, 103}, {4, 104}, {5, 105}}}, + {{{3, 103}, {4, 104}, {5, 105}}}, + {{{3, 103}, {4, 104}, {5, 105}}}, + {{{0, std::nullopt}}}}, + MAP(test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), + test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON()))); +} - stream = HashStringAllocator::prepareRead(positions.at(i).header); - ASSERT_EQ( - expected->hashValueAt(i), - ContainerRowSerde::hash(*stream, vector->type().get())) - << "at " << i << ": " << vector->toString(i); - } +TEST_F(ContainerRowSerdeTest, rowOfCustomComparison) { + testCompare( + makeRowVector( + {"a"}, + {makeNullableFlatVector( + {std::nullopt, 0, 1, 256, 257, 512, 513}, + test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())}), + // Compare with the same values based on the custom comparison. + makeRowVector( + {"a"}, + {makeNullableFlatVector( + {std::nullopt, 0, 1, 0, 1, 0, 1}, + test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())})); } } // namespace