Skip to content

Commit b7a18ed

Browse files
committed
Implement OneOf Input Objects via @OneOf directive
Replicates graphql/graphql-js@8cfa3de
1 parent 6e6d5be commit b7a18ed

23 files changed

+720
-13
lines changed

src/graphql/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@
259259
GraphQLStreamDirective,
260260
GraphQLDeprecatedDirective,
261261
GraphQLSpecifiedByDirective,
262+
GraphQLOneOfDirective,
262263
# "Enum" of Type Kinds
263264
TypeKind,
264265
# Constant Deprecation Reason
@@ -504,6 +505,7 @@
504505
"GraphQLStreamDirective",
505506
"GraphQLDeprecatedDirective",
506507
"GraphQLSpecifiedByDirective",
508+
"GraphQLOneOfDirective",
507509
"TypeKind",
508510
"DEFAULT_DEPRECATION_REASON",
509511
"introspection_types",

src/graphql/execution/values.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,20 @@ def coerce_variable_values(
128128
continue
129129

130130
def on_input_value_error(
131-
path: list[str | int], invalid_value: Any, error: GraphQLError
131+
path: list[str | int],
132+
invalid_value: Any,
133+
error: GraphQLError,
134+
var_name: str = var_name,
135+
var_def_node: VariableDefinitionNode = var_def_node,
132136
) -> None:
133137
invalid_str = inspect(invalid_value)
134-
prefix = f"Variable '${var_name}' got invalid value {invalid_str}" # noqa: B023
138+
prefix = f"Variable '${var_name}' got invalid value {invalid_str}"
135139
if path:
136-
prefix += f" at '{var_name}{print_path_list(path)}'" # noqa: B023
140+
prefix += f" at '{var_name}{print_path_list(path)}'"
137141
on_error(
138142
GraphQLError(
139143
prefix + "; " + error.message,
140-
var_def_node, # noqa: B023
144+
var_def_node,
141145
original_error=error,
142146
)
143147
)
@@ -193,7 +197,8 @@ def get_argument_values(
193197
)
194198
raise GraphQLError(msg, value_node)
195199
continue # pragma: no cover
196-
is_null = variable_values[variable_name] is None
200+
variable_value = variable_values[variable_name]
201+
is_null = variable_value is None or variable_value is Undefined
197202

198203
if is_null and is_non_null_type(arg_type):
199204
msg = f"Argument '{name}' of non-null type '{arg_type}' must not be null."

src/graphql/type/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
GraphQLStreamDirective,
138138
GraphQLDeprecatedDirective,
139139
GraphQLSpecifiedByDirective,
140+
GraphQLOneOfDirective,
140141
# Keyword Args
141142
GraphQLDirectiveKwargs,
142143
# Constant Deprecation Reason
@@ -286,6 +287,7 @@
286287
"GraphQLStreamDirective",
287288
"GraphQLDeprecatedDirective",
288289
"GraphQLSpecifiedByDirective",
290+
"GraphQLOneOfDirective",
289291
"GraphQLDirectiveKwargs",
290292
"DEFAULT_DEPRECATION_REASON",
291293
"is_specified_scalar_type",

src/graphql/type/definition.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,7 @@ class GraphQLInputObjectTypeKwargs(GraphQLNamedTypeKwargs, total=False):
12721272

12731273
fields: GraphQLInputFieldMap
12741274
out_type: GraphQLInputFieldOutType | None
1275+
is_one_of: bool
12751276

12761277

12771278
class GraphQLInputObjectType(GraphQLNamedType):
@@ -1301,6 +1302,7 @@ class GeoPoint(GraphQLInputObjectType):
13011302

13021303
ast_node: InputObjectTypeDefinitionNode | None
13031304
extension_ast_nodes: tuple[InputObjectTypeExtensionNode, ...]
1305+
is_one_of: bool
13041306

13051307
def __init__(
13061308
self,
@@ -1311,6 +1313,7 @@ def __init__(
13111313
extensions: dict[str, Any] | None = None,
13121314
ast_node: InputObjectTypeDefinitionNode | None = None,
13131315
extension_ast_nodes: Collection[InputObjectTypeExtensionNode] | None = None,
1316+
is_one_of: bool = False,
13141317
) -> None:
13151318
super().__init__(
13161319
name=name,
@@ -1322,6 +1325,7 @@ def __init__(
13221325
self._fields = fields
13231326
if out_type is not None:
13241327
self.out_type = out_type # type: ignore
1328+
self.is_one_of = is_one_of
13251329

13261330
@staticmethod
13271331
def out_type(value: dict[str, Any]) -> Any:
@@ -1340,6 +1344,7 @@ def to_kwargs(self) -> GraphQLInputObjectTypeKwargs:
13401344
out_type=None
13411345
if self.out_type is GraphQLInputObjectType.out_type
13421346
else self.out_type,
1347+
is_one_of=self.is_one_of,
13431348
)
13441349

13451350
def __copy__(self) -> GraphQLInputObjectType: # pragma: no cover

src/graphql/type/directives.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,20 @@ def assert_directive(directive: Any) -> GraphQLDirective:
261261
description="Exposes a URL that specifies the behaviour of this scalar.",
262262
)
263263

264+
# Used to declare an Input Object as a OneOf Input Objects.
265+
GraphQLOneOfDirective = GraphQLDirective(
266+
name="oneOf",
267+
locations=[DirectiveLocation.INPUT_OBJECT],
268+
args={},
269+
description="Indicates an Input Object is a OneOf Input Object.",
270+
)
271+
264272
specified_directives: tuple[GraphQLDirective, ...] = (
265273
GraphQLIncludeDirective,
266274
GraphQLSkipDirective,
267275
GraphQLDeprecatedDirective,
268276
GraphQLSpecifiedByDirective,
277+
GraphQLOneOfDirective,
269278
)
270279
"""A tuple with all directives from the GraphQL specification"""
271280

src/graphql/type/introspection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def __new__(cls):
305305
resolve=cls.input_fields,
306306
),
307307
"ofType": GraphQLField(_Type, resolve=cls.of_type),
308+
"isOneOf": GraphQLField(GraphQLBoolean, resolve=cls.is_one_of),
308309
}
309310

