Skip to content

Commit 8c12384

Browse files
committed
Add support for $median aggregation operator.
Closes #4472
1 parent 9348794 commit 8c12384

File tree

6 files changed

+173
-7
lines changed

6 files changed

+173
-7
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java

+84
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,16 @@ public Percentile percentile(Double... percentages) {
265265
return percentile.percentages(percentages);
266266
}
267267

268+
/**
269+
* Creates new {@link AggregationExpression} that calculates the median of the associated numeric value expression.
270+
*
271+
* @return new instance of {@link Median}.
272+
* @since 4.2
273+
*/
274+
public Median median() {
275+
return usesFieldRef() ? Median.medianOf(fieldReference) : Median.medianOf(expression);
276+
}
277+
268278
private boolean usesFieldRef() {
269279
return fieldReference != null;
270280
}
@@ -1082,4 +1092,78 @@ protected String getMongoMethod() {
10821092
return "$percentile";
10831093
}
10841094
}
1095+
1096+
/**
1097+
* {@link AggregationExpression} for {@code $median}.
1098+
*
1099+
* @author Julia Lee
1100+
* @since 4.2
1101+
*/
1102+
public static class Median extends AbstractAggregationExpression {
1103+
1104+
private Median(Object value) {
1105+
super(value);
1106+
}
1107+
1108+
/**
1109+
* Creates new {@link Median}.
1110+
*
1111+
* @param fieldReference must not be {@literal null}.
1112+
* @return new instance of {@link Median}.
1113+
*/
1114+
public static Median medianOf(String fieldReference) {
1115+
1116+
Assert.notNull(fieldReference, "FieldReference must not be null");
1117+
Map<String, Object> fields = new HashMap<>();
1118+
fields.put("input", Fields.field(fieldReference));
1119+
fields.put("method", "approximate");
1120+
return new Median(fields);
1121+
}
1122+
1123+
/**
1124+
* Creates new {@link Median}.
1125+
*
1126+
* @param expression must not be {@literal null}.
1127+
* @return new instance of {@link Median}.
1128+
*/
1129+
public static Median medianOf(AggregationExpression expression) {
1130+
1131+
Assert.notNull(expression, "Expression must not be null");
1132+
Map<String, Object> fields = new HashMap<>();
1133+
fields.put("input", expression);
1134+
fields.put("method", "approximate");
1135+
return new Median(fields);
1136+
}
1137+
1138+
/**
1139+
* Creates new {@link Median} with all previously added inputs appending the given one. <br />
1140+
* <strong>NOTE:</strong> Only possible in {@code $project} stage.
1141+
*
1142+
* @param fieldReference must not be {@literal null}.
1143+
* @return new instance of {@link Median}.
1144+
*/
1145+
public Median and(String fieldReference) {
1146+
1147+
Assert.notNull(fieldReference, "FieldReference must not be null");
1148+
return new Median(appendTo("input", Fields.field(fieldReference)));
1149+
}
1150+
1151+
/**
1152+
* Creates new {@link Median} with all previously added inputs appending the given one. <br />
1153+
* <strong>NOTE:</strong> Only possible in {@code $project} stage.
1154+
*
1155+
* @param expression must not be {@literal null}.
1156+
* @return new instance of {@link Median}.
1157+
*/
1158+
public Median and(AggregationExpression expression) {
1159+
1160+
Assert.notNull(expression, "Expression must not be null");
1161+
return new Median(appendTo("input", expression));
1162+
}
1163+
1164+
@Override
1165+
protected String getMongoMethod() {
1166+
return "$median";
1167+
}
1168+
}
10851169
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArithmeticOperators.java

+15
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,16 @@
2424
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovariancePop;
2525
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovarianceSamp;
2626
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Max;
27+
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Median;
2728
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Min;
2829
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Percentile;
2930
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.StdDevPop;
3031
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.StdDevSamp;
3132
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Sum;
3233
import org.springframework.data.mongodb.core.aggregation.SetWindowFieldsOperation.WindowUnit;
3334
import org.springframework.data.mongodb.core.aggregation.SetWindowFieldsOperation.WindowUnits;
35+
import org.springframework.data.mongodb.core.aggregation.SetWindowFieldsOperation.WindowUnit;
36+
import org.springframework.data.mongodb.core.aggregation.SetWindowFieldsOperation.WindowUnits;
3437
import org.springframework.lang.Nullable;
3538
import org.springframework.util.Assert;
3639
import org.springframework.util.ObjectUtils;
@@ -948,6 +951,18 @@ public Percentile percentile(Double... percentages) {
948951
return percentile.percentages(percentages);
949952
}
950953

