-
-
Notifications
You must be signed in to change notification settings - Fork 138
/
Copy pathtest_customize.py
152 lines (126 loc) · 4.57 KB
/
test_customize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from inspect import isasyncgen
import pytest
from graphql.execution import ExecutionContext, execute, subscribe
from graphql.language import parse
from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString
try:
anext # noqa: B018
except NameError: # pragma: no cover (Python < 3.10)
# noinspection PyShadowingBuiltins
async def anext(iterator):
"""Return the next item from an async iterator."""
return await iterator.__anext__()
def describe_customize_execution():
def uses_a_custom_field_resolver():
query = parse("{ foo }")
schema = GraphQLSchema(
GraphQLObjectType("Query", {"foo": GraphQLField(GraphQLString)})
)
# For the purposes of test, just return the name of the field!
def custom_resolver(_source, info, **_args):
return info.field_name
assert execute(schema, query, field_resolver=custom_resolver) == (
{"foo": "foo"},
None,
)
def uses_a_custom_execution_context_class():
query = parse("{ foo }")
schema = GraphQLSchema(
GraphQLObjectType(
"Query",
{"foo": GraphQLField(GraphQLString, resolve=lambda *_args: "bar")},
)
)
class TestExecutionContext(ExecutionContext):
def __init__(self, *args, **kwargs):
assert kwargs.pop("custom_arg", None) == "baz"
super().__init__(*args, **kwargs)
def execute_field(
self,
parent_type,
source,
field_group,
path,
incremental_data_record,
defer_map,
):
result = super().execute_field(
parent_type,
source,
field_group,
path,
incremental_data_record,
defer_map,
)
return result * 2 # type: ignore
assert execute(
schema,
query,
execution_context_class=TestExecutionContext,
custom_arg="baz",
) == (
{"foo": "barbar"},
None,
)
def describe_customize_subscription():
@pytest.mark.asyncio
async def uses_a_custom_subscribe_field_resolver():
schema = GraphQLSchema(
query=GraphQLObjectType("Query", {"foo": GraphQLField(GraphQLString)}),
subscription=GraphQLObjectType(
"Subscription", {"foo": GraphQLField(GraphQLString)}
),
)
class Root:
@staticmethod
async def custom_foo():
yield {"foo": "FooValue"}
subscription = subscribe(
schema,
document=parse("subscription { foo }"),
root_value=Root(),
subscribe_field_resolver=lambda root, _info: root.custom_foo(),
)
assert isasyncgen(subscription)
assert await anext(subscription) == (
{"foo": "FooValue"},
None,
)
await subscription.aclose()
@pytest.mark.asyncio
async def uses_a_custom_execution_context_class():
class TestExecutionContext(ExecutionContext):
def __init__(self, *args, **kwargs):
assert kwargs.pop("custom_arg", None) == "baz"
super().__init__(*args, **kwargs)
def build_resolve_info(self, *args, **kwargs):
resolve_info = super().build_resolve_info(*args, **kwargs)
resolve_info.context["foo"] = "bar"
return resolve_info
async def generate_foo(_obj, info):
yield info.context["foo"]
def resolve_foo(message, _info):
return message
schema = GraphQLSchema(
query=GraphQLObjectType("Query", {"foo": GraphQLField(GraphQLString)}),
subscription=GraphQLObjectType(
"Subscription",
{
"foo": GraphQLField(
GraphQLString,
resolve=resolve_foo,
subscribe=generate_foo,
)
},
),
)
document = parse("subscription { foo }")
subscription = subscribe(
schema,
document,
context_value={},
execution_context_class=TestExecutionContext,
custom_arg="baz",
)
assert isasyncgen(subscription)
assert await anext(subscription) == ({"foo": "bar"}, None)