Skip to content

ESQL: Fail in AggregateFunction when LogicPlan is not an Aggregate #124446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/124446.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 124446
summary: "ESQL: Fail in `AggregateFunction` when `LogicPlan` is not an `Aggregate`"
area: ES|QL
type: bug
issues:
- 124311
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Dedup;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -139,14 +140,12 @@ public boolean equals(Object obj) {
@Override
public BiConsumer<LogicalPlan, Failures> postAnalysisPlanVerification() {
return (p, failures) -> {
if (p instanceof OrderBy order) {
order.order().forEach(o -> {
o.forEachDown(Function.class, f -> {
if (f instanceof AggregateFunction) {
failures.add(fail(f, "Aggregate functions are not allowed in SORT [{}]", f.functionName()));
}
});
});
// `dedup` for now is not exposed as a command,
// so allowing aggregate functions for dedup explicitly is just an internal implementation detail
if ((p instanceof Aggregate) == false && (p instanceof Dedup) == false) {
p.expressions().forEach(x -> x.forEachDown(AggregateFunction.class, af -> {
failures.add(fail(af, "aggregate function [{}] not allowed outside STATS command", af.sourceText()));
}));
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.GeneratingPlan;

Expand Down Expand Up @@ -178,10 +177,6 @@ public void postAnalysisVerification(Failures failures) {
)
);
}
// check no aggregate functions are used
field.forEachDown(AggregateFunction.class, af -> {
failures.add(fail(af, "aggregate function [{}] not allowed outside STATS command", af.sourceText()));
});
});
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add another test with multiple aggs in the same WHERE/EVAL, or both in the same query with an agg each?
You can test it like here:

public void testNotFoundFieldInNestedFunction() {
assertEquals("""
1:30: Unknown column [missing]
line 1:43: Unknown column [not_found]
line 1:23: Unknown column [avg]""", error("from test | stats c = avg by missing + 1, not_found"));
}

I'm asking this because the RRF test sends many repeated errors (Because each "subblan" of the RRF has the same source). It shouldn't happen here at all, but just as a double-check

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add test cases where an aggregate function ends up in the ROW command. The following currently gives 500:

row language_code = count(2)

While supremely paranoid, we could also throw in test cases for dissect and grok, like row x = 1 | dissect avg(1) "foo" and similarly with grok. (This doesn't give 500 on current main.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reviewing! I’ve added the tests in c7d867c.

Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,17 @@ public void testNotFoundFieldInNestedFunction() {
line 1:23: Unknown column [avg]""", error("from test | stats c = avg by missing + 1, not_found"));
}

public void testMultipleAggsOutsideStats() {
assertEquals(
"""
1:71: aggregate function [avg(salary)] not allowed outside STATS command
line 1:96: aggregate function [median(emp_no)] not allowed outside STATS command
line 1:22: aggregate function [sum(salary)] not allowed outside STATS command
line 1:39: aggregate function [avg(languages)] not allowed outside STATS command""",
error("from test | eval s = sum(salary), l = avg(languages) | where salary > avg(salary) and emp_no > median(emp_no)")
);
}

public void testSpatialSort() {
String prefix = "ROW wkt = [\"POINT(42.9711 -14.7553)\", \"POINT(75.8093 22.7277)\"] | MV_EXPAND wkt ";
assertEquals("1:130: cannot sort on geo_point", error(prefix + "| EVAL shape = TO_GEOPOINT(wkt) | limit 5 | sort shape"));
Expand Down Expand Up @@ -2107,10 +2118,53 @@ public void testChangePoint_valueNumeric() {
}

public void testSortByAggregate() {
assertEquals("1:18: Aggregate functions are not allowed in SORT [COUNT]", error("ROW a = 1 | SORT count(*)"));
assertEquals("1:28: Aggregate functions are not allowed in SORT [COUNT]", error("ROW a = 1 | SORT to_string(count(*))"));
assertEquals("1:22: Aggregate functions are not allowed in SORT [MAX]", error("ROW a = 1 | SORT 1 + max(a)"));
assertEquals("1:18: Aggregate functions are not allowed in SORT [COUNT]", error("FROM test | SORT count(*)"));
assertEquals("1:18: aggregate function [count(*)] not allowed outside STATS command", error("ROW a = 1 | SORT count(*)"));
assertEquals(
"1:28: aggregate function [count(*)] not allowed outside STATS command",
error("ROW a = 1 | SORT to_string(count(*))")
);
assertEquals("1:22: aggregate function [max(a)] not allowed outside STATS command", error("ROW a = 1 | SORT 1 + max(a)"));
assertEquals("1:18: aggregate function [count(*)] not allowed outside STATS command", error("FROM test | SORT count(*)"));
}

public void testFilterByAggregate() {
assertEquals("1:19: aggregate function [count(*)] not allowed outside STATS command", error("ROW a = 1 | WHERE count(*) > 0"));
assertEquals(
"1:29: aggregate function [count(*)] not allowed outside STATS command",
error("ROW a = 1 | WHERE to_string(count(*)) IS NOT NULL")
);
assertEquals("1:23: aggregate function [max(a)] not allowed outside STATS command", error("ROW a = 1 | WHERE 1 + max(a) > 0"));
assertEquals(
"1:24: aggregate function [min(languages)] not allowed outside STATS command",
error("FROM employees | WHERE min(languages) > 2")
);
}

public void testDissectByAggregate() {
assertEquals(
"1:21: aggregate function [min(first_name)] not allowed outside STATS command",
error("from test | dissect min(first_name) \"%{foo}\"")
);
assertEquals(
"1:21: aggregate function [avg(salary)] not allowed outside STATS command",
error("from test | dissect avg(salary) \"%{foo}\"")
);
}

public void testGrokByAggregate() {
assertEquals(
"1:18: aggregate function [max(last_name)] not allowed outside STATS command",
error("from test | grok max(last_name) \"%{WORD:foo}\"")
);
assertEquals(
"1:18: aggregate function [sum(salary)] not allowed outside STATS command",
error("from test | grok sum(salary) \"%{WORD:foo}\"")
);
}

public void testAggregateInRow() {
assertEquals("1:13: aggregate function [count(*)] not allowed outside STATS command", error("ROW a = 1 + count(*)"));
assertEquals("1:9: aggregate function [avg(2)] not allowed outside STATS command", error("ROW a = avg(2)"));
}

public void testLookupJoinDataTypeMismatch() {
Expand Down
Loading