diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java
index ed311bc2d3..1b03937d08 100644
--- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java
+++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java
@@ -38,6 +38,7 @@
* @author Gustavo de Geus
* @author Christoph Strobl
* @author Mark Paluch
+ * @author Sergey Shcherbakov
* @since 1.3
* @see MongoDB Aggregation Framework: $group
*/
@@ -155,6 +156,17 @@ public GroupOperationBuilder sum(String reference) {
return sum(reference, null);
}
+ /**
+ * Generates an {@link GroupOperationBuilder} for an {@code $sum}-expression for the given
+ * {@link AggregationExpression}.
+ *
+ * @param expr
+ * @return
+ */
+ public GroupOperationBuilder sum(AggregationExpression expr) {
+ return newBuilder(GroupOps.SUM, null, expr);
+ }
+
private GroupOperationBuilder sum(@Nullable String reference, @Nullable Object value) {
return newBuilder(GroupOps.SUM, reference, value);
}
diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java
index ff11059a59..d438b5ddee 100644
--- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java
+++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java
@@ -84,6 +84,7 @@
* @author Mark Paluch
* @author Nikolay Bogdanov
* @author Maninder Singh
+ * @author Sergey Shcherbakov
*/
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration("classpath:infrastructure.xml")
@@ -799,6 +800,49 @@ public void shouldAllowGroupingUsingConditionalExpressions() {
assertThat(((Number) good.get("score")).longValue(), is(equalTo(9000L)));
}
+ @Test // DATAMONGO-1784
+ public void shouldAllowSumUsingConditionalExpressions() {
+
+ mongoTemplate.dropCollection(CarPerson.class);
+
+ CarPerson person1 = new CarPerson("first1", "last1", new CarDescriptor.Entry("MAKE1", "MODEL1", 2000),
+ new CarDescriptor.Entry("MAKE1", "MODEL2", 2001));
+
+ CarPerson person2 = new CarPerson("first2", "last2", new CarDescriptor.Entry("MAKE3", "MODEL4", 2014));
+ CarPerson person3 = new CarPerson("first3", "last3", new CarDescriptor.Entry("MAKE2", "MODEL5", 2015));
+
+ mongoTemplate.save(person1);
+ mongoTemplate.save(person2);
+ mongoTemplate.save(person3);
+
+ TypedAggregation agg = Aggregation.newAggregation(CarPerson.class,
+ unwind("descriptors.carDescriptor.entries"), //
+ project() //
+ .and(ConditionalOperators //
+ .when(Criteria.where("descriptors.carDescriptor.entries.make").is("MAKE1")).then("good")
+ .otherwise("meh"))
+ .as("make") //
+ .and("descriptors.carDescriptor.entries.model").as("model") //
+ .and("descriptors.carDescriptor.entries.year").as("year"), //
+ group("make").sum(ConditionalOperators //
+ .when(Criteria.where("year").gte(2012)) //
+ .then(1) //
+ .otherwise(9000)).as("score"),
+ sort(ASC, "make"));
+
+ AggregationResults result = mongoTemplate.aggregate(agg, Document.class);
+
+ assertThat(result.getMappedResults(), hasSize(2));
+
+ Document meh = result.getMappedResults().get(0);
+ assertThat((String) meh.get("_id"), is(equalTo("meh")));
+ assertThat(((Number) meh.get("score")).longValue(), is(equalTo(2L)));
+
+ Document good = result.getMappedResults().get(1);
+ assertThat((String) good.get("_id"), is(equalTo("good")));
+ assertThat(((Number) good.get("score")).longValue(), is(equalTo(18000L)));
+ }
+
/**
* @see Return