310311
@staticmethod
@@ -396,6 +397,10 @@ def input_fields(type_, _info, includeDeprecated=False):
396397
def of_type(type_, _info):
397398
return getattr(type_, "of_type", None)
398399

400+
@staticmethod
401+
def is_one_of(type_, _info):
402+
return type_.is_one_of if is_input_object_type(type_) else None
403+
399404

400405
_Type: GraphQLObjectType = GraphQLObjectType(
401406
name="__Type",

src/graphql/type/validate.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
SchemaDefinitionNode,
1717
SchemaExtensionNode,
1818
)
19-
from ..pyutils import and_list, inspect
19+
from ..pyutils import Undefined, and_list, inspect
2020
from ..utilities.type_comparators import is_equal_type, is_type_sub_type_of
2121
from .definition import (
2222
GraphQLEnumType,
@@ -482,6 +482,28 @@ def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
482482
],
483483
)
484484

485+
if input_obj.is_one_of:
486+
self.validate_one_of_input_object_field(input_obj, field_name, field)
487+
488+
def validate_one_of_input_object_field(
489+
self,
490+
type_: GraphQLInputObjectType,
491+
field_name: str,
492+
field: GraphQLInputField,
493+
) -> None:
494+
if is_non_null_type(field.type):
495+
self.report_error(
496+
f"OneOf input field {type_.name}.{field_name} must be nullable.",
497+
field.ast_node and field.ast_node.type,
498+
)
499+
500+
if field.default_value is not Undefined:
501+
self.report_error(
502+
f"OneOf input field {type_.name}.{field_name}"
503+
" cannot have a default value.",
504+
field.ast_node,
505+
)
506+
485507

486508
def get_operation_type_node(
487509
schema: GraphQLSchema, operation: OperationType

src/graphql/utilities/coerce_input_value.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,30 @@ def coerce_input_value(
130130
+ did_you_mean(suggestions)
131131
),
132132
)
133+
134+
if type_.is_one_of:
135+
keys = list(coerced_dict)
136+
if len(keys) != 1:
137+
on_error(
138+
path.as_list() if path else [],
139+
input_value,
140+
GraphQLError(
141+
"Exactly one key must be specified"
142+
f" for OneOf type '{type_.name}'.",
143+
),
144+
)
145+
else:
146+
key = keys[0]
147+
value = coerced_dict[key]
148+
if value is None:
149+
on_error(
150+
(path.as_list() if path else []) + [key],
151+
value,
152+
GraphQLError(
153+
f"Field '{key}' must be non-null.",
154+
),
155+
)
156+
133157
return type_.out_type(coerced_dict)
134158

135159
if is_leaf_type(type_):

src/graphql/utilities/extend_schema.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
GraphQLNullableType,
6666
GraphQLObjectType,
6767
GraphQLObjectTypeKwargs,
68+
GraphQLOneOfDirective,
6869
GraphQLOutputType,
6970
GraphQLScalarType,
7071
GraphQLSchema,
@@ -777,6 +778,7 @@ def build_input_object_type(
777778
fields=partial(self.build_input_field_map, all_nodes),
778779
ast_node=ast_node,
779780
extension_ast_nodes=extension_nodes,
781+
is_one_of=is_one_of(ast_node),
780782
)
781783

