diff --git a/pom.xml b/pom.xml
index ea80a3cb74..a81aef6925 100644
--- a/pom.xml
+++ b/pom.xml
@@ -5,7 +5,7 @@
org.springframework.data
spring-data-mongodb-parent
- 1.10.0.BUILD-SNAPSHOT
+ 1.10.0.DATAMONGO-1538-SNAPSHOT
pom
Spring Data MongoDB
diff --git a/spring-data-mongodb-cross-store/pom.xml b/spring-data-mongodb-cross-store/pom.xml
index ae0a5d6c8f..d4b9dad083 100644
--- a/spring-data-mongodb-cross-store/pom.xml
+++ b/spring-data-mongodb-cross-store/pom.xml
@@ -6,7 +6,7 @@
org.springframework.data
spring-data-mongodb-parent
- 1.10.0.BUILD-SNAPSHOT
+ 1.10.0.DATAMONGO-1538-SNAPSHOT
../pom.xml
@@ -48,7 +48,7 @@
org.springframework.data
spring-data-mongodb
- 1.10.0.BUILD-SNAPSHOT
+ 1.10.0.DATAMONGO-1538-SNAPSHOT
diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml
index 2d02722262..09099d9980 100644
--- a/spring-data-mongodb-distribution/pom.xml
+++ b/spring-data-mongodb-distribution/pom.xml
@@ -13,7 +13,7 @@
org.springframework.data
spring-data-mongodb-parent
- 1.10.0.BUILD-SNAPSHOT
+ 1.10.0.DATAMONGO-1538-SNAPSHOT
../pom.xml
diff --git a/spring-data-mongodb-log4j/pom.xml b/spring-data-mongodb-log4j/pom.xml
index ee5e3336db..f78bc0261d 100644
--- a/spring-data-mongodb-log4j/pom.xml
+++ b/spring-data-mongodb-log4j/pom.xml
@@ -5,7 +5,7 @@
org.springframework.data
spring-data-mongodb-parent
- 1.10.0.BUILD-SNAPSHOT
+ 1.10.0.DATAMONGO-1538-SNAPSHOT
../pom.xml
diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml
index 8072d3f665..b720861be6 100644
--- a/spring-data-mongodb/pom.xml
+++ b/spring-data-mongodb/pom.xml
@@ -11,7 +11,7 @@
org.springframework.data
spring-data-mongodb-parent
- 1.10.0.BUILD-SNAPSHOT
+ 1.10.0.DATAMONGO-1538-SNAPSHOT
../pom.xml
diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java
index a2a835fb31..2b5e873747 100644
--- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java
+++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java
@@ -17,6 +17,7 @@
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
@@ -25,8 +26,8 @@
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond.OtherwiseBuilder;
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond.ThenBuilder;
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Filter.AsBuilder;
+import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable;
import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField;
-import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference;
import org.springframework.data.mongodb.core.query.CriteriaDefinition;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
@@ -1986,6 +1987,28 @@ public static Map.AsBuilder mapItemsOf(String fieldReference) {
public static Map.AsBuilder mapItemsOf(AggregationExpression expression) {
return Map.itemsOf(expression);
}
+
+ /**
+ * Start creating new {@link Let} that allows definition of {@link ExpressionVariable} that can be used within a
+ * nested {@link AggregationExpression}.
+ *
+ * @param variables must not be {@literal null}.
+ * @return
+ */
+ public static Let.LetBuilder define(ExpressionVariable... variables) {
+ return Let.define(variables);
+ }
+
+ /**
+ * Start creating new {@link Let} that allows definition of {@link ExpressionVariable} that can be used within a
+ * nested {@link AggregationExpression}.
+ *
+ * @param variables must not be {@literal null}.
+ * @return
+ */
+ public static Let.LetBuilder define(Collection variables) {
+ return Let.define(variables);
+ }
}
/**
@@ -3881,31 +3904,19 @@ public static AsBuilder filter(List> values) {
*/
@Override
public DBObject toDbObject(final AggregationOperationContext context) {
-
- return toFilter(new ExposedFieldsAggregationOperationContext(ExposedFields.from(as), context) {
-
- @Override
- public FieldReference getReference(Field field) {
-
- FieldReference ref = null;
- try {
- ref = context.getReference(field);
- } catch (Exception e) {
- // just ignore that one.
- }
- return ref != null ? ref : super.getReference(field);
- }
- });
+ return toFilter(ExposedFields.from(as), context);
}
- private DBObject toFilter(AggregationOperationContext context) {
+ private DBObject toFilter(ExposedFields exposedFields, AggregationOperationContext context) {
DBObject filterExpression = new BasicDBObject();
+ InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext(
+ exposedFields, context);
filterExpression.putAll(context.getMappedObject(new BasicDBObject("input", getMappedInput(context))));
filterExpression.put("as", as.getTarget());
- filterExpression.putAll(context.getMappedObject(new BasicDBObject("cond", getMappedCondition(context))));
+ filterExpression.putAll(context.getMappedObject(new BasicDBObject("cond", getMappedCondition(operationContext))));
return new BasicDBObject("$filter", filterExpression);
}
@@ -5995,27 +6006,14 @@ public Map andApply(final AggregationExpression expression) {
@Override
public DBObject toDbObject(final AggregationOperationContext context) {
-
- return toMap(new ExposedFieldsAggregationOperationContext(
- ExposedFields.synthetic(Fields.fields(itemVariableName)), context) {
-
- @Override
- public FieldReference getReference(Field field) {
-
- FieldReference ref = null;
- try {
- ref = context.getReference(field);
- } catch (Exception e) {
- // just ignore that one.
- }
- return ref != null ? ref : super.getReference(field);
- }
- });
+ return toMap(ExposedFields.synthetic(Fields.fields(itemVariableName)), context);
}
- private DBObject toMap(AggregationOperationContext context) {
+ private DBObject toMap(ExposedFields exposedFields, AggregationOperationContext context) {
BasicDBObject map = new BasicDBObject();
+ InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext(
+ exposedFields, context);
BasicDBObject input;
if (sourceArray instanceof Field) {
@@ -6026,7 +6024,8 @@ private DBObject toMap(AggregationOperationContext context) {
map.putAll(context.getMappedObject(input));
map.put("as", itemVariableName);
- map.put("in", functionToApply.toDbObject(new NestedDelegatingExpressionAggregationOperationContext(context)));
+ map.put("in",
+ functionToApply.toDbObject(new NestedDelegatingExpressionAggregationOperationContext(operationContext)));
return new BasicDBObject("$map", map);
}
@@ -6696,4 +6695,173 @@ public Cond otherwiseValueOf(AggregationExpression expression) {
}
}
}
+
+ /**
+ * {@link AggregationExpression} for {@code $let} that binds {@link AggregationExpression} to variables for use in the
+ * specified {@code in} expression, and returns the result of the expression.
+ *
+ * @author Christoph Strobl
+ * @since 1.10
+ */
+ class Let implements AggregationExpression {
+
+ private final List vars;
+ private final AggregationExpression expression;
+
+ private Let(List vars, AggregationExpression expression) {
+
+ this.vars = vars;
+ this.expression = expression;
+ }
+
+ /**
+ * Start creating new {@link Let} by defining the variables for {@code $vars}.
+ *
+ * @param variables must not be {@literal null}.
+ * @return
+ */
+ public static LetBuilder define(final Collection variables) {
+
+ Assert.notNull(variables, "Variables must not be null!");
+
+ return new LetBuilder() {
+ @Override
+ public Let andApply(final AggregationExpression expression) {
+
+ Assert.notNull(expression, "Expression must not be null!");
+ return new Let(new ArrayList(variables), expression);
+ }
+ };
+ }
+
+ /**
+ * Start creating new {@link Let} by defining the variables for {@code $vars}.
+ *
+ * @param variables must not be {@literal null}.
+ * @return
+ */
+ public static LetBuilder define(final ExpressionVariable... variables) {
+
+ Assert.notNull(variables, "Variables must not be null!");
+
+ return new LetBuilder() {
+ @Override
+ public Let andApply(final AggregationExpression expression) {
+
+ Assert.notNull(expression, "Expression must not be null!");
+ return new Let(Arrays.asList(variables), expression);
+ }
+ };
+ }
+
+ public interface LetBuilder {
+
+ /**
+ * Define the {@link AggregationExpression} to evaluate.
+ *
+ * @param expression must not be {@literal null}.
+ * @return
+ */
+ Let andApply(AggregationExpression expression);
+ }
+
+ @Override
+ public DBObject toDbObject(final AggregationOperationContext context) {
+ return toLet(ExposedFields.synthetic(Fields.fields(getVariableNames())), context);
+ }
+
+ private String[] getVariableNames() {
+
+ String[] varNames = new String[this.vars.size()];
+ for (int i = 0; i < this.vars.size(); i++) {
+ varNames[i] = this.vars.get(i).variableName;
+ }
+
+ return varNames;
+ }
+
+ private DBObject toLet(ExposedFields exposedFields, AggregationOperationContext context) {
+
+ DBObject letExpression = new BasicDBObject();
+ DBObject mappedVars = new BasicDBObject();
+ InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext(
+ exposedFields, context);
+
+ for (ExpressionVariable var : this.vars) {
+ mappedVars.putAll(getMappedVariable(var, context));
+ }
+
+ letExpression.put("vars", mappedVars);
+ letExpression.put("in", getMappedIn(operationContext));
+
+ return new BasicDBObject("$let", letExpression);
+ }
+
+ private DBObject getMappedVariable(ExpressionVariable var, AggregationOperationContext context) {
+
+ return new BasicDBObject(var.variableName, var.expression instanceof AggregationExpression
+ ? ((AggregationExpression) var.expression).toDbObject(context) : var.expression);
+ }
+
+ private Object getMappedIn(AggregationOperationContext context) {
+ return expression.toDbObject(new NestedDelegatingExpressionAggregationOperationContext(context));
+ }
+
+ /**
+ * @author Christoph Strobl
+ */
+ public static class ExpressionVariable {
+
+ private final String variableName;
+ private final Object expression;
+
+ /**
+ * Creates new {@link ExpressionVariable}.
+ *
+ * @param variableName can be {@literal null}.
+ * @param expression can be {@literal null}.
+ */
+ private ExpressionVariable(String variableName, Object expression) {
+
+ this.variableName = variableName;
+ this.expression = expression;
+ }
+
+ /**
+ * Create a new {@link ExpressionVariable} with given name.
+ *
+ * @param variableName must not be {@literal null}.
+ * @return never {@literal null}.
+ */
+ public static ExpressionVariable newVariable(String variableName) {
+
+ Assert.notNull(variableName, "VariableName must not be null!");
+ return new ExpressionVariable(variableName, null);
+ }
+
+ /**
+ * Create a new {@link ExpressionVariable} with current name and given {@literal expression}.
+ *
+ * @param expression must not be {@literal null}.
+ * @return never {@literal null}.
+ */
+ public ExpressionVariable forExpression(AggregationExpression expression) {
+
+ Assert.notNull(expression, "Expression must not be null!");
+ return new ExpressionVariable(variableName, expression);
+ }
+
+ /**
+ * Create a new {@link ExpressionVariable} with current name and given {@literal expressionObject}.
+ *
+ * @param expressionObject must not be {@literal null}.
+ * @return never {@literal null}.
+ */
+ public ExpressionVariable forExpression(DBObject expressionObject) {
+
+ Assert.notNull(expressionObject, "Expression must not be null!");
+ return new ExpressionVariable(variableName, expressionObject);
+ }
+ }
+ }
}
diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java
index f605dc628d..219a552ac1 100644
--- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java
+++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java
@@ -36,7 +36,7 @@
@Deprecated
public enum AggregationFunctionExpressions {
- SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT, ADD;
+ SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT, ADD, MULTIPLY;
/**
* Returns an {@link AggregationExpression} build from the current {@link Enum} name and the given parameters.
diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java
index c25b567328..2071dc0b6c 100644
--- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java
+++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java
@@ -13,17 +13,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.springframework.data.mongodb.core.aggregation;
import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference;
-import org.springframework.util.Assert;
/**
* {@link ExposedFieldsAggregationOperationContext} that inherits fields from its parent
* {@link AggregationOperationContext}.
*
* @author Mark Paluch
+ * @since 1.9
*/
class InheritingExposedFieldsAggregationOperationContext extends ExposedFieldsAggregationOperationContext {
@@ -40,7 +39,7 @@ public InheritingExposedFieldsAggregationOperationContext(ExposedFields exposedF
AggregationOperationContext previousContext) {
super(exposedFields, previousContext);
- Assert.notNull(previousContext, "PreviousContext must not be null!");
+
this.previousContext = previousContext;
}
diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java
index 428abeb5ee..f1d67b02c4 100644
--- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java
+++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java
@@ -17,9 +17,11 @@
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collection;
import java.util.Collections;
import java.util.List;
+import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable;
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond;
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.IfNull;
import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField;
@@ -852,7 +854,7 @@ public ProjectionOperationBuilder differenceToArray(String array) {
/**
* Generates a {@code $setIsSubset} expression that takes array of the previously mentioned field and returns
* {@literal true} if it is a subset of the given {@literal array}.
- *
+ *
* @param array must not be {@literal null}.
* @return never {@literal null}.
* @since 1.10
@@ -1195,7 +1197,35 @@ public ProjectionOperationBuilder dateAsFormattedString(String format) {
return this.operation.and(AggregationExpressions.DateToString.dateOf(name).toString(format));
}
- /*
+ /**
+ * Generates a {@code $let} expression that binds variables for use in the specified expression, and returns the
+ * result of the expression.
+ *
+ * @param valueExpression The {@link AggregationExpression} bound to {@literal variableName}.
+ * @param variableName The variable name to be used in the {@literal in} {@link AggregationExpression}.
+ * @param in The {@link AggregationExpression} to evaluate.
+ * @return never {@literal null}.
+ * @since 1.10
+ */
+ public ProjectionOperationBuilder let(AggregationExpression valueExpression, String variableName,
+ AggregationExpression in) {
+ return this.operation.and(AggregationExpressions.Let.define(ExpressionVariable.newVariable(variableName).forExpression(valueExpression)).andApply(in));
+ }
+
+ /**
+ * Generates a {@code $let} expression that binds variables for use in the specified expression, and returns the
+ * result of the expression.
+ *
+ * @param variables The bound {@link ExpressionVariable}s.
+ * @param in The {@link AggregationExpression} to evaluate.
+ * @return never {@literal null}.
+ * @since 1.10
+ */
+ public ProjectionOperationBuilder let(Collection variables, AggregationExpression in) {
+ return this.operation.and(AggregationExpressions.Let.define(variables).andApply(in));
+ }
+
+ /*
* (non-Javadoc)
* @see org.springframework.data.mongodb.core.aggregation.AggregationOperation#toDBObject(org.springframework.data.mongodb.core.aggregation.AggregationOperationContext)
*/
diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java
index 100c59e128..d89d1a782a 100644
--- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java
+++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java
@@ -59,6 +59,8 @@
import org.springframework.data.mongodb.core.Venue;
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond;
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.ConditionalOperators;
+import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let;
+import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable;
import org.springframework.data.mongodb.core.aggregation.AggregationTests.CarDescriptor.Entry;
import org.springframework.data.mongodb.core.index.GeospatialIndex;
import org.springframework.data.mongodb.core.mapping.Document;
@@ -71,6 +73,7 @@
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import com.mongodb.BasicDBObject;
+import com.mongodb.BasicDBObjectBuilder;
import com.mongodb.CommandResult;
import com.mongodb.DBCollection;
import com.mongodb.DBObject;
@@ -140,6 +143,8 @@ private void cleanDb() {
mongoTemplate.dropCollection(MeterData.class);
mongoTemplate.dropCollection(LineItem.class);
mongoTemplate.dropCollection(InventoryItem.class);
+ mongoTemplate.dropCollection(Sales.class);
+ mongoTemplate.dropCollection(Sales2.class);
}
/**
@@ -651,9 +656,9 @@ public void aggregationUsingIfNullProjection() {
mongoTemplate.insert(new LineItem("idonly", null, 0));
TypedAggregation aggregation = newAggregation(LineItem.class, //
-project("id") //
- .and("caption")//
- .applyCondition(ConditionalOperators.ifNull("caption").then("unknown")),
+ project("id") //
+ .and("caption")//
+ .applyCondition(ConditionalOperators.ifNull("caption").then("unknown")),
sort(ASC, "id"));
assertThat(aggregation.toString(), is(notNullValue()));
@@ -1545,6 +1550,36 @@ public void filterShouldBeAppliedCorrectly() {
Sales.builder().id("2").items(Collections.- emptyList()).build()));
}
+ /**
+ * @see DATAMONGO-1538
+ */
+ @Test
+ public void letShouldBeAppliedCorrectly() {
+
+ assumeTrue(mongoVersion.isGreaterThanOrEqualTo(THREE_DOT_TWO));
+
+ Sales2 sales1 = Sales2.builder().id("1").price(10).tax(0.5F).applyDiscount(true).build();
+ Sales2 sales2 = Sales2.builder().id("2").price(10).tax(0.25F).applyDiscount(false).build();
+
+ mongoTemplate.insert(Arrays.asList(sales1, sales2), Sales2.class);
+
+ ExpressionVariable total = ExpressionVariable.newVariable("total")
+ .forExpression(AggregationFunctionExpressions.ADD.of(Fields.field("price"), Fields.field("tax")));
+ ExpressionVariable discounted = ExpressionVariable.newVariable("discounted")
+ .forExpression(Cond.when("applyDiscount").then(0.9D).otherwise(1.0D));
+
+ TypedAggregation agg = Aggregation.newAggregation(Sales2.class,
+ Aggregation.project()
+ .and(Let.define(total, discounted).andApply(
+ AggregationFunctionExpressions.MULTIPLY.of(Fields.field("total"), Fields.field("discounted"))))
+ .as("finalTotal"));
+
+ AggregationResults result = mongoTemplate.aggregate(agg, DBObject.class);
+ assertThat(result.getMappedResults(),
+ contains(new BasicDBObjectBuilder().add("_id", "1").add("finalTotal", 9.450000000000001D).get(),
+ new BasicDBObjectBuilder().add("_id", "2").add("finalTotal", 10.25D).get()));
+ }
+
private void createUsersWithReferencedPersons() {
mongoTemplate.dropCollection(User.class);
@@ -1786,6 +1821,9 @@ public InventoryItem(int id, String item, String description, int qty) {
}
}
+ /**
+ * @DATAMONGO-1491
+ */
@lombok.Data
@Builder
static class Sales {
@@ -1794,6 +1832,9 @@ static class Sales {
List
- items;
}
+ /**
+ * @DATAMONGO-1491
+ */
@lombok.Data
@Builder
static class Item {
@@ -1803,4 +1844,17 @@ static class Item {
Integer quantity;
Long price;
}
+
+ /**
+ * @DATAMONGO-1538
+ */
+ @lombok.Data
+ @Builder
+ static class Sales2 {
+
+ String id;
+ Integer price;
+ Float tax;
+ boolean applyDiscount;
+ }
}
diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java
index 73a8b94a06..96d45ef1e9 100644
--- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java
+++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java
@@ -16,8 +16,10 @@
package org.springframework.data.mongodb.core.aggregation;
import static org.hamcrest.Matchers.*;
+import static org.hamcrest.core.Is.is;
import static org.junit.Assert.*;
import static org.springframework.data.mongodb.core.aggregation.Aggregation.*;
+import static org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable.*;
import static org.springframework.data.mongodb.core.aggregation.AggregationFunctionExpressions.*;
import static org.springframework.data.mongodb.core.aggregation.Fields.*;
import static org.springframework.data.mongodb.test.util.IsBsonObject.*;
@@ -28,17 +30,9 @@
import org.junit.Test;
import org.springframework.data.mongodb.core.DBObjectTestUtils;
-import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.ArithmeticOperators;
-import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.ArrayOperators;
-import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.BooleanOperators;
-import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.ComparisonOperators;
-import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.ConditionalOperators;
-import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.DateOperators;
-import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.LiteralOperators;
-import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.SetOperators;
-import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.StringOperators;
-import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.VariableOperators;
+import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable;
import org.springframework.data.mongodb.core.aggregation.ProjectionOperation.ProjectionOperationBuilder;
+import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.*;
import com.mongodb.BasicDBObject;
import com.mongodb.DBObject;
@@ -1712,11 +1706,12 @@ public void shouldRenderMapAggregationExpressionOnExpression() {
@Test
public void shouldRenderIfNullConditionAggregationExpression() {
- DBObject agg = project().and(ConditionalOperators.ifNull(ArrayOperators.arrayOf("array").elementAt(1)).then("a more sophisticated value"))
+ DBObject agg = project().and(
+ ConditionalOperators.ifNull(ArrayOperators.arrayOf("array").elementAt(1)).then("a more sophisticated value"))
.as("result").toDBObject(Aggregation.DEFAULT_CONTEXT);
- assertThat(agg,
- is(JSON.parse("{ $project: { result: { $ifNull: [ { $arrayElemAt: [\"$array\", 1] }, \"a more sophisticated value\" ] } } }")));
+ assertThat(agg, is(JSON.parse(
+ "{ $project: { result: { $ifNull: [ { $arrayElemAt: [\"$array\", 1] }, \"a more sophisticated value\" ] } } }")));
}
/**
@@ -1745,6 +1740,58 @@ public void fieldReplacementIfNullShouldRenderCorrectly() {
assertThat(agg, is(JSON.parse("{ $project: { result: { $ifNull: [ \"$optional\", \"$never-null\" ] } } }")));
}
+ /**
+ * @see DATAMONGO-1538
+ */
+ @Test
+ public void shouldRenderLetExpressionCorrectly() {
+
+ DBObject agg = Aggregation.project()
+ .and(VariableOperators
+ .define(
+ newVariable("total")
+ .forExpression(AggregationFunctionExpressions.ADD.of(Fields.field("price"), Fields.field("tax"))),
+ newVariable("discounted").forExpression(Cond.when("applyDiscount").then(0.9D).otherwise(1.0D)))
+ .andApply(AggregationFunctionExpressions.MULTIPLY.of(Fields.field("total"), Fields.field("discounted")))) //
+ .as("finalTotal").toDBObject(Aggregation.DEFAULT_CONTEXT);
+
+ assertThat(agg,
+ is(JSON.parse("{ $project:{ \"finalTotal\" : { \"$let\": {" + //
+ "\"vars\": {" + //
+ "\"total\": { \"$add\": [ \"$price\", \"$tax\" ] }," + //
+ "\"discounted\": { \"$cond\": { \"if\": \"$applyDiscount\", \"then\": 0.9, \"else\": 1.0 } }" + //
+ "}," + //
+ "\"in\": { \"$multiply\": [ \"$$total\", \"$$discounted\" ] }" + //
+ "}}}}")));
+ }
+
+ /**
+ * @see DATAMONGO-1538
+ */
+ @Test
+ public void shouldRenderLetExpressionCorrectlyWhenUsingLetOnProjectionBuilder() {
+
+ ExpressionVariable var1 = newVariable("total")
+ .forExpression(AggregationFunctionExpressions.ADD.of(Fields.field("price"), Fields.field("tax")));
+
+ ExpressionVariable var2 = newVariable("discounted")
+ .forExpression(Cond.when("applyDiscount").then(0.9D).otherwise(1.0D));
+
+ DBObject agg = Aggregation.project().and("foo")
+ .let(Arrays.asList(var1, var2),
+ AggregationFunctionExpressions.MULTIPLY.of(Fields.field("total"), Fields.field("discounted")))
+ .as("finalTotal").toDBObject(Aggregation.DEFAULT_CONTEXT);
+
+ assertThat(agg,
+ is(JSON.parse("{ $project:{ \"finalTotal\" : { \"$let\": {" + //
+ "\"vars\": {" + //
+ "\"total\": { \"$add\": [ \"$price\", \"$tax\" ] }," + //
+ "\"discounted\": { \"$cond\": { \"if\": \"$applyDiscount\", \"then\": 0.9, \"else\": 1.0 } }" + //
+ "}," + //
+ "\"in\": { \"$multiply\": [ \"$$total\", \"$$discounted\" ] }" + //
+ "}}}}")));
+ }
+
private static DBObject exctractOperation(String field, DBObject fromProjectClause) {
return (DBObject) fromProjectClause.get(field);
}