Skip to content

Commit 7c39e7e

Browse files
committed
refactor: collect_fields to separate utility
Replicates graphql/graphql-js@dab4f44
1 parent 112145b commit 7c39e7e

File tree

4 files changed

+159
-119
lines changed

4 files changed

+159
-119
lines changed
+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from typing import Any, Dict, List, Set, Union, cast
2+
3+
from ..language import (
4+
FieldNode,
5+
FragmentDefinitionNode,
6+
FragmentSpreadNode,
7+
InlineFragmentNode,
8+
SelectionSetNode,
9+
)
10+
from ..type import (
11+
GraphQLAbstractType,
12+
GraphQLIncludeDirective,
13+
GraphQLObjectType,
14+
GraphQLSchema,
15+
GraphQLSkipDirective,
16+
is_abstract_type,
17+
)
18+
from ..utilities.type_from_ast import type_from_ast
19+
from .values import get_directive_values
20+
21+
__all__ = ["collect_fields"]
22+
23+
24+
def collect_fields(
25+
schema: GraphQLSchema,
26+
fragments: Dict[str, FragmentDefinitionNode],
27+
variable_values: Dict[str, Any],
28+
runtime_type: GraphQLObjectType,
29+
selection_set: SelectionSetNode,
30+
fields: Dict[str, List[FieldNode]],
31+
visited_fragment_names: Set[str],
32+
) -> Dict[str, List[FieldNode]]:
33+
"""Collect fields.
34+
35+
Given a selection_set, adds all of the fields in that selection to the passed in
36+
map of fields, and returns it at the end.
37+
38+
collect_fields requires the "runtime type" of an object. For a field which
39+
returns an Interface or Union type, the "runtime type" will be the actual
40+
Object type returned by that field.
41+
42+
For internal use only.
43+
"""
44+
for selection in selection_set.selections:
45+
if isinstance(selection, FieldNode):
46+
if not should_include_node(variable_values, selection):
47+
continue
48+
name = get_field_entry_key(selection)
49+
fields.setdefault(name, []).append(selection)
50+
elif isinstance(selection, InlineFragmentNode):
51+
if not should_include_node(
52+
variable_values, selection
53+
) or not does_fragment_condition_match(schema, selection, runtime_type):
54+
continue
55+
collect_fields(
56+
schema,
57+
fragments,
58+
variable_values,
59+
runtime_type,
60+
selection.selection_set,
61+
fields,
62+
visited_fragment_names,
63+
)
64+
elif isinstance(selection, FragmentSpreadNode): # pragma: no cover else
65+
frag_name = selection.name.value
66+
if frag_name in visited_fragment_names or not should_include_node(
67+
variable_values, selection
68+
):
69+
continue
70+
visited_fragment_names.add(frag_name)
71+
fragment = fragments.get(frag_name)
72+
if not fragment or not does_fragment_condition_match(
73+
schema, fragment, runtime_type
74+
):
75+
continue
76+
collect_fields(
77+
schema,
78+
fragments,
79+
variable_values,
80+
runtime_type,
81+
fragment.selection_set,
82+
fields,
83+
visited_fragment_names,
84+
)
85+
return fields
86+
87+
88+
def should_include_node(
89+
variable_values: Dict[str, Any],
90+
node: Union[FragmentSpreadNode, FieldNode, InlineFragmentNode],
91+
) -> bool:
92+
"""Check if node should be included
93+
94+
Determines if a field should be included based on the @include and @skip
95+
directives, where @skip has higher precedence than @include.
96+
"""
97+
skip = get_directive_values(GraphQLSkipDirective, node, variable_values)
98+
if skip and skip["if"]:
99+
return False
100+
101+
include = get_directive_values(GraphQLIncludeDirective, node, variable_values)
102+
if include and not include["if"]:
103+
return False
104+
105+
return True
106+
107+
108+
def does_fragment_condition_match(
109+
schema: GraphQLSchema,
110+
fragment: Union[FragmentDefinitionNode, InlineFragmentNode],
111+
type_: GraphQLObjectType,
112+
) -> bool:
113+
"""Determine if a fragment is applicable to the given type."""
114+
type_condition_node = fragment.type_condition
115+
if not type_condition_node:
116+
return True
117+
conditional_type = type_from_ast(schema, type_condition_node)
118+
if conditional_type is type_:
119+
return True
120+
if is_abstract_type(conditional_type):
121+
return schema.is_sub_type(cast(GraphQLAbstractType, conditional_type), type_)
122+
return False
123+
124+
125+
def get_field_entry_key(node: FieldNode) -> str:
126+
"""Implements the logic to compute the key of a given field's entry"""
127+
return node.alias.value if node.alias else node.name.value