782784
def build_type(self, ast_node: TypeDefinitionNode) -> GraphQLNamedType:
@@ -822,3 +824,10 @@ def get_specified_by_url(
822824

823825
specified_by_url = get_directive_values(GraphQLSpecifiedByDirective, node)
824826
return specified_by_url["url"] if specified_by_url else None
827+
828+
829+
def is_one_of(node: InputObjectTypeDefinitionNode) -> bool:
830+
"""Given an input object node, returns if the node should be OneOf."""
831+
from ..execution import get_directive_values
832+
833+
return get_directive_values(GraphQLOneOfDirective, node) is not None

src/graphql/utilities/value_from_ast.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ def value_from_ast(
118118
return Undefined
119119
coerced_obj[field.out_name or field_name] = field_value
120120

121+
if type_.is_one_of:
122+
keys = list(coerced_obj)
123+
if len(keys) != 1:
124+
return Undefined
125+
126+
if coerced_obj[keys[0]] is None:
127+
return Undefined
128+
121129
return type_.out_type(coerced_obj)
122130

123131
if is_leaf_type(type_):

src/graphql/validation/rules/values_of_correct_type.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, cast
5+
from typing import Any, Mapping, cast
66

77
from ...error import GraphQLError
88
from ...language import (
@@ -12,16 +12,20 @@
1212
FloatValueNode,
1313
IntValueNode,
1414
ListValueNode,
15+
NonNullTypeNode,
1516
NullValueNode,
1617
ObjectFieldNode,
1718
ObjectValueNode,
1819
StringValueNode,
1920
ValueNode,
21+
VariableDefinitionNode,
22+
VariableNode,
2023
VisitorAction,
2124
print_ast,
2225
)
2326
from ...pyutils import Undefined, did_you_mean, suggestion_list
2427
from ...type import (
28+
GraphQLInputObjectType,
2529
GraphQLScalarType,
2630
get_named_type,
2731
get_nullable_type,
@@ -31,7 +35,7 @@
3135
is_non_null_type,
3236
is_required_input_field,
3337
)
34-
from . import ValidationRule
38+
from . import ValidationContext, ValidationRule
3539

3640
__all__ = ["ValuesOfCorrectTypeRule"]
3741

@@ -45,6 +49,18 @@ class ValuesOfCorrectTypeRule(ValidationRule):
4549
See https://spec.graphql.org/draft/#sec-Values-of-Correct-Type
4650
"""
4751

52+
def __init__(self, context: ValidationContext) -> None:
53+
super().__init__(context)
54+
self.variable_definitions: dict[str, VariableDefinitionNode] = {}
55+
56+
def enter_operation_definition(self, *_args: Any) -> None:
57+
self.variable_definitions.clear()
58+
59+
def enter_variable_definition(
60+
self, definition: VariableDefinitionNode, *_args: Any
61+
) -> None:
62+
self.variable_definitions[definition.variable.name.value] = definition
63+
4864
def enter_list_value(self, node: ListValueNode, *_args: Any) -> VisitorAction:
4965
# Note: TypeInfo will traverse into a list's item type, so look to the parent
5066
# input type to check if it is a list.
@@ -72,6 +88,10 @@ def enter_object_value(self, node: ObjectValueNode, *_args: Any) -> VisitorActio
7288
node,
7389
)
7490
)
91+
if type_.is_one_of:
92+
validate_one_of_input_object(
93+
self.context, node, type_, field_node_map, self.variable_definitions
94+
)
7595
return None
7696

7797
def enter_object_field(self, node: ObjectFieldNode, *_args: Any) -> None:
@@ -162,3 +182,51 @@ def is_valid_value_node(self, node: ValueNode) -> None:
162182
)
163183

164184
return
185+
186+
187+
def validate_one_of_input_object(
188+
context: ValidationContext,
189+
node: ObjectValueNode,
190+
type_: GraphQLInputObjectType,
191+
field_node_map: Mapping[str, ObjectFieldNode],
192+
variable_definitions: dict[str, VariableDefinitionNode],
193+
) -> None:
194+
keys = list(field_node_map)
195+
is_not_exactly_one_filed = len(keys) != 1
196+
197+
if is_not_exactly_one_filed:
198+
context.report_error(
199+
GraphQLError(
200+
f"OneOf Input Object '{type_.name}' must specify exactly one key.",
201+
node,
202+
)
203+
)
204+
return
205+
206+
object_field_node = field_node_map.get(keys[0])
207+
value = object_field_node.value if object_field_node else None
208+
is_null_literal = not value or isinstance(value, NullValueNode)
209+
210+
if is_null_literal:
211+
context.report_error(
212+
GraphQLError(
213+
f"Field '{type_.name}.{keys[0]}' must be non-null.",
214+
node,
215+
)
216+
)
217+
return
218+
219+
is_variable = value and isinstance(value, VariableNode)
220+
if is_variable:
221+
variable_name = cast(VariableNode, value).name.value
222+
definition = variable_definitions[variable_name]
223+
is_nullable_variable = not isinstance(definition.type, NonNullTypeNode)
224+
225+
if is_nullable_variable:
226+
context.report_error(
227+
GraphQLError(
228+
f"Variable '{variable_name}' must be non-nullable"
229+
f" to be used for OneOf Input Object '{type_.name}'.",
230+
node,
231+
)
232+
)

0 commit comments

Comments
 (0)