17
17
18
18
import java .util .ArrayList ;
19
19
import java .util .Arrays ;
20
+ import java .util .HashSet ;
20
21
import java .util .List ;
22
+ import java .util .Set ;
21
23
24
+ import org .bson .Document ;
22
25
import org .springframework .data .mongodb .core .aggregation .ExposedFields .ExposedField ;
23
26
import org .springframework .data .mongodb .core .aggregation .FieldsExposingAggregationOperation .InheritsFieldsAggregationOperation ;
24
27
import org .springframework .data .mongodb .core .query .CriteriaDefinition ;
25
28
import org .springframework .util .Assert ;
26
-
27
- import org .bson .Document ;
29
+ import org .springframework .util .ClassUtils ;
28
30
29
31
/**
30
- * Encapsulates the aggregation framework {@code $graphLookup}-operation.
31
- * <p>
32
+ * Encapsulates the aggregation framework {@code $graphLookup}-operation. <br />
32
33
* Performs a recursive search on a collection, with options for restricting the search by recursion depth and query
33
- * filter.
34
- * <p>
34
+ * filter. <br />
35
35
* We recommend to use the static factory method {@link Aggregation#graphLookup(String)} instead of creating instances
36
36
* of this class directly.
37
37
*
38
- * @see http://docs.mongodb.org/manual/reference/aggregation/graphLookup/
38
+ * @see <a href=
39
+ * "http://docs.mongodb.org/manual/reference/aggregation/graphLookup/">http://docs.mongodb.org/manual/reference/aggregation/graphLookup/</a>
39
40
* @author Mark Paluch
41
+ * @author Christoph Strobl
40
42
* @since 1.10
41
43
*/
42
44
public class GraphLookupOperation implements InheritsFieldsAggregationOperation {
43
45
46
+ private static final Set <Class <?>> ALLOWED_START_TYPES = new HashSet <Class <?>>(
47
+ Arrays .<Class <?>> asList (AggregationExpression .class , String .class , Field .class , Document .class ));
48
+
44
49
private final String from ;
45
50
private final List <Object > startWith ;
46
51
private final Field connectFrom ;
@@ -65,15 +70,15 @@ private GraphLookupOperation(String from, List<Object> startWith, Field connectF
65
70
66
71
/**
67
72
* Creates a new {@link FromBuilder} to build {@link GraphLookupOperation}.
68
- *
73
+ *
69
74
* @return a new {@link FromBuilder}.
70
75
*/
71
76
public static FromBuilder builder () {
72
77
return new GraphLookupOperationFromBuilder ();
73
78
}
74
79
75
80
/* (non-Javadoc)
76
- * @see org.springframework.data.mongodb.core.aggregation.AggregationOperation#toDBObject (org.springframework.data.mongodb.core.aggregation.AggregationOperationContext)
81
+ * @see org.springframework.data.mongodb.core.aggregation.AggregationOperation#toDocument (org.springframework.data.mongodb.core.aggregation.AggregationOperationContext)
77
82
*/
78
83
@ Override
79
84
public Document toDocument (AggregationOperationContext context ) {
@@ -82,24 +87,20 @@ public Document toDocument(AggregationOperationContext context) {
82
87
83
88
graphLookup .put ("from" , from );
84
89
85
- List <Object > list = new ArrayList <>(startWith .size ());
90
+ List <Object > mappedStartWith = new ArrayList <Object >(startWith .size ());
86
91
87
92
for (Object startWithElement : startWith ) {
88
93
89
94
if (startWithElement instanceof AggregationExpression ) {
90
- list .add (((AggregationExpression ) startWithElement ).toDocument (context ));
91
- }
92
-
93
- if ( startWithElement instanceof Field ) {
94
- list .add (context . getReference (( Field ) startWithElement ). toString () );
95
+ mappedStartWith .add (((AggregationExpression ) startWithElement ).toDocument (context ));
96
+ } else if ( startWithElement instanceof Field ) {
97
+ mappedStartWith . add ( context . getReference (( Field ) startWithElement ). toString ());
98
+ } else {
99
+ mappedStartWith .add (startWithElement );
95
100
}
96
101
}
97
102
98
- if (list .size () == 1 ) {
99
- graphLookup .put ("startWith" , list .get (0 ));
100
- } else {
101
- graphLookup .put ("startWith" , list );
102
- }
103
+ graphLookup .put ("startWith" , mappedStartWith .size () == 1 ? mappedStartWith .iterator ().next () : mappedStartWith );
103
104
104
105
graphLookup .put ("connectFromField" , connectFrom .getName ());
105
106
graphLookup .put ("connectToField" , connectTo .getName ());
@@ -145,6 +146,7 @@ public interface FromBuilder {
145
146
146
147
/**
147
148
* @author Mark Paluch
149
+ * @author Christoph Strobl
148
150
*/
149
151
public interface StartWithBuilder {
150
152
@@ -163,6 +165,16 @@ public interface StartWithBuilder {
163
165
* @return
164
166
*/
165
167
ConnectFromBuilder startWith (AggregationExpression ... expressions );
168
+
169
+ /**
170
+ * Set the startWith as either {@literal fieldReferences}, {@link Fields}, {@link Document} or
171
+ * {@link AggregationExpression} to apply the {@code $graphLookup} to.
172
+ *
173
+ * @param expressions must not be {@literal null}.
174
+ * @return
175
+ * @throws IllegalArgumentException
176
+ */
177
+ ConnectFromBuilder startWith (Object ... expressions );
166
178
}
167
179
168
180
/**
@@ -196,7 +208,7 @@ public interface ConnectToBuilder {
196
208
/**
197
209
* Builder to build the initial {@link GraphLookupOperationBuilder} that configures the initial mandatory set of
198
210
* {@link GraphLookupOperation} properties.
199
- *
211
+ *
200
212
* @author Mark Paluch
201
213
*/
202
214
static final class GraphLookupOperationFromBuilder
@@ -215,7 +227,6 @@ public StartWithBuilder from(String collectionName) {
215
227
Assert .hasText (collectionName , "CollectionName must not be null or empty!" );
216
228
217
229
this .from = collectionName ;
218
-
219
230
return this ;
220
231
}
221
232
@@ -235,7 +246,6 @@ public ConnectFromBuilder startWith(String... fieldReferences) {
235
246
}
236
247
237
248
this .startWith = fields ;
238
-
239
249
return this ;
240
250
}
241
251
@@ -249,10 +259,50 @@ public ConnectFromBuilder startWith(AggregationExpression... expressions) {
249
259
Assert .noNullElements (expressions , "AggregationExpressions must not contain null elements!" );
250
260
251
261
this .startWith = Arrays .asList (expressions );
262
+ return this ;
263
+ }
264
+
265
+ @ Override
266
+ public ConnectFromBuilder startWith (Object ... expressions ) {
252
267
268
+ Assert .notNull (expressions , "Expressions must not be null!" );
269
+ Assert .noNullElements (expressions , "Expressions must not contain null elements!" );
270
+
271
+ this .startWith = verifyAndPotentiallyTransformStartsWithTypes (expressions );
253
272
return this ;
254
273
}
255
274
275
+ private List <Object > verifyAndPotentiallyTransformStartsWithTypes (Object ... expressions ) {
276
+
277
+ List <Object > expressionsToUse = new ArrayList <Object >(expressions .length );
278
+
279
+ for (Object expression : expressions ) {
280
+
281
+ assertStartWithType (expression );
282
+
283
+ if (expression instanceof String ) {
284
+ expressionsToUse .add (Fields .field ((String ) expression ));
285
+ } else {
286
+ expressionsToUse .add (expression );
287
+ }
288
+
289
+ }
290
+ return expressionsToUse ;
291
+ }
292
+
293
+ private void assertStartWithType (Object expression ) {
294
+
295
+ for (Class <?> type : ALLOWED_START_TYPES ) {
296
+
297
+ if (ClassUtils .isAssignable (type , expression .getClass ())) {
298
+ return ;
299
+ }
300
+ }
301
+
302
+ throw new IllegalArgumentException (
303
+ String .format ("Expression must be any of %s but was %s" , ALLOWED_START_TYPES , expression .getClass ()));
304
+ }
305
+
256
306
/* (non-Javadoc)
257
307
* @see org.springframework.data.mongodb.core.aggregation.GraphLookupOperation.ConnectFromBuilder#connectFrom(java.lang.String)
258
308
*/
@@ -262,7 +312,6 @@ public ConnectToBuilder connectFrom(String fieldName) {
262
312
Assert .hasText (fieldName , "ConnectFrom must not be null or empty!" );
263
313
264
314
this .connectFrom = fieldName ;
265
-
266
315
return this ;
267
316
}
268
317
@@ -301,8 +350,8 @@ protected GraphLookupOperationBuilder(String from, List<? extends Object> startW
301
350
}
302
351
303
352
/**
304
- * Limit the number of recursions.
305
- *
353
+ * Optionally limit the number of recursions.
354
+ *
306
355
* @param numberOfRecursions must be greater or equal to zero.
307
356
* @return
308
357
*/
@@ -311,13 +360,12 @@ public GraphLookupOperationBuilder maxDepth(long numberOfRecursions) {
311
360
Assert .isTrue (numberOfRecursions >= 0 , "Max depth must be >= 0!" );
312
361
313
362
this .maxDepth = numberOfRecursions ;
314
-
315
363
return this ;
316
364
}
317
365
318
366
/**
319
- * Add a depth field {@literal fieldName} to each traversed document in the search path.
320
- *
367
+ * Optionally add a depth field {@literal fieldName} to each traversed document in the search path.
368
+ *
321
369
* @param fieldName must not be {@literal null} or empty.
322
370
* @return
323
371
*/
@@ -326,13 +374,12 @@ public GraphLookupOperationBuilder depthField(String fieldName) {
326
374
Assert .hasText (fieldName , "Depth field name must not be null or empty!" );
327
375
328
376
this .depthField = Fields .field (fieldName );
329
-
330
377
return this ;
331
378
}
332
379
333
380
/**
334
- * Add a query specifying conditions to the recursive search.
335
- *
381
+ * Optionally add a query specifying conditions to the recursive search.
382
+ *
336
383
* @param criteriaDefinition must not be {@literal null}.
337
384
* @return
338
385
*/
@@ -341,14 +388,13 @@ public GraphLookupOperationBuilder restrict(CriteriaDefinition criteriaDefinitio
341
388
Assert .notNull (criteriaDefinition , "CriteriaDefinition must not be null!" );
342
389
343
390
this .restrictSearchWithMatch = criteriaDefinition ;
344
-
345
391
return this ;
346
392
}
347
393
348
394
/**
349
395
* Set the name of the array field added to each output document and return the final {@link GraphLookupOperation}.
350
396
* Contains the documents traversed in the {@literal $graphLookup} stage to reach the document.
351
- *
397
+ *
352
398
* @param fieldName must not be {@literal null} or empty.
353
399
* @return the final {@link GraphLookupOperation}.
354
400
*/
0 commit comments