src/graphql/execution/execute.py

+15-104
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,8 @@
2121
DocumentNode,
2222
FieldNode,
2323
FragmentDefinitionNode,
24-
FragmentSpreadNode,
25-
InlineFragmentNode,
2624
OperationDefinitionNode,
2725
OperationType,
28-
SelectionSetNode,
2926
)
3027
from ..pyutils import (
3128
inspect,
@@ -37,18 +34,15 @@
3734
Undefined,
3835
)
3936
from ..utilities.get_operation_root_type import get_operation_root_type
40-
from ..utilities.type_from_ast import type_from_ast
4137
from ..type import (
4238
GraphQLAbstractType,
4339
GraphQLField,
44-
GraphQLIncludeDirective,
4540
GraphQLLeafType,
4641
GraphQLList,
4742
GraphQLNonNull,
4843
GraphQLObjectType,
4944
GraphQLOutputType,
5045
GraphQLSchema,
51-
GraphQLSkipDirective,
5246
GraphQLFieldResolver,
5347
GraphQLTypeResolver,
5448
GraphQLResolveInfo,
@@ -62,8 +56,9 @@
6256
is_non_null_type,
6357
is_object_type,
6458
)
59+
from .collect_fields import collect_fields
6560
from .middleware import MiddlewareManager
66-
from .values import get_argument_values, get_directive_values, get_variable_values
61+
from .values import get_argument_values, get_variable_values
6762