954+
/**
955+
* Creates new {@link AggregationExpression} that calculates the requested percentile(s) of the
956+
* numeric value.
957+
*
958+
* @return new instance of {@link Median}.
959+
* @since 4.2
960+
*/
961+
public Median median() {
962+
return usesFieldRef() ? AccumulatorOperators.Median.medianOf(fieldReference)
963+
: AccumulatorOperators.Median.medianOf(expression);
964+
}
965+
951966
private boolean usesFieldRef() {
952967
return fieldReference != null;
953968
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java

+20
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,26 @@ void rendersPercentileWithExpression() {
132132
.isEqualTo(Document.parse("{ $percentile: { input: [\"$scoreOne\", {\"$sum\": \"$scoreTwo\"}], method: \"approximate\", p: [0.1, 0.2] } }"));
133133
}
134134

135+
@Test // GH-4472
136+
void rendersMedianWithFieldReference() {
137+
138+
assertThat(valueOf("score").median().toDocument(Aggregation.DEFAULT_CONTEXT))
139+
.isEqualTo(Document.parse("{ $median: { input: \"$score\", method: \"approximate\" } }"));
140+
141+
assertThat(valueOf("score").median().and("scoreTwo").toDocument(Aggregation.DEFAULT_CONTEXT))
142+
.isEqualTo(Document.parse("{ $median: { input: [\"$score\", \"$scoreTwo\"], method: \"approximate\" } }"));
143+
}
144+
145+
@Test // GH-4472
146+
void rendersMedianWithExpression() {
147+
148+
assertThat(valueOf(Sum.sumOf("score")).median().toDocument(Aggregation.DEFAULT_CONTEXT))
149+
.isEqualTo(Document.parse("{ $median: { input: {\"$sum\": \"$score\"}, method: \"approximate\" } }"));
150+
151+
assertThat(valueOf("scoreOne").median().and(Sum.sumOf("scoreTwo")).toDocument(Aggregation.DEFAULT_CONTEXT))
152+
.isEqualTo(Document.parse("{ $median: { input: [\"$scoreOne\", {\"$sum\": \"$scoreTwo\"}], method: \"approximate\" } }"));
153+
}
154+
135155
static class Jedi {
136156

137157
String name;

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java

+37-6
Original file line numberDiff line numberDiff line change
@@ -1897,19 +1897,44 @@ void facetShouldCreateFacets() {
18971897
@EnableIfMongoServerVersion(isGreaterThanEqual = "7.0")
18981898
void percentileShouldBeAppliedCorrectly() {
18991899

1900-
mongoTemplate.insert(new DATAMONGO788(15, 16));
1901-
mongoTemplate.insert(new DATAMONGO788(17, 18));
1900+
DATAMONGO788 objectToSave = new DATAMONGO788(62, 81, 80);
1901+
DATAMONGO788 objectToSave2 = new DATAMONGO788(60, 83, 79);
1902+
1903+
mongoTemplate.insert(objectToSave);
1904+
mongoTemplate.insert(objectToSave2);
19021905

19031906
Aggregation agg = Aggregation.newAggregation(
1904-
project().and(ArithmeticOperators.valueOf("x").percentile(0.9).and("y"))
1905-
.as("ninetiethPercentile"));
1907+
project().and(ArithmeticOperators.valueOf("x").percentile(0.9, 0.4).and("y").and("xField"))
1908+
.as("percentileValues"));
19061909

19071910
AggregationResults<Document> result = mongoTemplate.aggregate(agg, DATAMONGO788.class, Document.class);
19081911

19091912
// MongoDB server returns $percentile as an array of doubles
19101913
List<Document> rawResults = (List<Document>) result.getRawResults().get("results");
1911-
assertThat((List<Object>) rawResults.get(0).get("ninetiethPercentile")).containsExactly(16.0);
1912-
assertThat((List<Object>) rawResults.get(1).get("ninetiethPercentile")).containsExactly(18.0);
1914+
assertThat((List<Object>) rawResults.get(0).get("percentileValues")).containsExactly(81.0, 80.0);
1915+
assertThat((List<Object>) rawResults.get(1).get("percentileValues")).containsExactly(83.0, 79.0);
1916+
}
1917+
1918+
@Test // GH-4472
1919+
@EnableIfMongoServerVersion(isGreaterThanEqual = "7.0")
1920+
void medianShouldBeAppliedCorrectly() {
1921+
1922+
DATAMONGO788 objectToSave = new DATAMONGO788(62, 81, 80);
1923+
DATAMONGO788 objectToSave2 = new DATAMONGO788(60, 83, 79);
1924+
1925+
mongoTemplate.insert(objectToSave);
1926+
mongoTemplate.insert(objectToSave2);
1927+
1928+
Aggregation agg = Aggregation.newAggregation(
1929+
project().and(ArithmeticOperators.valueOf("x").median().and("y").and("xField"))
1930+
.as("medianValue"));
1931+
1932+
AggregationResults<Document> result = mongoTemplate.aggregate(agg, DATAMONGO788.class, Document.class);
1933+
1934+
// MongoDB server returns $median a Double
1935+
List<Document> rawResults = (List<Document>) result.getRawResults().get("results");
1936+
assertThat(rawResults.get(0).get("medianValue")).isEqualTo(80.0);
1937+
assertThat(rawResults.get(1).get("medianValue")).isEqualTo(79.0);
19131938
}
19141939

19151940
@Test // DATAMONGO-1986
@@ -2152,6 +2177,12 @@ public DATAMONGO788() {}
21522177
this.y = y;
21532178
this.yField = y;
21542179
}
2180+
2181+
public DATAMONGO788(int x, int y, int xField) {
2182+
this.x = x;
2183+
this.y = y;
2184+
this.xField = xField;
2185+
}
21552186
}
21562187

