Skip to content

Commit

Permalink
HSEARCH-5133 Refactor LuceneMetric*Aggregation classes
Browse files Browse the repository at this point in the history
Extracting one subclass for each operation name
  • Loading branch information
fax4ever committed Aug 26, 2024
1 parent 532dc54 commit 08371a9
Show file tree
Hide file tree
Showing 15 changed files with 309 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
import java.util.Set;

import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.CompensatedSum;
import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.CompensatedSumCollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.CountCollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.DoubleAggregationFunctionCollector;
import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.MaxCollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.MinCollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.collector.impl.CollectorKey;
import org.hibernate.search.backend.lucene.lowlevel.docvalues.impl.JoiningLongMultiValuesSource;
import org.hibernate.search.backend.lucene.search.aggregation.impl.AggregationExtractContext;
Expand All @@ -24,68 +20,46 @@
import org.hibernate.search.engine.backend.types.converter.spi.ProjectionConverter;
import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder;
import org.hibernate.search.engine.search.common.ValueModel;
import org.hibernate.search.util.common.AssertionFailure;

/**
* @param <F> The type of field values.
* @param <K> The type of returned value. It can be {@code F}, {@link Double}
* or a different type if value converters are used.
*/
public class LuceneMetricCompensatedSumAggregation<F, E extends Number, K> extends AbstractLuceneNestableAggregation<K> {
public abstract class AbstractLuceneMetricCompensatedSumAggregation<F, E extends Number, K>
extends AbstractLuceneNestableAggregation<K> {

private final Set<String> indexNames;
private final String absoluteFieldPath;
private final AbstractLuceneNumericFieldCodec<F, E> codec;
private final LuceneNumericDomain<E> numericDomain;
private final ProjectionConverter<F, ? extends K> fromFieldValueConverter;
private final String operation;

private CollectorKey<?, Long> collectorKey;
private CollectorKey<DoubleAggregationFunctionCollector<CompensatedSum>, Double> compensatedSumCollectorKey;
protected CollectorKey<?, Long> collectorKey;
protected CollectorKey<DoubleAggregationFunctionCollector<CompensatedSum>, Double> compensatedSumCollectorKey;

LuceneMetricCompensatedSumAggregation(Builder<F, E, K> builder) {
AbstractLuceneMetricCompensatedSumAggregation(Builder<F, E, K> builder) {
super( builder );
this.indexNames = builder.scope.hibernateSearchIndexNames();
this.absoluteFieldPath = builder.field.absolutePath();
this.codec = builder.codec;
this.numericDomain = codec.getDomain();
this.fromFieldValueConverter = builder.fromFieldValueConverter;
this.operation = builder.operation;
}

@Override
public Extractor<K> request(AggregationRequestContext context) {
JoiningLongMultiValuesSource source = JoiningLongMultiValuesSource.fromField(
absoluteFieldPath, createNestedDocsProvider( context )
);
if ( "sum".equals( operation ) ) {
CompensatedSumCollectorFactory collectorFactory = new CompensatedSumCollectorFactory( source,
numericDomain::sortedDocValueToDouble );
compensatedSumCollectorKey = collectorFactory.getCollectorKey();
context.requireCollector( collectorFactory );
}
else if ( "min".equals( operation ) ) {
MinCollectorFactory collectorFactory = new MinCollectorFactory( source );
collectorKey = collectorFactory.getCollectorKey();
context.requireCollector( collectorFactory );
}
else if ( "max".equals( operation ) ) {
MaxCollectorFactory collectorFactory = new MaxCollectorFactory( source );
collectorKey = collectorFactory.getCollectorKey();
context.requireCollector( collectorFactory );
}
else if ( "avg".equals( operation ) ) {
CompensatedSumCollectorFactory sumCollectorFactory = new CompensatedSumCollectorFactory( source,
numericDomain::sortedDocValueToDouble );
compensatedSumCollectorKey = sumCollectorFactory.getCollectorKey();
context.requireCollector( sumCollectorFactory );

CountCollectorFactory countCollectorFactory = new CountCollectorFactory( source );
collectorKey = countCollectorFactory.getCollectorKey();
context.requireCollector( countCollectorFactory );
}
fillCollectors( source, context, numericDomain );
return new LuceneNumericMetricFieldAggregationExtraction();
}

abstract void fillCollectors(JoiningLongMultiValuesSource source, AggregationRequestContext context,
LuceneNumericDomain<E> numericDomain);

@Override
public Set<String> indexNames() {
return indexNames;
Expand All @@ -95,28 +69,14 @@ private class LuceneNumericMetricFieldAggregationExtraction implements Extractor

@Override
public K extract(AggregationExtractContext context) {
E extracted;

if ( "sum".equals( operation ) ) {
Double sum = context.getFacets( compensatedSumCollectorKey );
extracted = numericDomain.doubleToTerm( sum );
}
else if ( "avg".equals( operation ) ) {
Double sum = context.getFacets( compensatedSumCollectorKey );
Long counts = context.getFacets( collectorKey );
double avg = ( sum / counts );
extracted = numericDomain.doubleToTerm( avg );
}
else {
Long result = context.getFacets( collectorKey );
extracted = numericDomain.sortedDocValueToTerm( result );
}

E extracted = extractEncoded( context, numericDomain );
F decode = codec.decode( extracted );
return fromFieldValueConverter.fromDocumentValue( decode, context.fromDocumentValueConvertContext() );
}
}

abstract E extractEncoded(AggregationExtractContext context, LuceneNumericDomain<E> numericDomain);

public static class Factory<F>
extends AbstractLuceneCodecAwareSearchQueryElementFactory<FieldMetricAggregationBuilder.TypeSelector,
F,
Expand Down Expand Up @@ -168,7 +128,7 @@ private TypeSelector(AbstractLuceneNumericFieldCodec<F, ?> codec,
}
}

private static class Builder<F, E extends Number, K> extends AbstractBuilder<K>
protected static class Builder<F, E extends Number, K> extends AbstractBuilder<K>
implements FieldMetricAggregationBuilder<K> {

private final AbstractLuceneNumericFieldCodec<F, E> codec;
Expand All @@ -186,8 +146,16 @@ public Builder(AbstractLuceneNumericFieldCodec<F, E> codec, LuceneSearchIndexSco
}

@Override
public LuceneMetricCompensatedSumAggregation<F, E, K> build() {
return new LuceneMetricCompensatedSumAggregation<>( this );
public AbstractLuceneMetricCompensatedSumAggregation<F, E, K> build() {
if ( "sum".equals( operation ) ) {
return new LuceneSumCompensatedSumAggregation<>( this );
}
else if ( "avg".equals( operation ) ) {
return new LuceneAvgCompensatedSumAggregation<>( this );
}
else {
throw new AssertionFailure( "Aggregation operation not supported: " + operation );
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.AggregationFunctionCollector;
import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.Count;
import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.CountCollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.MaxCollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.MinCollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.SumCollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.collector.impl.CollectorKey;
import org.hibernate.search.backend.lucene.lowlevel.docvalues.impl.JoiningLongMultiValuesSource;
import org.hibernate.search.backend.lucene.search.aggregation.impl.AggregationExtractContext;
Expand All @@ -25,67 +21,47 @@
import org.hibernate.search.engine.cfg.spi.NumberUtils;
import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder;
import org.hibernate.search.engine.search.common.ValueModel;
import org.hibernate.search.util.common.AssertionFailure;

/**
* @param <F> The type of field values.
* @param <K> The type of returned value. It can be {@code F}, {@link Double}
* or a different type if value converters are used.
*/
public class LuceneMetricNumericFieldAggregation<F, E extends Number, K> extends AbstractLuceneNestableAggregation<K> {
public abstract class AbstractLuceneMetricNumericFieldAggregation<F, E extends Number, K>
extends AbstractLuceneNestableAggregation<K> {

private final Set<String> indexNames;
private final String absoluteFieldPath;
private final AbstractLuceneNumericFieldCodec<F, E> codec;
private final LuceneNumericDomain<E> numericDomain;
private final ProjectionConverter<F, ? extends K> fromFieldValueConverter;
private final String operation;

private CollectorKey<?, Long> collectorKey;
protected CollectorKey<?, Long> collectorKey;

// Supplementary collector used by the avg function
private CollectorKey<AggregationFunctionCollector<Count>, Long> countCollectorKey;
protected CollectorKey<AggregationFunctionCollector<Count>, Long> countCollectorKey;

LuceneMetricNumericFieldAggregation(Builder<F, E, K> builder) {
AbstractLuceneMetricNumericFieldAggregation(Builder<F, E, K> builder) {
super( builder );
this.indexNames = builder.scope.hibernateSearchIndexNames();
this.absoluteFieldPath = builder.field.absolutePath();
this.codec = builder.codec;
this.numericDomain = codec.getDomain();
this.fromFieldValueConverter = builder.fromFieldValueConverter;
this.operation = builder.operation;
}

@Override
public Extractor<K> request(AggregationRequestContext context) {
JoiningLongMultiValuesSource source = JoiningLongMultiValuesSource.fromField(
absoluteFieldPath, createNestedDocsProvider( context )
);
if ( "sum".equals( operation ) ) {
SumCollectorFactory collectorFactory = new SumCollectorFactory( source );
collectorKey = collectorFactory.getCollectorKey();
context.requireCollector( collectorFactory );
}
else if ( "min".equals( operation ) ) {
MinCollectorFactory collectorFactory = new MinCollectorFactory( source );
collectorKey = collectorFactory.getCollectorKey();
context.requireCollector( collectorFactory );
}
else if ( "max".equals( operation ) ) {
MaxCollectorFactory collectorFactory = new MaxCollectorFactory( source );
collectorKey = collectorFactory.getCollectorKey();
context.requireCollector( collectorFactory );
}
else if ( "avg".equals( operation ) ) {
SumCollectorFactory sumCollectorFactory = new SumCollectorFactory( source );
CountCollectorFactory countCollectorFactory = new CountCollectorFactory( source );
collectorKey = sumCollectorFactory.getCollectorKey();
countCollectorKey = countCollectorFactory.getCollectorKey();
context.requireCollector( sumCollectorFactory );
context.requireCollector( countCollectorFactory );
}
fillCollectors( source, context );
return new LuceneNumericMetricFieldAggregationExtraction();
}

abstract void fillCollectors(JoiningLongMultiValuesSource source, AggregationRequestContext context);

@Override
public Set<String> indexNames() {
return indexNames;
Expand Down Expand Up @@ -169,7 +145,7 @@ private TypeSelector(AbstractLuceneNumericFieldCodec<F, ?> codec,
}
}

private static class Builder<F, E extends Number, K> extends AbstractBuilder<K>
protected static class Builder<F, E extends Number, K> extends AbstractBuilder<K>
implements FieldMetricAggregationBuilder<K> {

private final AbstractLuceneNumericFieldCodec<F, E> codec;
Expand All @@ -187,8 +163,22 @@ public Builder(AbstractLuceneNumericFieldCodec<F, E> codec, LuceneSearchIndexSco
}

@Override
public LuceneMetricNumericFieldAggregation<F, E, K> build() {
return new LuceneMetricNumericFieldAggregation<>( this );
public AbstractLuceneMetricNumericFieldAggregation<F, E, K> build() {
if ( "sum".equals( operation ) ) {
return new LuceneSumNumericFieldAggregation<>( this );
}
else if ( "min".equals( operation ) ) {
return new LuceneMinNumericFieldAggregation<>( this );
}
else if ( "max".equals( operation ) ) {
return new LuceneMaxNumericFieldAggregation<>( this );
}
else if ( "avg".equals( operation ) ) {
return new LuceneAvgNumericFieldAggregation<>( this );
}
else {
throw new AssertionFailure( "Aggregation operation not supported: " + operation );
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

import java.util.Set;

import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.CountCollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.CountDistinctCollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.collector.impl.CollectorKey;
import org.hibernate.search.backend.lucene.lowlevel.docvalues.impl.JoiningLongMultiValuesSource;
import org.hibernate.search.backend.lucene.search.aggregation.impl.AggregationExtractContext;
Expand All @@ -17,16 +15,17 @@
import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexValueFieldContext;
import org.hibernate.search.backend.lucene.types.codec.impl.AbstractLuceneNumericFieldCodec;
import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder;
import org.hibernate.search.util.common.AssertionFailure;

public class LuceneMetricNumericLongAggregation extends AbstractLuceneNestableAggregation<Long> {
public abstract class AbstractLuceneMetricNumericLongAggregation extends AbstractLuceneNestableAggregation<Long> {

private final Set<String> indexNames;
private final String absoluteFieldPath;
private final String operation;

private CollectorKey<?, Long> collectorKey;
protected CollectorKey<?, Long> collectorKey;

LuceneMetricNumericLongAggregation(Builder builder) {
AbstractLuceneMetricNumericLongAggregation(Builder builder) {
super( builder );
this.indexNames = builder.scope.hibernateSearchIndexNames();
this.absoluteFieldPath = builder.field.absolutePath();
Expand All @@ -38,19 +37,13 @@ public Extractor<Long> request(AggregationRequestContext context) {
JoiningLongMultiValuesSource source = JoiningLongMultiValuesSource.fromField(
absoluteFieldPath, createNestedDocsProvider( context )
);
if ( "cardinality".equals( operation ) ) {
CountDistinctCollectorFactory collectorFactory = new CountDistinctCollectorFactory( source );
collectorKey = collectorFactory.getCollectorKey();
context.requireCollector( collectorFactory );
}
else if ( "value_count".equals( operation ) ) {
CountCollectorFactory collectorFactory = new CountCollectorFactory( source );
collectorKey = collectorFactory.getCollectorKey();
context.requireCollector( collectorFactory );
}
fillCollectors( source, context );

return new LuceneNumericMetricLongAggregationExtraction();
}

abstract void fillCollectors(JoiningLongMultiValuesSource source, AggregationRequestContext context);

@Override
public Set<String> indexNames() {
return indexNames;
Expand Down Expand Up @@ -82,7 +75,7 @@ public FieldMetricAggregationBuilder<Long> create(LuceneSearchIndexScope<?> scop
}
}

private static class Builder extends AbstractBuilder<Long> implements FieldMetricAggregationBuilder<Long> {
protected static class Builder extends AbstractBuilder<Long> implements FieldMetricAggregationBuilder<Long> {
private final String operation;

public Builder(LuceneSearchIndexScope<?> scope,
Expand All @@ -93,8 +86,16 @@ public Builder(LuceneSearchIndexScope<?> scope,
}

@Override
public LuceneMetricNumericLongAggregation build() {
return new LuceneMetricNumericLongAggregation( this );
public AbstractLuceneMetricNumericLongAggregation build() {
if ( "value_count".equals( operation ) ) {
return new LuceneCountNumericLongAggregation( this );
}
else if ( "cardinality".equals( operation ) ) {
return new LuceneCountDistinctNumericLongAggregation( this );
}
else {
throw new AssertionFailure( "Aggregation operation not supported: " + operation );
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.search.backend.lucene.types.aggregation.impl;

import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.CompensatedSumCollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.CountCollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.docvalues.impl.JoiningLongMultiValuesSource;
import org.hibernate.search.backend.lucene.search.aggregation.impl.AggregationExtractContext;
import org.hibernate.search.backend.lucene.search.aggregation.impl.AggregationRequestContext;
import org.hibernate.search.backend.lucene.types.lowlevel.impl.LuceneNumericDomain;

public class LuceneAvgCompensatedSumAggregation<F, E extends Number, K>
extends AbstractLuceneMetricCompensatedSumAggregation<F, E, K> {

LuceneAvgCompensatedSumAggregation(Builder<F, E, K> builder) {
super( builder );
}

@Override
void fillCollectors(JoiningLongMultiValuesSource source, AggregationRequestContext context,
LuceneNumericDomain<E> numericDomain) {
CompensatedSumCollectorFactory sumCollectorFactory = new CompensatedSumCollectorFactory( source,
numericDomain::sortedDocValueToDouble );
compensatedSumCollectorKey = sumCollectorFactory.getCollectorKey();
context.requireCollector( sumCollectorFactory );

CountCollectorFactory countCollectorFactory = new CountCollectorFactory( source );
collectorKey = countCollectorFactory.getCollectorKey();
context.requireCollector( countCollectorFactory );
}

@Override
E extractEncoded(AggregationExtractContext context, LuceneNumericDomain<E> numericDomain) {
Double sum = context.getFacets( compensatedSumCollectorKey );
Long counts = context.getFacets( collectorKey );
double avg = ( sum / counts );
return numericDomain.doubleToTerm( avg );
}
}
Loading

0 comments on commit 08371a9

Please sign in to comment.