Skip to content

Commit 68db0d4

Browse files
christophstroblmp911de
authored andcommitted
DATAMONGO-1538 - Add support for $let to aggregation.
We now support $let in aggregation $project stage. ExpressionVariable total = newExpressionVariable("total").forExpression(ADD.of(field("price"), field("tax"))); ExpressionVariable discounted = newExpressionVariable("discounted").forExpression(Cond.when("applyDiscount").then(0.9D).otherwise(1.0D)); newAggregation(Sales.class, project() .and(define(total, discounted) .andApply(MULTIPLY.of(field("total"), field("discounted")))) .as("finalTotal")); Original pull request: #417.
1 parent c9dfeea commit 68db0d4

File tree

5 files changed

+353
-19
lines changed

5 files changed

+353
-19
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.ArrayList;
1919
import java.util.Arrays;
20+
import java.util.Collection;
2021
import java.util.Collections;
2122
import java.util.LinkedHashMap;
2223
import java.util.List;
@@ -25,6 +26,7 @@
2526
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond.OtherwiseBuilder;
2627
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond.ThenBuilder;
2728
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Filter.AsBuilder;
29+
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable;
2830
import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField;
2931
import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference;
3032
import org.springframework.data.mongodb.core.query.CriteriaDefinition;
@@ -1986,6 +1988,28 @@ public static Map.AsBuilder mapItemsOf(String fieldReference) {
19861988
public static Map.AsBuilder mapItemsOf(AggregationExpression expression) {
19871989
return Map.itemsOf(expression);
19881990
}
1991+
1992+
/**
1993+
* Start creating new {@link Let} that allows definition of {@link ExpressionVariable} that can be used within a
1994+
* nested {@link AggregationExpression}.
1995+
*
1996+
* @param variables must not be {@literal null}.
1997+
* @return
1998+
*/
1999+
public static Let.LetBuilder define(ExpressionVariable... variables) {
2000+
return Let.define(variables);
2001+
}
2002+
2003+
/**
2004+
* Start creating new {@link Let} that allows definition of {@link ExpressionVariable} that can be used within a
2005+
* nested {@link AggregationExpression}.
2006+
*
2007+
* @param variables must not be {@literal null}.
2008+
* @return
2009+
*/
2010+
public static Let.LetBuilder define(Collection<ExpressionVariable> variables) {
2011+
return Let.define(variables);
2012+
}
19892013
}
19902014

