diff --git a/pom.xml b/pom.xml index ea80a3cb74..a81aef6925 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 1.10.0.BUILD-SNAPSHOT + 1.10.0.DATAMONGO-1538-SNAPSHOT pom Spring Data MongoDB diff --git a/spring-data-mongodb-cross-store/pom.xml b/spring-data-mongodb-cross-store/pom.xml index ae0a5d6c8f..d4b9dad083 100644 --- a/spring-data-mongodb-cross-store/pom.xml +++ b/spring-data-mongodb-cross-store/pom.xml @@ -6,7 +6,7 @@ org.springframework.data spring-data-mongodb-parent - 1.10.0.BUILD-SNAPSHOT + 1.10.0.DATAMONGO-1538-SNAPSHOT ../pom.xml @@ -48,7 +48,7 @@ org.springframework.data spring-data-mongodb - 1.10.0.BUILD-SNAPSHOT + 1.10.0.DATAMONGO-1538-SNAPSHOT diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml index 2d02722262..09099d9980 100644 --- a/spring-data-mongodb-distribution/pom.xml +++ b/spring-data-mongodb-distribution/pom.xml @@ -13,7 +13,7 @@ org.springframework.data spring-data-mongodb-parent - 1.10.0.BUILD-SNAPSHOT + 1.10.0.DATAMONGO-1538-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb-log4j/pom.xml b/spring-data-mongodb-log4j/pom.xml index ee5e3336db..f78bc0261d 100644 --- a/spring-data-mongodb-log4j/pom.xml +++ b/spring-data-mongodb-log4j/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 1.10.0.BUILD-SNAPSHOT + 1.10.0.DATAMONGO-1538-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index 8072d3f665..b720861be6 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -11,7 +11,7 @@ org.springframework.data spring-data-mongodb-parent - 1.10.0.BUILD-SNAPSHOT + 1.10.0.DATAMONGO-1538-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java index a2a835fb31..2b5e873747 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java @@ -17,6 +17,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; @@ -25,8 +26,8 @@ import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond.OtherwiseBuilder; import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond.ThenBuilder; import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Filter.AsBuilder; +import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable; import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; -import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; import org.springframework.data.mongodb.core.query.CriteriaDefinition; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -1986,6 +1987,28 @@ public static Map.AsBuilder mapItemsOf(String fieldReference) { public static Map.AsBuilder mapItemsOf(AggregationExpression expression) { return Map.itemsOf(expression); } + + /** + * Start creating new {@link Let} that allows definition of {@link ExpressionVariable} that can be used within a + * nested {@link AggregationExpression}. + * + * @param variables must not be {@literal null}. + * @return + */ + public static Let.LetBuilder define(ExpressionVariable... variables) { + return Let.define(variables); + } + + /** + * Start creating new {@link Let} that allows definition of {@link ExpressionVariable} that can be used within a + * nested {@link AggregationExpression}. + * + * @param variables must not be {@literal null}. + * @return + */ + public static Let.LetBuilder define(Collection variables) { + return Let.define(variables); + } } /** @@ -3881,31 +3904,19 @@ public static AsBuilder filter(List values) { */ @Override public DBObject toDbObject(final AggregationOperationContext context) { - - return toFilter(new ExposedFieldsAggregationOperationContext(ExposedFields.from(as), context) { - - @Override - public FieldReference getReference(Field field) { - - FieldReference ref = null; - try { - ref = context.getReference(field); - } catch (Exception e) { - // just ignore that one. - } - return ref != null ? ref : super.getReference(field); - } - }); + return toFilter(ExposedFields.from(as), context); } - private DBObject toFilter(AggregationOperationContext context) { + private DBObject toFilter(ExposedFields exposedFields, AggregationOperationContext context) { DBObject filterExpression = new BasicDBObject(); + InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( + exposedFields, context); filterExpression.putAll(context.getMappedObject(new BasicDBObject("input", getMappedInput(context)))); filterExpression.put("as", as.getTarget()); - filterExpression.putAll(context.getMappedObject(new BasicDBObject("cond", getMappedCondition(context)))); + filterExpression.putAll(context.getMappedObject(new BasicDBObject("cond", getMappedCondition(operationContext)))); return new BasicDBObject("$filter", filterExpression); } @@ -5995,27 +6006,14 @@ public Map andApply(final AggregationExpression expression) { @Override public DBObject toDbObject(final AggregationOperationContext context) { - - return toMap(new ExposedFieldsAggregationOperationContext( - ExposedFields.synthetic(Fields.fields(itemVariableName)), context) { - - @Override - public FieldReference getReference(Field field) { - - FieldReference ref = null; - try { - ref = context.getReference(field); - } catch (Exception e) { - // just ignore that one. - } - return ref != null ? ref : super.getReference(field); - } - }); + return toMap(ExposedFields.synthetic(Fields.fields(itemVariableName)), context); } - private DBObject toMap(AggregationOperationContext context) { + private DBObject toMap(ExposedFields exposedFields, AggregationOperationContext context) { BasicDBObject map = new BasicDBObject(); + InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( + exposedFields, context); BasicDBObject input; if (sourceArray instanceof Field) { @@ -6026,7 +6024,8 @@ private DBObject toMap(AggregationOperationContext context) { map.putAll(context.getMappedObject(input)); map.put("as", itemVariableName); - map.put("in", functionToApply.toDbObject(new NestedDelegatingExpressionAggregationOperationContext(context))); + map.put("in", + functionToApply.toDbObject(new NestedDelegatingExpressionAggregationOperationContext(operationContext))); return new BasicDBObject("$map", map); } @@ -6696,4 +6695,173 @@ public Cond otherwiseValueOf(AggregationExpression expression) { } } } + + /** + * {@link AggregationExpression} for {@code $let} that binds {@link AggregationExpression} to variables for use in the + * specified {@code in} expression, and returns the result of the expression. + * + * @author Christoph Strobl + * @since 1.10 + */ + class Let implements AggregationExpression { + + private final List vars; + private final AggregationExpression expression; + + private Let(List vars, AggregationExpression expression) { + + this.vars = vars; + this.expression = expression; + } + + /** + * Start creating new {@link Let} by defining the variables for {@code $vars}. + * + * @param variables must not be {@literal null}. + * @return + */ + public static LetBuilder define(final Collection variables) { + + Assert.notNull(variables, "Variables must not be null!"); + + return new LetBuilder() { + @Override + public Let andApply(final AggregationExpression expression) { + + Assert.notNull(expression, "Expression must not be null!"); + return new Let(new ArrayList(variables), expression); + } + }; + } + + /** + * Start creating new {@link Let} by defining the variables for {@code $vars}. + * + * @param variables must not be {@literal null}. + * @return + */ + public static LetBuilder define(final ExpressionVariable... variables) { + + Assert.notNull(variables, "Variables must not be null!"); + + return new LetBuilder() { + @Override + public Let andApply(final AggregationExpression expression) { + + Assert.notNull(expression, "Expression must not be null!"); + return new Let(Arrays.asList(variables), expression); + } + }; + } + + public interface LetBuilder { + + /** + * Define the {@link AggregationExpression} to evaluate. + * + * @param expression must not be {@literal null}. + * @return + */ + Let andApply(AggregationExpression expression); + } + + @Override + public DBObject toDbObject(final AggregationOperationContext context) { + return toLet(ExposedFields.synthetic(Fields.fields(getVariableNames())), context); + } + + private String[] getVariableNames() { + + String[] varNames = new String[this.vars.size()]; + for (int i = 0; i < this.vars.size(); i++) { + varNames[i] = this.vars.get(i).variableName; + } + + return varNames; + } + + private DBObject toLet(ExposedFields exposedFields, AggregationOperationContext context) { + + DBObject letExpression = new BasicDBObject(); + DBObject mappedVars = new BasicDBObject(); + InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( + exposedFields, context); + + for (ExpressionVariable var : this.vars) { + mappedVars.putAll(getMappedVariable(var, context)); + } + + letExpression.put("vars", mappedVars); + letExpression.put("in", getMappedIn(operationContext)); + + return new BasicDBObject("$let", letExpression); + } + + private DBObject getMappedVariable(ExpressionVariable var, AggregationOperationContext context) { + + return new BasicDBObject(var.variableName, var.expression instanceof AggregationExpression + ? ((AggregationExpression) var.expression).toDbObject(context) : var.expression); + } + + private Object getMappedIn(AggregationOperationContext context) { + return expression.toDbObject(new NestedDelegatingExpressionAggregationOperationContext(context)); + } + + /** + * @author Christoph Strobl + */ + public static class ExpressionVariable { + + private final String variableName; + private final Object expression; + + /** + * Creates new {@link ExpressionVariable}. + * + * @param variableName can be {@literal null}. + * @param expression can be {@literal null}. + */ + private ExpressionVariable(String variableName, Object expression) { + + this.variableName = variableName; + this.expression = expression; + } + + /** + * Create a new {@link ExpressionVariable} with given name. + * + * @param variableName must not be {@literal null}. + * @return never {@literal null}. + */ + public static ExpressionVariable newVariable(String variableName) { + + Assert.notNull(variableName, "VariableName must not be null!"); + return new ExpressionVariable(variableName, null); + } + + /** + * Create a new {@link ExpressionVariable} with current name and given {@literal expression}. + * + * @param expression must not be {@literal null}. + * @return never {@literal null}. + */ + public ExpressionVariable forExpression(AggregationExpression expression) { + + Assert.notNull(expression, "Expression must not be null!"); + return new ExpressionVariable(variableName, expression); + } + + /** + * Create a new {@link ExpressionVariable} with current name and given {@literal expressionObject}. + * + * @param expressionObject must not be {@literal null}. + * @return never {@literal null}. + */ + public ExpressionVariable forExpression(DBObject expressionObject) { + + Assert.notNull(expressionObject, "Expression must not be null!"); + return new ExpressionVariable(variableName, expressionObject); + } + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java index f605dc628d..219a552ac1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java @@ -36,7 +36,7 @@ @Deprecated public enum AggregationFunctionExpressions { - SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT, ADD; + SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT, ADD, MULTIPLY; /** * Returns an {@link AggregationExpression} build from the current {@link Enum} name and the given parameters. diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java index c25b567328..2071dc0b6c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java @@ -13,17 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.data.mongodb.core.aggregation; import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; -import org.springframework.util.Assert; /** * {@link ExposedFieldsAggregationOperationContext} that inherits fields from its parent * {@link AggregationOperationContext}. * * @author Mark Paluch + * @since 1.9 */ class InheritingExposedFieldsAggregationOperationContext extends ExposedFieldsAggregationOperationContext { @@ -40,7 +39,7 @@ public InheritingExposedFieldsAggregationOperationContext(ExposedFields exposedF AggregationOperationContext previousContext) { super(exposedFields, previousContext); - Assert.notNull(previousContext, "PreviousContext must not be null!"); + this.previousContext = previousContext; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java index 428abeb5ee..f1d67b02c4 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java @@ -17,9 +17,11 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.List; +import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable; import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond; import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.IfNull; import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; @@ -852,7 +854,7 @@ public ProjectionOperationBuilder differenceToArray(String array) { /** * Generates a {@code $setIsSubset} expression that takes array of the previously mentioned field and returns * {@literal true} if it is a subset of the given {@literal array}. - * + * * @param array must not be {@literal null}. * @return never {@literal null}. * @since 1.10 @@ -1195,7 +1197,35 @@ public ProjectionOperationBuilder dateAsFormattedString(String format) { return this.operation.and(AggregationExpressions.DateToString.dateOf(name).toString(format)); } - /* + /** + * Generates a {@code $let} expression that binds variables for use in the specified expression, and returns the + * result of the expression. + * + * @param valueExpression The {@link AggregationExpression} bound to {@literal variableName}. + * @param variableName The variable name to be used in the {@literal in} {@link AggregationExpression}. + * @param in The {@link AggregationExpression} to evaluate. + * @return never {@literal null}. + * @since 1.10 + */ + public ProjectionOperationBuilder let(AggregationExpression valueExpression, String variableName, + AggregationExpression in) { + return this.operation.and(AggregationExpressions.Let.define(ExpressionVariable.newVariable(variableName).forExpression(valueExpression)).andApply(in)); + } + + /** + * Generates a {@code $let} expression that binds variables for use in the specified expression, and returns the + * result of the expression. + * + * @param variables The bound {@link ExpressionVariable}s. + * @param in The {@link AggregationExpression} to evaluate. + * @return never {@literal null}. + * @since 1.10 + */ + public ProjectionOperationBuilder let(Collection variables, AggregationExpression in) { + return this.operation.and(AggregationExpressions.Let.define(variables).andApply(in)); + } + + /* * (non-Javadoc) * @see org.springframework.data.mongodb.core.aggregation.AggregationOperation#toDBObject(org.springframework.data.mongodb.core.aggregation.AggregationOperationContext) */ diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index 100c59e128..d89d1a782a 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -59,6 +59,8 @@ import org.springframework.data.mongodb.core.Venue; import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond; import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.ConditionalOperators; +import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let; +import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable; import org.springframework.data.mongodb.core.aggregation.AggregationTests.CarDescriptor.Entry; import org.springframework.data.mongodb.core.index.GeospatialIndex; import org.springframework.data.mongodb.core.mapping.Document; @@ -71,6 +73,7 @@ import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import com.mongodb.BasicDBObject; +import com.mongodb.BasicDBObjectBuilder; import com.mongodb.CommandResult; import com.mongodb.DBCollection; import com.mongodb.DBObject; @@ -140,6 +143,8 @@ private void cleanDb() { mongoTemplate.dropCollection(MeterData.class); mongoTemplate.dropCollection(LineItem.class); mongoTemplate.dropCollection(InventoryItem.class); + mongoTemplate.dropCollection(Sales.class); + mongoTemplate.dropCollection(Sales2.class); } /** @@ -651,9 +656,9 @@ public void aggregationUsingIfNullProjection() { mongoTemplate.insert(new LineItem("idonly", null, 0)); TypedAggregation aggregation = newAggregation(LineItem.class, // -project("id") // - .and("caption")// - .applyCondition(ConditionalOperators.ifNull("caption").then("unknown")), + project("id") // + .and("caption")// + .applyCondition(ConditionalOperators.ifNull("caption").then("unknown")), sort(ASC, "id")); assertThat(aggregation.toString(), is(notNullValue())); @@ -1545,6 +1550,36 @@ public void filterShouldBeAppliedCorrectly() { Sales.builder().id("2").items(Collections. emptyList()).build())); } + /** + * @see DATAMONGO-1538 + */ + @Test + public void letShouldBeAppliedCorrectly() { + + assumeTrue(mongoVersion.isGreaterThanOrEqualTo(THREE_DOT_TWO)); + + Sales2 sales1 = Sales2.builder().id("1").price(10).tax(0.5F).applyDiscount(true).build(); + Sales2 sales2 = Sales2.builder().id("2").price(10).tax(0.25F).applyDiscount(false).build(); + + mongoTemplate.insert(Arrays.asList(sales1, sales2), Sales2.class); + + ExpressionVariable total = ExpressionVariable.newVariable("total") + .forExpression(AggregationFunctionExpressions.ADD.of(Fields.field("price"), Fields.field("tax"))); + ExpressionVariable discounted = ExpressionVariable.newVariable("discounted") + .forExpression(Cond.when("applyDiscount").then(0.9D).otherwise(1.0D)); + + TypedAggregation agg = Aggregation.newAggregation(Sales2.class, + Aggregation.project() + .and(Let.define(total, discounted).andApply( + AggregationFunctionExpressions.MULTIPLY.of(Fields.field("total"), Fields.field("discounted")))) + .as("finalTotal")); + + AggregationResults result = mongoTemplate.aggregate(agg, DBObject.class); + assertThat(result.getMappedResults(), + contains(new BasicDBObjectBuilder().add("_id", "1").add("finalTotal", 9.450000000000001D).get(), + new BasicDBObjectBuilder().add("_id", "2").add("finalTotal", 10.25D).get())); + } + private void createUsersWithReferencedPersons() { mongoTemplate.dropCollection(User.class); @@ -1786,6 +1821,9 @@ public InventoryItem(int id, String item, String description, int qty) { } } + /** + * @DATAMONGO-1491 + */ @lombok.Data @Builder static class Sales { @@ -1794,6 +1832,9 @@ static class Sales { List items; } + /** + * @DATAMONGO-1491 + */ @lombok.Data @Builder static class Item { @@ -1803,4 +1844,17 @@ static class Item { Integer quantity; Long price; } + + /** + * @DATAMONGO-1538 + */ + @lombok.Data + @Builder + static class Sales2 { + + String id; + Integer price; + Float tax; + boolean applyDiscount; + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java index 73a8b94a06..96d45ef1e9 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java @@ -16,8 +16,10 @@ package org.springframework.data.mongodb.core.aggregation; import static org.hamcrest.Matchers.*; +import static org.hamcrest.core.Is.is; import static org.junit.Assert.*; import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; +import static org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable.*; import static org.springframework.data.mongodb.core.aggregation.AggregationFunctionExpressions.*; import static org.springframework.data.mongodb.core.aggregation.Fields.*; import static org.springframework.data.mongodb.test.util.IsBsonObject.*; @@ -28,17 +30,9 @@ import org.junit.Test; import org.springframework.data.mongodb.core.DBObjectTestUtils; -import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.ArithmeticOperators; -import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.ArrayOperators; -import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.BooleanOperators; -import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.ComparisonOperators; -import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.ConditionalOperators; -import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.DateOperators; -import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.LiteralOperators; -import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.SetOperators; -import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.StringOperators; -import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.VariableOperators; +import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable; import org.springframework.data.mongodb.core.aggregation.ProjectionOperation.ProjectionOperationBuilder; +import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.*; import com.mongodb.BasicDBObject; import com.mongodb.DBObject; @@ -1712,11 +1706,12 @@ public void shouldRenderMapAggregationExpressionOnExpression() { @Test public void shouldRenderIfNullConditionAggregationExpression() { - DBObject agg = project().and(ConditionalOperators.ifNull(ArrayOperators.arrayOf("array").elementAt(1)).then("a more sophisticated value")) + DBObject agg = project().and( + ConditionalOperators.ifNull(ArrayOperators.arrayOf("array").elementAt(1)).then("a more sophisticated value")) .as("result").toDBObject(Aggregation.DEFAULT_CONTEXT); - assertThat(agg, - is(JSON.parse("{ $project: { result: { $ifNull: [ { $arrayElemAt: [\"$array\", 1] }, \"a more sophisticated value\" ] } } }"))); + assertThat(agg, is(JSON.parse( + "{ $project: { result: { $ifNull: [ { $arrayElemAt: [\"$array\", 1] }, \"a more sophisticated value\" ] } } }"))); } /** @@ -1745,6 +1740,58 @@ public void fieldReplacementIfNullShouldRenderCorrectly() { assertThat(agg, is(JSON.parse("{ $project: { result: { $ifNull: [ \"$optional\", \"$never-null\" ] } } }"))); } + /** + * @see DATAMONGO-1538 + */ + @Test + public void shouldRenderLetExpressionCorrectly() { + + DBObject agg = Aggregation.project() + .and(VariableOperators + .define( + newVariable("total") + .forExpression(AggregationFunctionExpressions.ADD.of(Fields.field("price"), Fields.field("tax"))), + newVariable("discounted").forExpression(Cond.when("applyDiscount").then(0.9D).otherwise(1.0D))) + .andApply(AggregationFunctionExpressions.MULTIPLY.of(Fields.field("total"), Fields.field("discounted")))) // + .as("finalTotal").toDBObject(Aggregation.DEFAULT_CONTEXT); + + assertThat(agg, + is(JSON.parse("{ $project:{ \"finalTotal\" : { \"$let\": {" + // + "\"vars\": {" + // + "\"total\": { \"$add\": [ \"$price\", \"$tax\" ] }," + // + "\"discounted\": { \"$cond\": { \"if\": \"$applyDiscount\", \"then\": 0.9, \"else\": 1.0 } }" + // + "}," + // + "\"in\": { \"$multiply\": [ \"$$total\", \"$$discounted\" ] }" + // + "}}}}"))); + } + + /** + * @see DATAMONGO-1538 + */ + @Test + public void shouldRenderLetExpressionCorrectlyWhenUsingLetOnProjectionBuilder() { + + ExpressionVariable var1 = newVariable("total") + .forExpression(AggregationFunctionExpressions.ADD.of(Fields.field("price"), Fields.field("tax"))); + + ExpressionVariable var2 = newVariable("discounted") + .forExpression(Cond.when("applyDiscount").then(0.9D).otherwise(1.0D)); + + DBObject agg = Aggregation.project().and("foo") + .let(Arrays.asList(var1, var2), + AggregationFunctionExpressions.MULTIPLY.of(Fields.field("total"), Fields.field("discounted"))) + .as("finalTotal").toDBObject(Aggregation.DEFAULT_CONTEXT); + + assertThat(agg, + is(JSON.parse("{ $project:{ \"finalTotal\" : { \"$let\": {" + // + "\"vars\": {" + // + "\"total\": { \"$add\": [ \"$price\", \"$tax\" ] }," + // + "\"discounted\": { \"$cond\": { \"if\": \"$applyDiscount\", \"then\": 0.9, \"else\": 1.0 } }" + // + "}," + // + "\"in\": { \"$multiply\": [ \"$$total\", \"$$discounted\" ] }" + // + "}}}}"))); + } + private static DBObject exctractOperation(String field, DBObject fromProjectClause) { return (DBObject) fromProjectClause.get(field); }