6863
__all__ = [
6964
"assert_valid_execution_arguments",
@@ -328,7 +323,15 @@ def execute_operation(
328323
Implements the "Executing operations" section of the spec.
329324
"""
330325
type_ = get_operation_root_type(self.schema, operation)
331-
fields = self.collect_fields(type_, operation.selection_set, {}, set())
326+
fields = collect_fields(
327+
self.schema,
328+
self.fragments,
329+
self.variable_values,
330+
type_,
331+
operation.selection_set,
332+
{},
333+
set(),
334+
)
332335

333336
path = None
334337

@@ -462,96 +465,6 @@ async def get_results() -> Dict[str, Any]:
462465

463466
return get_results()
464467

465-
def collect_fields(
466-
self,
467-
runtime_type: GraphQLObjectType,
468-
selection_set: SelectionSetNode,
469-
fields: Dict[str, List[FieldNode]],
470-
visited_fragment_names: Set[str],
471-
) -> Dict[str, List[FieldNode]]:
472-
"""Collect fields.
473-
474-
Given a selection_set, adds all of the fields in that selection to the passed in
475-
map of fields, and returns it at the end.
476-
477-
collect_fields requires the "runtime type" of an object. For a field which
478-
returns an Interface or Union type, the "runtime type" will be the actual
479-
Object type returned by that field.
480-
481-
For internal use only.
482-
"""
483-
for selection in selection_set.selections:
484-
if isinstance(selection, FieldNode):
485-
if not self.should_include_node(selection):
486-
continue
487-
name = get_field_entry_key(selection)
488-
fields.setdefault(name, []).append(selection)
489-
elif isinstance(selection, InlineFragmentNode):
490-
if not self.should_include_node(
491-
selection
492-
) or not self.does_fragment_condition_match(selection, runtime_type):
493-
continue
494-
self.collect_fields(
495-
runtime_type,
496-
selection.selection_set,
497-
fields,
498-
visited_fragment_names,
499-
)
500-
elif isinstance(selection, FragmentSpreadNode): # pragma: no cover else
501-
frag_name = selection.name.value
502-
if frag_name in visited_fragment_names or not self.should_include_node(
503-
selection
504-
):
505-
continue
506-
visited_fragment_names.add(frag_name)
507-
fragment = self.fragments.get(frag_name)
508-
if not fragment or not self.does_fragment_condition_match(
509-
fragment, runtime_type
510-
):
511-
continue
512-
self.collect_fields(
513-
runtime_type, fragment.selection_set, fields, visited_fragment_names
514-
)
515-
return fields
516-
517-
def should_include_node(
518-
self, node: Union[FragmentSpreadNode, FieldNode, InlineFragmentNode]
519-
) -> bool:
520-
"""Check if node should be included
521-
522-
Determines if a field should be included based on the @include and @skip
523-
directives, where @skip has higher precedence than @include.
524-
"""
525-
skip = get_directive_values(GraphQLSkipDirective, node, self.variable_values)
526-
if skip and skip["if"]:
527-
return False
528-
529-
include = get_directive_values(
530-
GraphQLIncludeDirective, node, self.variable_values
531-
)
532-
if include and not include["if"]:
533-
return False
534-
535-
return True
536-
537-
def does_fragment_condition_match(
538-
self,
539-
fragment: Union[FragmentDefinitionNode, InlineFragmentNode],
540-
type_: GraphQLObjectType,
541-
) -> bool:
542-
"""Determine if a fragment is applicable to the given type."""
543-
type_condition_node = fragment.type_condition
544-
if not type_condition_node:
545-
return True
546-
conditional_type = type_from_ast(self.schema, type_condition_node)
547-
if conditional_type is type_:
548-
return True
549-
if is_abstract_type(conditional_type):
550-
return self.schema.is_sub_type(
551-
cast(GraphQLAbstractType, conditional_type), type_
552-
)
553-
return False
554-
555468
def build_resolve_info(
556469
self,
557470
field_def: GraphQLField,
@@ -1039,7 +952,10 @@ def collect_subfields(
1039952
for field_node in field_nodes:
1040953
selection_set = field_node.selection_set
1041954
if selection_set:
1042-
sub_field_nodes = self.collect_fields(
955+
sub_field_nodes = collect_fields(
956+
self.schema,
957+
self.fragments,
958+
self.variable_values,
1043959
return_type,
1044960
selection_set,
1045961
sub_field_nodes,
@@ -1216,11 +1132,6 @@ def get_field_def(
12161132
return parent_type.fields.get(field_name)
12171133

12181134

1219-
def get_field_entry_key(node: FieldNode) -> str:
1220-
"""Implements the logic to compute the key of a given field's entry"""
1221-
return node.alias.value if node.alias else node.name.value
1222-
1223-
12241135
def invalid_return_type_error(
12251136
return_type: GraphQLObjectType, result: Any, field_nodes: List[FieldNode]
12261137
) -> GraphQLError:

src/graphql/subscription/subscribe.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010

1111
from ..error import GraphQLError, located_error
12+
from ..execution.collect_fields import collect_fields
1213
from ..execution.execute import (
1314
assert_valid_execution_arguments,
1415
execute,
@@ -163,7 +164,15 @@ async def create_source_event_stream(
163164
async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
164165
schema = context.schema
165166
type_ = get_operation_root_type(schema, context.operation)
166-
fields = context.collect_fields(type_, context.operation.selection_set, {}, set())
167+
fields = collect_fields(
168+
schema,
169+
context.fragments,
170+
context.variable_values,
171+
type_,
172+
context.operation.selection_set,
173+
{},
174+
set(),
175+
)
167176
response_name, field_nodes = next(iter(fields.items()))
168177
field_def = get_field_def(schema, type_, field_nodes[0])
169178

src/graphql/validation/rules/single_field_subscriptions.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Dict, cast
22

33
from ...error import GraphQLError
4-
from ...execution import ExecutionContext, default_field_resolver, default_type_resolver
4+
from ...execution.collect_fields import collect_fields
55
from ...language import (
66
FieldNode,
77
FragmentDefinitionNode,
@@ -35,21 +35,14 @@ def enter_operation_definition(
3535
for definition in document.definitions
3636
if isinstance(definition, FragmentDefinitionNode)
3737
}
38-
fake_execution_context = ExecutionContext(
38+
fields = collect_fields(
3939
schema,
4040
fragments,
41-
root_value=None,
42-
context_value=None,
43-
operation=node,
44-
variable_values=variable_values,
45-
field_resolver=default_field_resolver,
46-
type_resolver=default_type_resolver,
47-
errors=[],
48-
middleware_manager=None,
49-
is_awaitable=None,
50-
)
51-
fields = fake_execution_context.collect_fields(
52-
subscription_type, node.selection_set, {}, set()
41+
variable_values,
42+
subscription_type,
43+
node.selection_set,
44+
{},
45+
set(),
5346
)
5447
if len(fields) > 1:
5548
field_selection_lists = list(fields.values())

0 commit comments

Comments
 (0)