Skip to content

Commit 87551f5

Browse files
committed
Support returning async iterables from resolver functions
Replicates graphql/graphql-js@59c87c3
1 parent dd08366 commit 87551f5

File tree

3 files changed

+299
-56
lines changed

3 files changed

+299
-56
lines changed

src/graphql/execution/execute.py

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@
6464
from .values import get_argument_values, get_variable_values
6565

6666

67+
try: # pragma: no cover
68+
anext
69+
except NameError: # pragma: no cover (Python < 3.10)
70+
# noinspection PyShadowingBuiltins
71+
async def anext(iterator: AsyncIterator) -> Any:
72+
"""Return the next item from an async iterator."""
73+
return await iterator.__anext__()
74+
75+
6776
__all__ = [
6877
"create_source_event_stream",
6978
"default_field_resolver",
@@ -684,6 +693,67 @@ def complete_value(
684693
f" '{inspect(return_type)}'."
685694
)
686695

696+
async def complete_async_iterator_value(
697+
self,
698+
item_type: GraphQLOutputType,
699+
field_nodes: List[FieldNode],
700+
info: GraphQLResolveInfo,
701+
path: Path,
702+
iterator: AsyncIterator[Any],
703+
) -> List[Any]:
704+
"""Complete an async iterator.
705+
706+
Complete a async iterator value by completing the result and calling
707+
recursively until all the results are completed.
708+
"""
709+
is_awaitable = self.is_awaitable
710+
awaitable_indices: List[int] = []
711+
append_awaitable = awaitable_indices.append
712+
completed_results: List[Any] = []
713+
append_result = completed_results.append
714+
index = 0
715+
while True:
716+
field_path = path.add_key(index, None)
717+
try:
718+
try:
719+
value = await anext(iterator)
720+
except StopAsyncIteration:
721+
break
722+
try:
723+
completed_item = self.complete_value(
724+
item_type, field_nodes, info, field_path, value
725+
)
726+
if is_awaitable(completed_item):
727+
append_awaitable(index)
728+
append_result(completed_item)
729+
except Exception as raw_error:
730+
append_result(None)
731+
error = located_error(raw_error, field_nodes, field_path.as_list())
732+
self.handle_field_error(error, item_type)
733+
except Exception as raw_error:
734+
append_result(None)
735+
error = located_error(raw_error, field_nodes, field_path.as_list())
736+
self.handle_field_error(error, item_type)
737+
break
738+
index += 1
739+
740+
if not awaitable_indices:
741+
return completed_results
742+
743+
if len(awaitable_indices) == 1:
744+
# If there is only one index, avoid the overhead of parallelization.
745+
index = awaitable_indices[0]
746+
completed_results[index] = await completed_results[index]
747+
else:
748+
for index, result in zip(
749+
awaitable_indices,
750+
await gather(
751+
*(completed_results[index] for index in awaitable_indices)
752+
),
753+
):
754+
completed_results[index] = result
755+
return completed_results
756+
687757
def complete_list_value(
688758
self,
689759
return_type: GraphQLList[GraphQLOutputType],
@@ -696,20 +766,16 @@ def complete_list_value(
696766
697767
Complete a list value by completing each item in the list with the inner type.
698768
"""
699-
if not is_iterable(result):
700-
# experimental: allow async iterables
701-
if isinstance(result, AsyncIterable):
702-
# noinspection PyShadowingNames
703-
async def async_iterable_to_list(
704-
async_result: AsyncIterable[Any],
705-
) -> Any:
706-
sync_result = [item async for item in async_result]
707-
return self.complete_list_value(
708-
return_type, field_nodes, info, path, sync_result
709-
)
769+
item_type = return_type.of_type
710770

711-
return async_iterable_to_list(result)
771+
if isinstance(result, AsyncIterable):
772+
iterator = result.__aiter__()
712773

774+
return self.complete_async_iterator_value(
775+
item_type, field_nodes, info, path, iterator
776+
)
777+
778+
if not is_iterable(result):
713779
raise GraphQLError(
714780
"Expected Iterable, but did not find one for field"
715781
f" '{info.parent_type.name}.{info.field_name}'."
@@ -718,7 +784,6 @@ async def async_iterable_to_list(
718784
# This is specified as a simple map, however we're optimizing the path where
719785
# the list contains no coroutine objects by avoiding creating another coroutine
720786
# object.
721-
item_type = return_type.of_type
722787
is_awaitable = self.is_awaitable
723788
awaitable_indices: List[int] = []
724789
append_awaitable = awaitable_indices.append
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import asyncio
2+
from inspect import isawaitable
3+
4+
from graphql import ExecutionResult, build_schema, execute, parse
5+
6+
7+
schema = build_schema("type Query { listField: [String] }")
8+
document = parse("{ listField }")
9+
10+
11+
class Data:
12+
# noinspection PyPep8Naming
13+
@staticmethod
14+
async def listField(info_):
15+
for index in range(1000):
16+
yield index
17+
18+
19+
async def execute_async() -> ExecutionResult:
20+
result = execute(schema, document, Data())
21+
assert isawaitable(result)
22+
return await result
23+
24+
25+
def test_execute_async_iterable_list_field(benchmark):
26+
# Note: we are creating the async loop outside of the benchmark code so that
27+
# the setup is not included in the benchmark timings
28+
loop = asyncio.events.new_event_loop()
29+
asyncio.events.set_event_loop(loop)
30+
result = benchmark(lambda: loop.run_until_complete(execute_async()))
31+
asyncio.events.set_event_loop(None)
32+
loop.close()
33+
assert not result.errors
34+
assert result.data == {"listField": [str(index) for index in range(1000)]}

0 commit comments

Comments
 (0)