Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release aggregations earlier during reduce #124520

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
import org.elasticsearch.action.search.SearchPhaseController.TopDocsStats;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.SearchShardTarget;
Expand All @@ -31,6 +34,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -174,14 +178,10 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
this.mergeResult = null;
final int resultSize = buffer.size() + (mergeResult == null ? 0 : 1);
final List<TopDocs> topDocsList = hasTopDocs ? new ArrayList<>(resultSize) : null;
final List<DelayableWriteable<InternalAggregations>> aggsList = hasAggs ? new ArrayList<>(resultSize) : null;
if (mergeResult != null) {
if (topDocsList != null) {
topDocsList.add(mergeResult.reducedTopDocs);
}
if (aggsList != null) {
aggsList.add(DelayableWriteable.referencing(mergeResult.reducedAggs));
}
}
for (QuerySearchResult result : buffer) {
topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly());
Expand All @@ -190,34 +190,39 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
setShardIndex(topDocs.topDocs, result.getShardIndex());
topDocsList.add(topDocs.topDocs);
}
if (aggsList != null) {
aggsList.add(result.getAggs());
}
}
SearchPhaseController.ReducedQueryPhase reducePhase;
long breakerSize = circuitBreakerBytes;
final InternalAggregations aggs;
try {
if (aggsList != null) {
if (hasAggs) {
// Add an estimate of the final reduce size
breakerSize = addEstimateAndMaybeBreak(estimateRamBytesUsedForReduce(breakerSize));
aggs = aggregate(
buffer.iterator(),
mergeResult,
resultSize,
performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction()
);
} else {
aggs = null;
}
reducePhase = SearchPhaseController.reducedQueryPhase(
results.asList(),
aggsList,
aggs,
topDocsList == null ? Collections.emptyList() : topDocsList,
topDocsStats,
numReducePhases,
false,
aggReduceContextBuilder,
queryPhaseRankCoordinatorContext,
performFinalReduce
queryPhaseRankCoordinatorContext
);
buffer = null;
} finally {
releaseAggs(buffer);
}
if (hasAggs
// reduced aggregations can be null if all shards failed
&& reducePhase.aggregations() != null) {
&& aggs != null) {

// Update the circuit breaker to replace the estimation with the serialized size of the newly reduced result
long finalSize = DelayableWriteable.getSerializedSize(reducePhase.aggregations()) - breakerSize;
Expand Down Expand Up @@ -249,17 +254,7 @@ private MergeResult partialReduce(
toConsume.sort(RESULT_COMPARATOR);

final TopDocs newTopDocs;
final InternalAggregations newAggs;
final List<DelayableWriteable<InternalAggregations>> aggsList;
final int resultSetSize = toConsume.size() + (lastMerge != null ? 1 : 0);
if (hasAggs) {
aggsList = new ArrayList<>(resultSetSize);
if (lastMerge != null) {
aggsList.add(DelayableWriteable.referencing(lastMerge.reducedAggs));
}
} else {
aggsList = null;
}
List<TopDocs> topDocsList;
if (hasTopDocs) {
topDocsList = new ArrayList<>(resultSetSize);
Expand All @@ -269,14 +264,12 @@ private MergeResult partialReduce(
} else {
topDocsList = null;
}
final InternalAggregations newAggs;
try {
for (QuerySearchResult result : toConsume) {
topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly());
SearchShardTarget target = result.getSearchShardTarget();
processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId()));
if (aggsList != null) {
aggsList.add(result.getAggs());
}
if (topDocsList != null) {
TopDocsAndMaxScore topDocs = result.consumeTopDocs();
setShardIndex(topDocs.topDocs, result.getShardIndex());
Expand All @@ -285,9 +278,10 @@ private MergeResult partialReduce(
}
// we have to merge here in the same way we collect on a shard
newTopDocs = topDocsList == null ? null : mergeTopDocs(topDocsList, topNSize, 0);
newAggs = aggsList == null
? null
: InternalAggregations.topLevelReduceDelayable(aggsList, aggReduceContextBuilder.forPartialReduction());
newAggs = hasAggs
? aggregate(toConsume.iterator(), lastMerge, resultSetSize, aggReduceContextBuilder.forPartialReduction())
: null;
toConsume = null;
} finally {
releaseAggs(toConsume);
}
Expand All @@ -302,6 +296,45 @@ private MergeResult partialReduce(
return new MergeResult(processedShards, newTopDocs, newAggs, newAggs != null ? DelayableWriteable.getSerializedSize(newAggs) : 0);
}

private static InternalAggregations aggregate(
Iterator<QuerySearchResult> toConsume,
MergeResult lastMerge,
int resultSetSize,
AggregationReduceContext reduceContext
) {
interface ReleasableIterator extends Iterator<InternalAggregations>, Releasable {}
try (var aggsIter = new ReleasableIterator() {

private Releasable toRelease;

@Override
public void close() {
Releasables.close(toRelease);
}

@Override
public boolean hasNext() {
return toConsume.hasNext();
}

@Override
public InternalAggregations next() {
var res = toConsume.next().consumeAggs();
Releasables.close(toRelease);
toRelease = res;
return res.expand();
}
}) {
return InternalAggregations.topLevelReduce(
lastMerge == null ? aggsIter : Iterators.concat(Iterators.single(lastMerge.reducedAggs), aggsIter),
resultSetSize,
reduceContext
);
} finally {
toConsume.forEachRemaining(QuerySearchResult::releaseAggs);
}
}

public int getNumReducePhases() {
return numReducePhases;
}
Expand Down Expand Up @@ -517,8 +550,10 @@ public void onFailure(Exception exc) {
}

private static void releaseAggs(List<QuerySearchResult> toConsume) {
for (QuerySearchResult result : toConsume) {
result.releaseAggs();
if (toConsume != null) {
for (QuerySearchResult result : toConsume) {
result.releaseAggs();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.TotalHits.Relation;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.util.Maps;
Expand Down Expand Up @@ -401,22 +400,20 @@ private static SearchHits getHits(
/**
* Reduces the given query results and consumes all aggregations and profile results.
* @param queryResults a list of non-null query shard results
* @param bufferedAggs a list of pre-collected aggregations.
* @param reducedAggs already reduced aggregations
* @param bufferedTopDocs a list of pre-collected top docs.
* @param numReducePhases the number of non-final reduce phases applied to the query results.
* @see QuerySearchResult#getAggs()
* @see QuerySearchResult#consumeProfileResult()
*/
static ReducedQueryPhase reducedQueryPhase(
Collection<? extends SearchPhaseResult> queryResults,
@Nullable List<DelayableWriteable<InternalAggregations>> bufferedAggs,
@Nullable InternalAggregations reducedAggs,
List<TopDocs> bufferedTopDocs,
TopDocsStats topDocsStats,
int numReducePhases,
boolean isScrollRequest,
AggregationReduceContext.Builder aggReduceContextBuilder,
QueryPhaseRankCoordinatorContext queryPhaseRankCoordinatorContext,
boolean performFinalReduce
QueryPhaseRankCoordinatorContext queryPhaseRankCoordinatorContext
) {
assert numReducePhases >= 0 : "num reduce phases must be >= 0 but was: " + numReducePhases;
numReducePhases++; // increment for this phase
Expand Down Expand Up @@ -520,12 +517,7 @@ static ReducedQueryPhase reducedQueryPhase(
topDocsStats.timedOut,
topDocsStats.terminatedEarly,
reducedSuggest,
bufferedAggs == null
? null
: InternalAggregations.topLevelReduceDelayable(
bufferedAggs,
performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction()
),
reducedAggs,
profileShardResults.isEmpty() ? null : new SearchProfileResultsBuilder(profileShardResults),
sortedTopDocs,
sortValueFormats,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.internal.InternalScrollSearchRequest;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchContextId;
Expand Down Expand Up @@ -313,17 +312,6 @@ protected Transport.Connection getConnection(String clusterAlias, DiscoveryNode
* @param queryResults a list of non-null query shard results
*/
protected static SearchPhaseController.ReducedQueryPhase reducedScrollQueryPhase(Collection<? extends SearchPhaseResult> queryResults) {
AggregationReduceContext.Builder aggReduceContextBuilder = new AggregationReduceContext.Builder() {
@Override
public AggregationReduceContext forPartialReduction() {
throw new UnsupportedOperationException("Scroll requests don't have aggs");
}

@Override
public AggregationReduceContext forFinalReduction() {
throw new UnsupportedOperationException("Scroll requests don't have aggs");
}
};
final SearchPhaseController.TopDocsStats topDocsStats = new SearchPhaseController.TopDocsStats(
SearchContext.TRACK_TOTAL_HITS_ACCURATE
);
Expand All @@ -339,16 +327,6 @@ public AggregationReduceContext forFinalReduction() {
topDocs.add(td.topDocs);
}
}
return SearchPhaseController.reducedQueryPhase(
queryResults,
null,
topDocs,
topDocsStats,
0,
true,
aggReduceContextBuilder,
null,
true
);
return SearchPhaseController.reducedQueryPhase(queryResults, null, topDocs, topDocsStats, 0, true, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
*/
package org.elasticsearch.search.aggregations;

import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand All @@ -23,7 +22,6 @@
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -180,44 +178,22 @@ public SortValue sortValue(AggregationPath.PathElement head, Iterator<Aggregatio
}

/**
* Equivalent to {@link #topLevelReduce(List, AggregationReduceContext)} but it takes a list of
* {@link DelayableWriteable}. The object will be expanded once via {@link DelayableWriteable#expand()}
* but it is the responsibility of the caller to release those releasables.
* Equivalent to {@link #topLevelReduce(List, AggregationReduceContext)} but it takes an iterator and a count.
*/
public static InternalAggregations topLevelReduceDelayable(
List<DelayableWriteable<InternalAggregations>> delayableAggregations,
AggregationReduceContext context
) {
final List<InternalAggregations> aggregations = new AbstractList<>() {
@Override
public InternalAggregations get(int index) {
return delayableAggregations.get(index).expand();
}

@Override
public int size() {
return delayableAggregations.size();
}
};
return topLevelReduce(aggregations, context);
public static InternalAggregations topLevelReduce(Iterator<InternalAggregations> aggs, int count, AggregationReduceContext context) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is single-use for now but I left it here since it could be used where the existing version of it that consumes a list is used today as well to save some more indirection and maybe heap.

if (count == 0) {
return null;
}
return maybeExecuteFinalReduce(context, count == 1 ? reduce(aggs.next(), context) : reduce(aggs, count, context));
}

/**
* Begin the reduction process. This should be the entry point for the "first" reduction, e.g. called by
* SearchPhaseController or anywhere else that wants to initiate a reduction. It _should not_ be called
* as an intermediate reduction step (e.g. in the middle of an aggregation tree).
*
* This method first reduces the aggregations, and if it is the final reduce, then reduce the pipeline
* aggregations (both embedded parent/sibling as well as top-level sibling pipelines)
*/
public static InternalAggregations topLevelReduce(List<InternalAggregations> aggregationsList, AggregationReduceContext context) {
InternalAggregations reduced = reduce(aggregationsList, context);
private static InternalAggregations maybeExecuteFinalReduce(AggregationReduceContext context, InternalAggregations reduced) {
if (reduced == null) {
return null;
}
if (context.isFinalReduce()) {
List<InternalAggregation> reducedInternalAggs = reduced.getInternalAggregations();
reducedInternalAggs = reducedInternalAggs.stream()
List<InternalAggregation> reducedInternalAggs = reduced.getInternalAggregations()
.stream()
.map(agg -> agg.reducePipelines(agg, context, context.pipelineTreeRoot().subTree(agg.getName())))
.collect(Collectors.toCollection(ArrayList::new));

Expand All @@ -231,6 +207,18 @@ public static InternalAggregations topLevelReduce(List<InternalAggregations> agg
return reduced;
}

/**
* Begin the reduction process. This should be the entry point for the "first" reduction, e.g. called by
* SearchPhaseController or anywhere else that wants to initiate a reduction. It _should not_ be called
* as an intermediate reduction step (e.g. in the middle of an aggregation tree).
*
* This method first reduces the aggregations, and if it is the final reduce, then reduce the pipeline
* aggregations (both embedded parent/sibling as well as top-level sibling pipelines)
*/
public static InternalAggregations topLevelReduce(List<InternalAggregations> aggregationsList, AggregationReduceContext context) {
return maybeExecuteFinalReduce(context, reduce(aggregationsList, context));
}

/**
* Reduces the given list of aggregations as well as the top-level pipeline aggregators extracted from the first
* {@link InternalAggregations} object found in the list.
Expand All @@ -254,6 +242,16 @@ public static InternalAggregations reduce(List<InternalAggregations> aggregation
}
}

private static InternalAggregations reduce(Iterator<InternalAggregations> aggsIterator, int count, AggregationReduceContext context) {
// general case
var first = aggsIterator.next();
try (AggregatorsReducer reducer = new AggregatorsReducer(first, context, count)) {
reducer.accept(first);
aggsIterator.forEachRemaining(reducer::accept);
return reducer.get();
}
}

public static InternalAggregations reduce(InternalAggregations aggregations, AggregationReduceContext context) {
final List<InternalAggregation> internalAggregations = aggregations.asList();
int size = internalAggregations.size();
Expand Down
Loading