19912015
/**
@@ -6696,4 +6720,185 @@ public Cond otherwiseValueOf(AggregationExpression expression) {
66966720
}
66976721
}
66986722
}
6723+
6724+
/**
6725+
* {@link AggregationExpression} for {@code $let} that binds {@link AggregationExpression} to variables for use in the
6726+
* specified {@code in} expression, and returns the result of the expression.
6727+
*
6728+
* @author Christoph Strobl
6729+
* @since 1.10
6730+
*/
6731+
class Let implements AggregationExpression {
6732+
6733+
private final List<ExpressionVariable> vars;
6734+
private final AggregationExpression expression;
6735+
6736+
private Let(List<ExpressionVariable> vars, AggregationExpression expression) {
6737+
6738+
this.vars = vars;
6739+
this.expression = expression;
6740+
}
6741+
6742+
/**
6743+
* Start creating new {@link Let} by defining the variables for {@code $vars}.
6744+
*
6745+
* @param variables must not be {@literal null}.
6746+
* @return
6747+
*/
6748+
public static LetBuilder define(final Collection<ExpressionVariable> variables) {
6749+
6750+
Assert.notNull(variables, "Variables must not be null!");
6751+
6752+
return new LetBuilder() {
6753+
@Override
6754+
public Let andApply(final AggregationExpression expression) {
6755+
6756+
Assert.notNull(expression, "Expression must not be null!");
6757+
return new Let(new ArrayList<ExpressionVariable>(variables), expression);
6758+
}
6759+
};
6760+
}
6761+
6762+
/**
6763+
* Start creating new {@link Let} by defining the variables for {@code $vars}.
6764+
*
6765+
* @param variables must not be {@literal null}.
6766+
* @return
6767+
*/
6768+
public static LetBuilder define(final ExpressionVariable... variables) {
6769+
6770+
Assert.notNull(variables, "Variables must not be null!");
6771+
6772+
return new LetBuilder() {
6773+
@Override
6774+
public Let andApply(final AggregationExpression expression) {
6775+
6776+
Assert.notNull(expression, "Expression must not be null!");
6777+
return new Let(Arrays.asList(variables), expression);
6778+
}
6779+
};
6780+
}
6781+
6782+
public interface LetBuilder {
6783+
6784+
/**
6785+
* Define the {@link AggregationExpression} to evaluate.
6786+
*
6787+
* @param expression must not be {@literal null}.
6788+
* @return
6789+
*/
6790+
Let andApply(AggregationExpression expression);
6791+
}
6792+
6793+
@Override
6794+
public DBObject toDbObject(final AggregationOperationContext context) {
6795+
6796+
return toLet(new ExposedFieldsAggregationOperationContext(
6797+
ExposedFields.synthetic(Fields.fields(getVariableNames())), context) {
6798+
6799+
@Override
6800+
public FieldReference getReference(Field field) {
6801+
6802+
FieldReference ref = null;
6803+
try {
6804+
ref = context.getReference(field);
6805+
} catch (Exception e) {
6806+
// just ignore that one.
6807+
}
6808+
return ref != null ? ref : super.getReference(field);
6809+
}
6810+
});
6811+
}
6812+
6813+
private String[] getVariableNames() {
6814+
6815+
String[] varNames = new String[this.vars.size()];
6816+
for (int i = 0; i < this.vars.size(); i++) {
6817+
varNames[i] = this.vars.get(i).variableName;
6818+
}
6819+
return varNames;
6820+
}
6821+
6822+
private DBObject toLet(AggregationOperationContext context) {
6823+
6824+
DBObject letExpression = new BasicDBObject();
6825+
6826+
DBObject mappedVars = new BasicDBObject();
6827+
for (ExpressionVariable var : this.vars) {
6828+
mappedVars.putAll(getMappedVariable(var, context));
6829+
}
6830+
6831+
letExpression.put("vars", mappedVars);
6832+
letExpression.put("in", getMappedIn(context));
6833+
6834+
return new BasicDBObject("$let", letExpression);
6835+
}
6836+
6837+
private DBObject getMappedVariable(ExpressionVariable var, AggregationOperationContext context) {
6838+
6839+
return new BasicDBObject(var.variableName, var.expression instanceof AggregationExpression
6840+
? ((AggregationExpression) var.expression).toDbObject(context) : var.expression);
6841+
}
6842+
6843+
private Object getMappedIn(AggregationOperationContext context) {
6844+
return expression.toDbObject(new NestedDelegatingExpressionAggregationOperationContext(context));
6845+
}
6846+
6847+
/**
6848+
* @author Christoph Strobl
6849+
*/
6850+
public static class ExpressionVariable {
6851+
6852+
private final String variableName;
6853+
private final Object expression;
6854+
6855+
/**
6856+
* Creates new {@link ExpressionVariable}.
6857+
*
6858+
* @param variableName can be {@literal null}.
6859+
* @param expression can be {@literal null}.
6860+
*/
6861+
private ExpressionVariable(String variableName, Object expression) {
6862+
6863+
this.variableName = variableName;
6864+
this.expression = expression;
6865+
}
6866+
6867+
/**
6868+
* Create a new {@link ExpressionVariable} with given name.
6869+
*
6870+
* @param variableName must not be {@literal null}.
6871+
* @return never {@literal null}.
6872+
*/
6873+
public static ExpressionVariable newVariable(String variableName) {
6874+
6875+
Assert.notNull(variableName, "VariableName must not be null!");
6876+
return new ExpressionVariable(variableName, null);
6877+
}
6878+
6879+
/**
6880+
* Create a new {@link ExpressionVariable} with current name and given {@literal expression}.
6881+
*
6882+
* @param expression must not be {@literal null}.
6883+
* @return never {@literal null}.
6884+
*/
6885+
public ExpressionVariable forExpression(AggregationExpression expression) {
6886+
6887+
Assert.notNull(expression, "Expression must not be null!");
6888+
return new ExpressionVariable(variableName, expression);
6889+
}
6890+
6891+
/**
6892+
* Create a new {@link ExpressionVariable} with current name and given {@literal expressionObject}.
6893+
*
6894+
* @param expressionObject must not be {@literal null}.
6895+
* @return never {@literal null}.
6896+
*/
6897+
public ExpressionVariable forExpression(DBObject expressionObject) {
6898+
6899+
Assert.notNull(expressionObject, "Expression must not be null!");
6900+
return new ExpressionVariable(variableName, expressionObject);
6901+
}
6902+
}
6903+
}
66996904
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
@Deprecated
3737
public enum AggregationFunctionExpressions {
3838

39-
SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT, ADD;
39+
SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT, ADD, MULTIPLY;
4040

4141
/**
4242
* Returns an {@link AggregationExpression} build from the current {@link Enum} name and the given parameters.

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
import java.util.ArrayList;
1919
import java.util.Arrays;
20+
import java.util.Collection;
2021
import java.util.Collections;
2122
import java.util.List;
2223

24+
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable;
2325
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond;
2426
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.IfNull;
2527
import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField;
@@ -852,7 +854,7 @@ public ProjectionOperationBuilder differenceToArray(String array) {
852854
/**
853855
* Generates a {@code $setIsSubset} expression that takes array of the previously mentioned field and returns
854856
* {@literal true} if it is a subset of the given {@literal array}.
855-
*
857+
*
856858
* @param array must not be {@literal null}.
857859
* @return never {@literal null}.
858860
* @since 1.10
@@ -1195,7 +1197,35 @@ public ProjectionOperationBuilder dateAsFormattedString(String format) {
11951197
return this.operation.and(AggregationExpressions.DateToString.dateOf(name).toString(format));
11961198
}
11971199

1198-
/*
1200+
/**
1201+
* Generates a {@code $let} expression that binds variables for use in the specified expression, and returns the
1202+
* result of the expression.
1203+
*
1204+
* @param valueExpression The {@link AggregationExpression} bound to {@literal variableName}.
1205+
* @param variableName The variable name to be used in the {@literal in} {@link AggregationExpression}.
1206+
* @param in The {@link AggregationExpression} to evaluate.
1207+
* @return never {@literal null}.
1208+
* @since 1.10
1209+
*/
1210+
public ProjectionOperationBuilder let(AggregationExpression valueExpression, String variableName,
1211+
AggregationExpression in) {
1212+
return this.operation.and(AggregationExpressions.Let.define(ExpressionVariable.newVariable(variableName).forExpression(valueExpression)).andApply(in));
1213+
}
1214+
1215+
/**
1216+
* Generates a {@code $let} expression that binds variables for use in the specified expression, and returns the
1217+
* result of the expression.
1218+
*
1219+
* @param variables The bound {@link ExpressionVariable}s.
1220+
* @param in The {@link AggregationExpression} to evaluate.
1221+
* @return never {@literal null}.
1222+
* @since 1.10
1223+
*/
1224+
public ProjectionOperationBuilder let(Collection<ExpressionVariable> variables, AggregationExpression in) {
1225+
return this.operation.and(AggregationExpressions.Let.define(variables).andApply(in));
1226+
}
1227+
1228+
/*
11991229
* (non-Javadoc)
12001230
* @see org.springframework.data.mongodb.core.aggregation.AggregationOperation#toDBObject(org.springframework.data.mongodb.core.aggregation.AggregationOperationContext)
12011231
*/

0 commit comments

Comments
 (0)