21572188
// DATAMONGO-806

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ArithmeticOperatorsUnitTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import org.junit.jupiter.api.Test;
2626

2727
/**
28-
* Unit tests for {@link Round}.
28+
* Unit tests for {@link ArithmeticOperators}.
2929
*
3030
* @author Christoph Strobl
3131
* @author Mark Paluch

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java

+16
Original file line numberDiff line numberDiff line change
@@ -2261,6 +2261,22 @@ void shouldRenderPercentileWithMultipleArgsAggregationExpression() {
22612261
assertThat(agg).isEqualTo(Document.parse("{ $project: { scorePercentiles: { $percentile: { input: [\"$scoreOne\", \"$scoreTwo\"], method: \"approximate\", p: [0.4] } }} } }"));
22622262
}
22632263

2264+
@Test // GH-4472
2265+
void shouldRenderMedianAggregationExpressions() {
2266+
2267+
Document singleArgAgg = project()
2268+
.and(ArithmeticOperators.valueOf("score").median()).as("medianValue")
2269+
.toDocument(Aggregation.DEFAULT_CONTEXT);
2270+
2271+
assertThat(singleArgAgg).isEqualTo(Document.parse("{ $project: { medianValue: { $median: { input: \"$score\", method: \"approximate\" } }} } }"));
2272+
2273+
Document multipleArgsAgg = project()
2274+
.and(ArithmeticOperators.valueOf("score").median().and("scoreTwo")).as("medianValue")
2275+
.toDocument(Aggregation.DEFAULT_CONTEXT);
2276+
2277+
assertThat(multipleArgsAgg).isEqualTo(Document.parse("{ $project: { medianValue: { $median: { input: [\"$score\", \"$scoreTwo\"], method: \"approximate\" } }} } }"));
2278+
}
2279+
22642280
private static Document extractOperation(String field, Document fromProjectClause) {
22652281
return (Document) fromProjectClause.get(field);
22662282
}

0 commit comments

Comments
 (0)