Skip to content

Run single phase aggregation when possible #131485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.startsWith;

public class TimeSeriesIT extends AbstractEsqlIntegTestCase {

Expand Down Expand Up @@ -444,6 +445,19 @@ public void testProfile() {
}
}
assertThat(totalTimeSeries, equalTo(dataProfiles.size() / 3));
{
List<DriverProfile> finalProfiles = profile.drivers().stream().filter(d -> d.description().equals("final")).toList();
assertThat(finalProfiles, hasSize(1));
DriverProfile finalProfile = finalProfiles.getFirst();
assertThat(finalProfile.operators(), hasSize(7));
assertThat(finalProfile.operators().get(0).operator(), startsWith("ExchangeSourceOperator"));
assertThat(finalProfile.operators().get(1).operator(), startsWith("TimeSeriesAggregationOperator"));
assertThat(finalProfile.operators().get(2).operator(), startsWith("ProjectOperator"));
assertThat(finalProfile.operators().get(3).operator(), startsWith("HashAggregationOperator"));
assertThat(finalProfile.operators().get(4).operator(), startsWith("ProjectOperator"));
assertThat(finalProfile.operators().get(5).operator(), startsWith("TopNOperator"));
assertThat(finalProfile.operators().get(6).operator(), startsWith("OutputOperator"));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.xpack.esql.VerificationException;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.SinglePhaseAggregate;
import org.elasticsearch.xpack.esql.plan.physical.FragmentExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.rule.ParameterizedRuleExecutor;
Expand All @@ -24,7 +25,8 @@
public class PhysicalPlanOptimizer extends ParameterizedRuleExecutor<PhysicalPlan, PhysicalOptimizerContext> {

private static final List<RuleExecutor.Batch<PhysicalPlan>> RULES = List.of(
new Batch<>("Plan Boundary", Limiter.ONCE, new ProjectAwayColumns())
new Batch<>("Plan Boundary", Limiter.ONCE, new ProjectAwayColumns()),
new Batch<>("Single phase aggregate", Limiter.ONCE, new SinglePhaseAggregate())
);

private final PhysicalVerifier verifier = PhysicalVerifier.INSTANCE;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.optimizer.rules.physical;

import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;

/**
* Collapses two-phase aggregation into a single phase when possible.
* For example, in FROM .. | STATS first | STATS second, the STATS second aggregation
* can be executed in a single phase on the coordinator instead of two phases.
*/
public class SinglePhaseAggregate extends PhysicalOptimizerRules.OptimizerRule<AggregateExec> {
@Override
protected PhysicalPlan rule(AggregateExec plan) {
if (plan instanceof AggregateExec parent
&& parent.getMode() == AggregatorMode.FINAL
&& parent.child() instanceof AggregateExec child
&& child.getMode() == AggregatorMode.INITIAL) {
if (parent.groupings()
.stream()
.noneMatch(group -> group.anyMatch(expr -> expr instanceof GroupingFunction.NonEvaluatableGroupingFunction))) {
return child.withMode(AggregatorMode.SINGLE);
}
}
return plan;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ public final PhysicalOperation groupingPhysicalOperation(
List<Aggregator.Factory> aggregatorFactories = new ArrayList<>();

// append channels to the layout
if (aggregatorMode == AggregatorMode.FINAL) {
layout.append(aggregates);
} else {
if (aggregatorMode.isOutputPartial()) {
layout.append(aggregateMapper.mapNonGrouping(aggregates));
} else {
layout.append(aggregates);
}

// create the agg factories
Expand Down Expand Up @@ -146,14 +146,14 @@ else if (aggregatorMode.isOutputPartial()) {
groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), sourceGroupAttribute, group));
}

if (aggregatorMode == AggregatorMode.FINAL) {
if (aggregatorMode.isOutputPartial()) {
layout.append(aggregateMapper.mapGrouping(aggregates));
} else {
for (var agg : aggregates) {
if (Alias.unwrap(agg) instanceof AggregateFunction) {
layout.append(agg);
}
}
} else {
layout.append(aggregateMapper.mapGrouping(aggregates));
}

// create the agg factories
Expand Down Expand Up @@ -264,7 +264,13 @@ private void aggregatesToFactory(
if (child instanceof AggregateFunction aggregateFunction) {
List<NamedExpression> sourceAttr = new ArrayList<>();

if (mode == AggregatorMode.INITIAL) {
if (mode.isInputPartial()) {
if (grouping) {
sourceAttr = aggregateMapper.mapGrouping(ne);
} else {
sourceAttr = aggregateMapper.mapNonGrouping(ne);
}
} else {
// TODO: this needs to be made more reliable - use casting to blow up when dealing with expressions (e+1)
Expression field = aggregateFunction.field();
// Only count can now support literals - all the other aggs should be optimized away
Expand Down Expand Up @@ -292,16 +298,6 @@ private void aggregatesToFactory(
}
}
}
// coordinator/exchange phase
else if (mode == AggregatorMode.FINAL || mode == AggregatorMode.INTERMEDIATE) {
if (grouping) {
sourceAttr = aggregateMapper.mapGrouping(ne);
} else {
sourceAttr = aggregateMapper.mapNonGrouping(ne);
}
} else {
throw new EsqlIllegalArgumentException("illegal aggregation mode");
}

AggregatorFunctionSupplier aggSupplier = supplier(aggregateFunction);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.compute.Describable;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
Expand Down Expand Up @@ -215,7 +214,7 @@ public LocalExecutionPlan plan(String description, FoldContext foldCtx, Physical
// workaround for https://github.com/elastic/elasticsearch/issues/99782
localPhysicalPlan = localPhysicalPlan.transformUp(
AggregateExec.class,
a -> a.getMode() == AggregatorMode.FINAL ? new ProjectExec(a.source(), a, Expressions.asAttributes(a.aggregates())) : a
a -> a.getMode().isOutputPartial() ? a : new ProjectExec(a.source(), a, Expressions.asAttributes(a.aggregates()))
);
PhysicalOperation physicalOperation = plan(localPhysicalPlan, context);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@
import static java.util.Arrays.asList;
import static org.elasticsearch.compute.aggregation.AggregatorMode.FINAL;
import static org.elasticsearch.compute.aggregation.AggregatorMode.INITIAL;
import static org.elasticsearch.compute.aggregation.AggregatorMode.SINGLE;
import static org.elasticsearch.core.Tuple.tuple;
import static org.elasticsearch.index.query.QueryBuilders.boolQuery;
import static org.elasticsearch.index.query.QueryBuilders.existsQuery;
Expand Down Expand Up @@ -7814,6 +7815,33 @@ public void testLookupJoinFieldLoadingDropAllFields() throws Exception {
assertLookupJoinFieldNames(query, data, List.of(Set.of(), Set.of("foo", "bar", "baz")));
}

/**
* LimitExec[1000[INTEGER],null]
* \_AggregateExec[[last_name{r}#8],[COUNT(first_name{r}#5,true[BOOLEAN]) AS count(first_name)#11, last_name{r}#8],SINGLE,[last_name
* {r}#8, $$count(first_name)$count{r}#25, $$count(first_name)$seen{r}#26],null]
* \_AggregateExec[[emp_no{f}#12],[VALUES(first_name{f}#13,true[BOOLEAN]) AS first_name#5, VALUES(last_name{f}#16,true[BOOLEAN]) A
* S last_name#8],FINAL,[emp_no{f}#12, $$first_name$values{r}#23, $$last_name$values{r}#24],null]
* \_ExchangeExec[[emp_no{f}#12, $$first_name$values{r}#23, $$last_name$values{r}#24],true]
* \_FragmentExec[filter=null, estimatedRowSize=0, reducer=[], fragment=[
* Aggregate[[emp_no{f}#12],[VALUES(first_name{f}#13,true[BOOLEAN]) AS first_name#5, VALUES(last_name{f}#16,true[BOOLEAN]) A
* S last_name#8]]
* \_EsRelation[test][_meta_field{f}#18, emp_no{f}#12, first_name{f}#13, ..]]]
*/
public void testSingleModeAggregate() {
String q = """
FROM test
| STATS first_name = VALUES(first_name), last_name = VALUES(last_name) BY emp_no
| STATS count(first_name) BY last_name""";
PhysicalPlan plan = physicalPlan(q);
PhysicalPlan optimized = physicalPlanOptimizer.optimize(plan);
LimitExec limit = as(optimized, LimitExec.class);
AggregateExec second = as(limit.child(), AggregateExec.class);
assertThat(second.getMode(), equalTo(SINGLE));
AggregateExec first = as(second.child(), AggregateExec.class);
assertThat(first.getMode(), equalTo(FINAL));
as(first.child(), ExchangeExec.class);
}

private void assertLookupJoinFieldNames(String query, TestDataSource data, List<Set<String>> expectedFieldNames) {
assertLookupJoinFieldNames(query, data, expectedFieldNames, false);
}
Expand Down