Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit db05a21

Browse files
committedJun 17, 2025··
simplify CollectFields for @defer and @stream
Replicates graphql/graphql-js@2aedf25
1 parent e7da373 commit db05a21

File tree

6 files changed

+369
-408
lines changed

6 files changed

+369
-408
lines changed
 

‎docs/conf.py

Lines changed: 77 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -138,79 +138,77 @@
138138
}
139139

140140
# ignore the following undocumented or internal references:
141-
ignore_references = set(
142-
[
143-
"GNT",
144-
"GT",
145-
"KT",
146-
"T",
147-
"VT",
148-
"TContext",
149-
"Enum",
150-
"traceback",
151-
"types.TracebackType",
152-
"TypeMap",
153-
"AwaitableOrValue",
154-
"DeferredFragmentRecord",
155-
"DeferUsage",
156-
"EnterLeaveVisitor",
157-
"ExperimentalIncrementalExecutionResults",
158-
"FieldGroup",
159-
"FormattedIncrementalResult",
160-
"FormattedPendingResult",
161-
"FormattedSourceLocation",
162-
"GraphQLAbstractType",
163-
"GraphQLCompositeType",
164-
"GraphQLEnumValueMap",
165-
"GraphQLErrorExtensions",
166-
"GraphQLFieldResolver",
167-
"GraphQLInputType",
168-
"GraphQLNullableType",
169-
"GraphQLOutputType",
170-
"GraphQLTypeResolver",
171-
"GroupedFieldSet",
172-
"IncrementalDataRecord",
173-
"IncrementalResult",
174-
"InitialResultRecord",
175-
"Middleware",
176-
"PendingResult",
177-
"StreamItemsRecord",
178-
"StreamRecord",
179-
"SubsequentDataRecord",
180-
"asyncio.events.AbstractEventLoop",
181-
"collections.abc.MutableMapping",
182-
"collections.abc.MutableSet",
183-
"enum.Enum",
184-
"graphql.execution.collect_fields.DeferUsage",
185-
"graphql.execution.collect_fields.CollectFieldsResult",
186-
"graphql.execution.collect_fields.FieldGroup",
187-
"graphql.execution.execute.StreamArguments",
188-
"graphql.execution.execute.StreamUsage",
189-
"graphql.execution.map_async_iterable.map_async_iterable",
190-
"graphql.execution.incremental_publisher.CompletedResult",
191-
"graphql.execution.incremental_publisher.DeferredFragmentRecord",
192-
"graphql.execution.incremental_publisher.DeferredGroupedFieldSetRecord",
193-
"graphql.execution.incremental_publisher.FormattedCompletedResult",
194-
"graphql.execution.incremental_publisher.FormattedPendingResult",
195-
"graphql.execution.incremental_publisher.IncrementalPublisher",
196-
"graphql.execution.incremental_publisher.InitialResultRecord",
197-
"graphql.execution.incremental_publisher.PendingResult",
198-
"graphql.execution.incremental_publisher.StreamItemsRecord",
199-
"graphql.execution.incremental_publisher.StreamRecord",
200-
"graphql.execution.Middleware",
201-
"graphql.language.lexer.EscapeSequence",
202-
"graphql.language.visitor.EnterLeaveVisitor",
203-
"graphql.pyutils.ref_map.K",
204-
"graphql.pyutils.ref_map.V",
205-
"graphql.type.definition.GT_co",
206-
"graphql.type.definition.GNT_co",
207-
"graphql.type.definition.TContext",
208-
"graphql.type.schema.InterfaceImplementations",
209-
"graphql.validation.validation_context.VariableUsage",
210-
"graphql.validation.rules.known_argument_names.KnownArgumentNamesOnDirectivesRule",
211-
"graphql.validation.rules.provided_required_arguments.ProvidedRequiredArgumentsOnDirectivesRule",
212-
]
213-
)
141+
ignore_references = {
142+
"GNT",
143+
"GT",
144+
"KT",
145+
"T",
146+
"VT",
147+
"TContext",
148+
"Enum",
149+
"traceback",
150+
"types.TracebackType",
151+
"TypeMap",
152+
"AwaitableOrValue",
153+
"DeferredFragmentRecord",
154+
"DeferUsage",
155+
"EnterLeaveVisitor",
156+
"ExperimentalIncrementalExecutionResults",
157+
"FieldGroup",
158+
"FormattedIncrementalResult",
159+
"FormattedPendingResult",
160+
"FormattedSourceLocation",
161+
"GraphQLAbstractType",
162+
"GraphQLCompositeType",
163+
"GraphQLEnumValueMap",
164+
"GraphQLErrorExtensions",
165+
"GraphQLFieldResolver",
166+
"GraphQLInputType",
167+
"GraphQLNullableType",
168+
"GraphQLOutputType",
169+
"GraphQLTypeResolver",
170+
"GroupedFieldSet",
171+
"IncrementalDataRecord",
172+
"IncrementalResult",
173+
"InitialResultRecord",
174+
"Middleware",
175+
"PendingResult",
176+
"StreamItemsRecord",
177+
"StreamRecord",
178+
"SubsequentDataRecord",
179+
"asyncio.events.AbstractEventLoop",
180+
"collections.abc.MutableMapping",
181+
"collections.abc.MutableSet",
182+
"enum.Enum",
183+
"graphql.execution.build_field_plan.FieldGroup",
184+
"graphql.execution.build_field_plan.FieldPlan",
185+
"graphql.execution.collect_fields.DeferUsage",
186+
"graphql.execution.execute.StreamArguments",
187+
"graphql.execution.execute.StreamUsage",
188+
"graphql.execution.map_async_iterable.map_async_iterable",
189+
"graphql.execution.incremental_publisher.CompletedResult",
190+
"graphql.execution.incremental_publisher.DeferredFragmentRecord",
191+
"graphql.execution.incremental_publisher.DeferredGroupedFieldSetRecord",
192+
"graphql.execution.incremental_publisher.FormattedCompletedResult",
193+
"graphql.execution.incremental_publisher.FormattedPendingResult",
194+
"graphql.execution.incremental_publisher.IncrementalPublisher",
195+
"graphql.execution.incremental_publisher.InitialResultRecord",
196+
"graphql.execution.incremental_publisher.PendingResult",
197+
"graphql.execution.incremental_publisher.StreamItemsRecord",
198+
"graphql.execution.incremental_publisher.StreamRecord",
199+
"graphql.execution.Middleware",
200+
"graphql.language.lexer.EscapeSequence",
201+
"graphql.language.visitor.EnterLeaveVisitor",
202+
"graphql.pyutils.ref_map.K",
203+
"graphql.pyutils.ref_map.V",
204+
"graphql.type.definition.GT_co",
205+
"graphql.type.definition.GNT_co",
206+
"graphql.type.definition.TContext",
207+
"graphql.type.schema.InterfaceImplementations",
208+
"graphql.validation.validation_context.VariableUsage",
209+
"graphql.validation.rules.known_argument_names.KnownArgumentNamesOnDirectivesRule",
210+
"graphql.validation.rules.provided_required_arguments.ProvidedRequiredArgumentsOnDirectivesRule",
211+
}
214212

215213
ignore_references.update(__builtins__.keys())
216214

@@ -228,10 +226,12 @@ def on_missing_reference(app, env, node, contnode):
228226
name = target.rsplit(".", 1)[-1]
229227
if name in ("GT", "GNT", "KT", "T", "VT"):
230228
return contnode
231-
if typ == "obj":
232-
if target.startswith("typing."):
233-
if name in ("Any", "Optional", "Union"):
234-
return contnode
229+
if (
230+
typ == "obj"
231+
and target.startswith("typing.")
232+
and name in ("Any", "Optional", "Union")
233+
):
234+
return contnode
235235
if typ != "class":
236236
return None
237237
if "." in target: # maybe too specific
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""Build field plan"""
2+
3+
from __future__ import annotations
4+
5+
import sys
6+
from typing import TYPE_CHECKING, Dict, NamedTuple
7+
8+
from ..pyutils import RefMap, RefSet
9+
from .collect_fields import DeferUsage, FieldDetails
10+
11+
if TYPE_CHECKING:
12+
from ..language import FieldNode
13+
14+
try:
15+
from typing import TypeAlias
16+
except ImportError: # Python < 3.10
17+
from typing_extensions import TypeAlias
18+
19+
__all__ = [
20+
"DeferUsageSet",
21+
"FieldGroup",
22+
"FieldPlan",
23+
"GroupedFieldSet",
24+
"NewGroupedFieldSetDetails",
25+
"build_field_plan",
26+
]
27+
28+
29+
DeferUsageSet: TypeAlias = RefSet[DeferUsage]
30+
31+
32+
class FieldGroup(NamedTuple):
33+
"""A group of fields with defer usages."""
34+
35+
fields: list[FieldDetails]
36+
defer_usages: DeferUsageSet | None = None
37+
known_defer_usages: DeferUsageSet | None = None
38+
39+
def to_nodes(self) -> list[FieldNode]:
40+
"""Return the field nodes in this group."""
41+
return [field_details.node for field_details in self.fields]
42+
43+
44+
if sys.version_info < (3, 9):
45+
GroupedFieldSet: TypeAlias = Dict[str, FieldGroup]
46+
else: # Python >= 3.9
47+
GroupedFieldSet: TypeAlias = dict[str, FieldGroup]
48+
49+
50+
class NewGroupedFieldSetDetails(NamedTuple):
51+
"""Details of a new grouped field set."""
52+
53+
grouped_field_set: GroupedFieldSet
54+
should_initiate_defer: bool
55+
56+
57+
class FieldPlan(NamedTuple):
58+
"""A plan for executing fields."""
59+
60+
grouped_field_set: GroupedFieldSet
61+
new_grouped_field_set_details_map: RefMap[DeferUsageSet, NewGroupedFieldSetDetails]
62+
new_defer_usages: list[DeferUsage]
63+
64+
65+
def build_field_plan(
66+
fields: dict[str, list[FieldDetails]],
67+
parent_defer_usages: DeferUsageSet | None = None,
68+
known_defer_usages: DeferUsageSet | None = None,
69+
) -> FieldPlan:
70+
"""Build a plan for executing fields."""
71+
if parent_defer_usages is None:
72+
parent_defer_usages = RefSet()
73+
if known_defer_usages is None:
74+
known_defer_usages = RefSet()
75+
76+
new_defer_usages: RefSet[DeferUsage] = RefSet()
77+
new_known_defer_usages: RefSet[DeferUsage] = RefSet(known_defer_usages)
78+
79+
grouped_field_set: GroupedFieldSet = {}
80+
81+
new_grouped_field_set_details_map: RefMap[
82+
DeferUsageSet, NewGroupedFieldSetDetails
83+
] = RefMap()
84+
85+
map_: dict[str, tuple[DeferUsageSet, list[FieldDetails]]] = {}
86+
87+
for response_key, field_details_list in fields.items():
88+
defer_usage_set: RefSet[DeferUsage] = RefSet()
89+
in_original_result = False
90+
for field_details in field_details_list:
91+
defer_usage = field_details.defer_usage
92+
if defer_usage is None:
93+
in_original_result = True
94+
continue
95+
defer_usage_set.add(defer_usage)
96+
if defer_usage not in known_defer_usages:
97+
new_defer_usages.add(defer_usage)
98+
new_known_defer_usages.add(defer_usage)
99+
if in_original_result:
100+
defer_usage_set.clear()
101+
else:
102+
defer_usage_set -= {
103+
defer_usage
104+
for defer_usage in defer_usage_set
105+
if any(
106+
ancestor in defer_usage_set for ancestor in defer_usage.ancestors
107+
)
108+
}
109+
map_[response_key] = (defer_usage_set, field_details_list)
110+
111+
for response_key, [defer_usage_set, field_details_list] in map_.items():
112+
if defer_usage_set == parent_defer_usages:
113+
field_group = grouped_field_set.get(response_key)
114+
if field_group is None: # pragma: no cover else
115+
field_group = FieldGroup([], defer_usage_set, new_known_defer_usages)
116+
grouped_field_set[response_key] = field_group
117+
field_group.fields.extend(field_details_list)
118+
continue
119+
120+
for (
121+
new_grouped_field_set_defer_usage_set,
122+
new_grouped_field_set_details,
123+
) in new_grouped_field_set_details_map.items():
124+
if new_grouped_field_set_defer_usage_set == defer_usage_set:
125+
new_grouped_field_set = new_grouped_field_set_details.grouped_field_set
126+
break
127+
else:
128+
new_grouped_field_set = {}
129+
new_grouped_field_set_details = NewGroupedFieldSetDetails(
130+
new_grouped_field_set,
131+
any(
132+
defer_usage not in parent_defer_usages
133+
for defer_usage in defer_usage_set
134+
),
135+
)
136+
new_grouped_field_set_details_map[defer_usage_set] = (
137+
new_grouped_field_set_details
138+
)
139+
140+
field_group = new_grouped_field_set.get(response_key)
141+
if field_group is None: # pragma: no cover else
142+
field_group = FieldGroup([], defer_usage_set, new_known_defer_usages)
143+
new_grouped_field_set[response_key] = field_group
144+
field_group.fields.extend(field_details_list)
145+
146+
return FieldPlan(
147+
grouped_field_set, new_grouped_field_set_details_map, list(new_defer_usages)
148+
)

‎src/graphql/execution/collect_fields.py

Lines changed: 67 additions & 255 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from __future__ import annotations
44

5-
import sys
6-
from typing import Any, Dict, NamedTuple, Union, cast
5+
from collections import defaultdict
6+
from typing import Any, NamedTuple
77

88
from ..language import (
99
FieldNode,
@@ -14,7 +14,6 @@
1414
OperationType,
1515
SelectionSetNode,
1616
)
17-
from ..pyutils import RefMap, RefSet
1817
from ..type import (
1918
GraphQLDeferDirective,
2019
GraphQLIncludeDirective,
@@ -26,81 +25,37 @@
2625
from ..utilities.type_from_ast import type_from_ast
2726
from .values import get_directive_values
2827

29-
try:
30-
from typing import TypeAlias
31-
except ImportError: # Python < 3.10
32-
from typing_extensions import TypeAlias
33-
34-
3528
__all__ = [
36-
"NON_DEFERRED_TARGET_SET",
3729
"CollectFieldsContext",
38-
"CollectFieldsResult",
3930
"DeferUsage",
40-
"DeferUsageSet",
4131
"FieldDetails",
42-
"FieldGroup",
43-
"GroupedFieldSetDetails",
44-
"Target",
45-
"TargetSet",
4632
"collect_fields",
4733
"collect_subfields",
4834
]
4935

5036

5137
class DeferUsage(NamedTuple):
52-
"""An optionally labelled list of ancestor targets."""
38+
"""An optionally labelled linked list of defer usages."""
5339

5440
label: str | None
55-
ancestors: list[Target]
56-
57-
58-
Target: TypeAlias = Union[DeferUsage, None]
41+
parent_defer_usage: DeferUsage | None
5942

60-
TargetSet: TypeAlias = RefSet[Target]
61-
DeferUsageSet: TypeAlias = RefSet[DeferUsage]
62-
63-
64-
NON_DEFERRED_TARGET_SET: TargetSet = RefSet([None])
43+
@property
44+
def ancestors(self) -> list[DeferUsage]:
45+
"""Get the ancestors of this defer usage."""
46+
ancestors: list[DeferUsage] = []
47+
parent_defer_usage = self.parent_defer_usage
48+
while parent_defer_usage is not None:
49+
ancestors.append(parent_defer_usage)
50+
parent_defer_usage = parent_defer_usage.parent_defer_usage
51+
return ancestors[::-1]
6552

6653

6754
class FieldDetails(NamedTuple):
68-
"""A field node and its target."""
55+
"""A field node and its defer usage."""
6956

7057
node: FieldNode
71-
target: Target
72-
73-
74-
class FieldGroup(NamedTuple):
75-
"""A group of fields that share the same target set."""
76-
77-
fields: list[FieldDetails]
78-
targets: TargetSet
79-
80-
def to_nodes(self) -> list[FieldNode]:
81-
"""Return the field nodes in this group."""
82-
return [field_details.node for field_details in self.fields]
83-
84-
85-
if sys.version_info < (3, 9):
86-
GroupedFieldSet: TypeAlias = Dict[str, FieldGroup]
87-
else: # Python >= 3.9
88-
GroupedFieldSet: TypeAlias = dict[str, FieldGroup]
89-
90-
91-
class GroupedFieldSetDetails(NamedTuple):
92-
"""A grouped field set with defer info."""
93-
94-
grouped_field_set: GroupedFieldSet
95-
should_initiate_defer: bool
96-
97-
98-
class CollectFieldsResult(NamedTuple):
99-
"""Collected fields and deferred usages."""
100-
101-
grouped_field_set: GroupedFieldSet
102-
new_grouped_field_set_details: RefMap[DeferUsageSet, GroupedFieldSetDetails]
103-
new_defer_usages: list[DeferUsage]
58+
defer_usage: DeferUsage | None
10459

10560

10661
class CollectFieldsContext(NamedTuple):
@@ -111,9 +66,6 @@ class CollectFieldsContext(NamedTuple):
11166
variable_values: dict[str, Any]
11267
operation: OperationDefinitionNode
11368
runtime_type: GraphQLObjectType
114-
targets_by_key: dict[str, TargetSet]
115-
fields_by_target: RefMap[Target, dict[str, list[FieldNode]]]
116-
new_defer_usages: list[DeferUsage]
11769
visited_fragment_names: set[str]
11870

11971

@@ -123,7 +75,7 @@ def collect_fields(
12375
variable_values: dict[str, Any],
12476
runtime_type: GraphQLObjectType,
12577
operation: OperationDefinitionNode,
126-
) -> CollectFieldsResult:
78+
) -> dict[str, list[FieldDetails]]:
12779
"""Collect fields.
12880
12981
Given a selection_set, collects all the fields and returns them.
@@ -134,23 +86,18 @@ def collect_fields(
13486
13587
For internal use only.
13688
"""
89+
grouped_field_set: dict[str, list[FieldDetails]] = defaultdict(list)
13790
context = CollectFieldsContext(
13891
schema,
13992
fragments,
14093
variable_values,
14194
operation,
14295
runtime_type,
143-
{},
144-
RefMap(),
145-
[],
14696
set(),
14797
)
148-
collect_fields_impl(context, operation.selection_set)
14998

150-
return CollectFieldsResult(
151-
*build_grouped_field_sets(context.targets_by_key, context.fields_by_target),
152-
context.new_defer_usages,
153-
)
99+
collect_fields_impl(context, operation.selection_set, grouped_field_set)
100+
return grouped_field_set
154101

155102

156103
def collect_subfields(
@@ -159,8 +106,8 @@ def collect_subfields(
159106
variable_values: dict[str, Any],
160107
operation: OperationDefinitionNode,
161108
return_type: GraphQLObjectType,
162-
field_group: FieldGroup,
163-
) -> CollectFieldsResult:
109+
field_details: list[FieldDetails],
110+
) -> dict[str, list[FieldDetails]]:
164111
"""Collect subfields.
165112
166113
Given a list of field nodes, collects all the subfields of the passed in fields,
@@ -178,30 +125,29 @@ def collect_subfields(
178125
variable_values,
179126
operation,
180127
return_type,
181-
{},
182-
RefMap(),
183-
[],
184128
set(),
185129
)
130+
sub_grouped_field_set: dict[str, list[FieldDetails]] = defaultdict(list)
186131

187-
for field_details in field_group.fields:
188-
node = field_details.node
132+
for field_detail in field_details:
133+
node = field_detail.node
189134
if node.selection_set:
190-
collect_fields_impl(context, node.selection_set, field_details.target)
135+
collect_fields_impl(
136+
context,
137+
node.selection_set,
138+
sub_grouped_field_set,
139+
field_detail.defer_usage,
140+
)
191141

192-
return CollectFieldsResult(
193-
*build_grouped_field_sets(
194-
context.targets_by_key, context.fields_by_target, field_group.targets
195-
),
196-
context.new_defer_usages,
197-
)
142+
return sub_grouped_field_set
198143

199144

200145
def collect_fields_impl(
201146
context: CollectFieldsContext,
202147
selection_set: SelectionSetNode,
203-
parent_target: Target | None = None,
204-
new_target: Target | None = None,
148+
grouped_field_set: dict[str, list[FieldDetails]],
149+
parent_defer_usage: DeferUsage | None = None,
150+
defer_usage: DeferUsage | None = None,
205151
) -> None:
206152
"""Collect fields (internal implementation)."""
207153
(
@@ -210,97 +156,71 @@ def collect_fields_impl(
210156
variable_values,
211157
operation,
212158
runtime_type,
213-
targets_by_key,
214-
fields_by_target,
215-
new_defer_usages,
216159
visited_fragment_names,
217160
) = context
218161

219-
ancestors: list[Target]
220-
221162
for selection in selection_set.selections:
222163
if isinstance(selection, FieldNode):
223164
if not should_include_node(variable_values, selection):
224165
continue
225166
key = get_field_entry_key(selection)
226-
target = new_target or parent_target
227-
key_targets = targets_by_key.get(key)
228-
if key_targets is None:
229-
key_targets = RefSet([target])
230-
targets_by_key[key] = key_targets
231-
else:
232-
key_targets.add(target)
233-
target_fields = fields_by_target.get(target)
234-
if target_fields is None:
235-
fields_by_target[target] = {key: [selection]}
236-
else:
237-
field_nodes = target_fields.get(key)
238-
if field_nodes is None:
239-
target_fields[key] = [selection]
240-
else:
241-
field_nodes.append(selection)
167+
grouped_field_set[key].append(
168+
FieldDetails(selection, defer_usage or parent_defer_usage)
169+
)
242170
elif isinstance(selection, InlineFragmentNode):
243171
if not should_include_node(
244172
variable_values, selection
245173
) or not does_fragment_condition_match(schema, selection, runtime_type):
246174
continue
247175

248-
defer = get_defer_values(operation, variable_values, selection)
249-
250-
if defer:
251-
ancestors = (
252-
[None]
253-
if parent_target is None
254-
else [parent_target, *parent_target.ancestors]
255-
)
256-
target = DeferUsage(defer.label, ancestors)
257-
new_defer_usages.append(target)
258-
else:
259-
target = new_target
176+
new_defer_usage = get_defer_usage(
177+
operation, variable_values, selection, parent_defer_usage
178+
)
260179

261-
collect_fields_impl(context, selection.selection_set, parent_target, target)
180+
collect_fields_impl(
181+
context,
182+
selection.selection_set,
183+
grouped_field_set,
184+
parent_defer_usage,
185+
new_defer_usage or defer_usage,
186+
)
262187
elif isinstance(selection, FragmentSpreadNode): # pragma: no cover else
263188
frag_name = selection.name.value
264189

265-
if not should_include_node(variable_values, selection):
266-
continue
190+
new_defer_usage = get_defer_usage(
191+
operation, variable_values, selection, parent_defer_usage
192+
)
267193

268-
defer = get_defer_values(operation, variable_values, selection)
269-
if frag_name in visited_fragment_names and not defer:
194+
if new_defer_usage is None and (
195+
frag_name in visited_fragment_names
196+
or not should_include_node(variable_values, selection)
197+
):
270198
continue
271199

272200
fragment = fragments.get(frag_name)
273-
if not fragment or not does_fragment_condition_match(
201+
if fragment is None or not does_fragment_condition_match(
274202
schema, fragment, runtime_type
275203
):
276204
continue
277205

278-
if defer:
279-
ancestors = (
280-
[None]
281-
if parent_target is None
282-
else [parent_target, *parent_target.ancestors]
283-
)
284-
target = DeferUsage(defer.label, ancestors)
285-
new_defer_usages.append(target)
286-
else:
206+
if new_defer_usage is None:
287207
visited_fragment_names.add(frag_name)
288-
target = new_target
289208

290-
collect_fields_impl(context, fragment.selection_set, parent_target, target)
291-
292-
293-
class DeferValues(NamedTuple):
294-
"""Values of an active defer directive."""
295-
296-
label: str | None
209+
collect_fields_impl(
210+
context,
211+
fragment.selection_set,
212+
grouped_field_set,
213+
parent_defer_usage,
214+
new_defer_usage or defer_usage,
215+
)
297216

298217

299-
def get_defer_values(
218+
def get_defer_usage(
300219
operation: OperationDefinitionNode,
301220
variable_values: dict[str, Any],
302221
node: FragmentSpreadNode | InlineFragmentNode,
303-
) -> DeferValues | None:
222+
parent_defer_usage: DeferUsage | None,
223+
) -> DeferUsage | None:
304224
"""Get values of defer directive if active.
305225
306226
Returns an object containing the `@defer` arguments if a field should be
@@ -319,7 +239,7 @@ def get_defer_values(
319239
)
320240
raise TypeError(msg)
321241

322-
return DeferValues(defer.get("label"))
242+
return DeferUsage(defer.get("label"), parent_defer_usage)
323243

324244

325245
def should_include_node(
@@ -360,111 +280,3 @@ def does_fragment_condition_match(
360280
def get_field_entry_key(node: FieldNode) -> str:
361281
"""Implement the logic to compute the key of a given field's entry"""
362282
return node.alias.value if node.alias else node.name.value
363-
364-
365-
def build_grouped_field_sets(
366-
targets_by_key: dict[str, TargetSet],
367-
fields_by_target: RefMap[Target, dict[str, list[FieldNode]]],
368-
parent_targets: TargetSet = NON_DEFERRED_TARGET_SET,
369-
) -> tuple[GroupedFieldSet, RefMap[DeferUsageSet, GroupedFieldSetDetails]]:
370-
"""Build grouped field sets."""
371-
parent_target_keys, target_set_details_map = get_target_set_details(
372-
targets_by_key, parent_targets
373-
)
374-
375-
grouped_field_set = (
376-
get_ordered_grouped_field_set(
377-
parent_target_keys, parent_targets, targets_by_key, fields_by_target
378-
)
379-
if parent_target_keys
380-
else {}
381-
)
382-
383-
new_grouped_field_set_details: RefMap[DeferUsageSet, GroupedFieldSetDetails] = (
384-
RefMap()
385-
)
386-
387-
for masking_targets, target_set_details in target_set_details_map.items():
388-
keys, should_initiate_defer = target_set_details
389-
390-
new_grouped_field_set = get_ordered_grouped_field_set(
391-
keys, masking_targets, targets_by_key, fields_by_target
392-
)
393-
394-
# All TargetSets that causes new grouped field sets consist only of DeferUsages
395-
# and have should_initiate_defer defined
396-
397-
new_grouped_field_set_details[cast("DeferUsageSet", masking_targets)] = (
398-
GroupedFieldSetDetails(new_grouped_field_set, should_initiate_defer)
399-
)
400-
401-
return grouped_field_set, new_grouped_field_set_details
402-
403-
404-
class TargetSetDetails(NamedTuple):
405-
"""A set of target keys with defer info."""
406-
407-
keys: set[str]
408-
should_initiate_defer: bool
409-
410-
411-
def get_target_set_details(
412-
targets_by_key: dict[str, TargetSet], parent_targets: TargetSet
413-
) -> tuple[set[str], RefMap[TargetSet, TargetSetDetails]]:
414-
"""Get target set details."""
415-
parent_target_keys: set[str] = set()
416-
target_set_details_map: RefMap[TargetSet, TargetSetDetails] = RefMap()
417-
418-
for response_key, targets in targets_by_key.items():
419-
masking_target_list: list[Target] = []
420-
for target in targets:
421-
if not target or all(
422-
ancestor not in targets for ancestor in target.ancestors
423-
):
424-
masking_target_list.append(target)
425-
426-
masking_targets: TargetSet = RefSet(masking_target_list)
427-
if masking_targets == parent_targets:
428-
parent_target_keys.add(response_key)
429-
continue
430-
431-
for target_set, target_set_details in target_set_details_map.items():
432-
if target_set == masking_targets:
433-
target_set_details.keys.add(response_key)
434-
break
435-
else:
436-
target_set_details = TargetSetDetails(
437-
{response_key},
438-
any(
439-
defer_usage not in parent_targets for defer_usage in masking_targets
440-
),
441-
)
442-
target_set_details_map[masking_targets] = target_set_details
443-
444-
return parent_target_keys, target_set_details_map
445-
446-
447-
def get_ordered_grouped_field_set(
448-
keys: set[str],
449-
masking_targets: TargetSet,
450-
targets_by_key: dict[str, TargetSet],
451-
fields_by_target: RefMap[Target, dict[str, list[FieldNode]]],
452-
) -> GroupedFieldSet:
453-
"""Get ordered grouped field set."""
454-
grouped_field_set: GroupedFieldSet = {}
455-
456-
first_target = next(iter(masking_targets))
457-
first_fields = fields_by_target[first_target]
458-
for key in list(first_fields):
459-
if key in keys:
460-
field_group = grouped_field_set.get(key)
461-
if field_group is None: # pragma: no cover else
462-
field_group = FieldGroup([], masking_targets)
463-
grouped_field_set[key] = field_group
464-
for target in targets_by_key[key]:
465-
fields_for_target = fields_by_target[target]
466-
nodes = fields_for_target[key]
467-
del fields_for_target[key]
468-
field_group.fields.extend(FieldDetails(node, target) for node in nodes)
469-
470-
return grouped_field_set

‎src/graphql/execution/execute.py

Lines changed: 68 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ..error import GraphQLError, located_error
4141
from ..language import (
4242
DocumentNode,
43+
FieldNode,
4344
FragmentDefinitionNode,
4445
OperationDefinitionNode,
4546
OperationType,
@@ -75,18 +76,15 @@
7576
is_object_type,
7677
)
7778
from .async_iterables import map_async_iterable
78-
from .collect_fields import (
79-
NON_DEFERRED_TARGET_SET,
80-
CollectFieldsResult,
81-
DeferUsage,
79+
from .build_field_plan import (
8280
DeferUsageSet,
83-
FieldDetails,
8481
FieldGroup,
82+
FieldPlan,
8583
GroupedFieldSet,
86-
GroupedFieldSetDetails,
87-
collect_fields,
88-
collect_subfields,
84+
NewGroupedFieldSetDetails,
85+
build_field_plan,
8986
)
87+
from .collect_fields import DeferUsage, FieldDetails, collect_fields, collect_subfields
9088
from .incremental_publisher import (
9189
ASYNC_DELAY,
9290
DeferredFragmentRecord,
@@ -208,7 +206,7 @@ def __init__(
208206
if is_awaitable:
209207
self.is_awaitable = is_awaitable
210208
self._canceled_iterators: set[AsyncIterator] = set()
211-
self._subfields_cache: dict[tuple, CollectFieldsResult] = {}
209+
self._field_plan_cache: dict[tuple, FieldPlan] = {}
212210
self._tasks: set[Awaitable] = set()
213211
self._stream_usages: RefMap[FieldGroup, StreamUsage] = RefMap()
214212

@@ -326,10 +324,11 @@ def execute_operation(
326324
)
327325
raise GraphQLError(msg, operation)
328326

329-
grouped_field_set, new_grouped_field_set_details, new_defer_usages = (
330-
collect_fields(
331-
schema, self.fragments, self.variable_values, root_type, operation
332-
)
327+
fields = collect_fields(
328+
schema, self.fragments, self.variable_values, root_type, operation
329+
)
330+
grouped_field_set, new_grouped_field_set_details_map, new_defer_usages = (
331+
build_field_plan(fields)
333332
)
334333

335334
incremental_publisher = self.incremental_publisher
@@ -341,7 +340,7 @@ def execute_operation(
341340

342341
new_deferred_grouped_field_set_records = add_new_deferred_grouped_field_sets(
343342
incremental_publisher,
344-
new_grouped_field_set_details,
343+
new_grouped_field_set_details_map,
345344
new_defer_map,
346345
path,
347346
)
@@ -502,7 +501,9 @@ def execute_field(
502501
if self.middleware_manager:
503502
resolve_fn = self.middleware_manager.get_field_resolver(resolve_fn)
504503

505-
info = self.build_resolve_info(field_def, field_group, parent_type, path)
504+
info = self.build_resolve_info(
505+
field_def, field_group.to_nodes(), parent_type, path
506+
)
506507

507508
# Get the resolve function, regardless of if its result is normal or abrupt
508509
# (error).
@@ -575,7 +576,7 @@ async def await_completed() -> Any:
575576
def build_resolve_info(
576577
self,
577578
field_def: GraphQLField,
578-
field_group: FieldGroup,
579+
field_nodes: list[FieldNode],
579580
parent_type: GraphQLObjectType,
580581
path: Path,
581582
) -> GraphQLResolveInfo:
@@ -586,8 +587,8 @@ def build_resolve_info(
586587
# The resolve function's first argument is a collection of information about
587588
# the current execution state.
588589
return GraphQLResolveInfo(
589-
field_group.fields[0].node.name.value,
590-
field_group.to_nodes(),
590+
field_nodes[0].name.value,
591+
field_nodes,
591592
field_def.type,
592593
parent_type,
593594
path,
@@ -807,8 +808,7 @@ def get_stream_usage(
807808
[
808809
FieldDetails(field_details.node, None)
809810
for field_details in field_group.fields
810-
],
811-
NON_DEFERRED_TARGET_SET,
811+
]
812812
)
813813

814814
stream_usage = StreamUsage(
@@ -1273,8 +1273,8 @@ def collect_and_execute_subfields(
12731273
defer_map: RefMap[DeferUsage, DeferredFragmentRecord],
12741274
) -> AwaitableOrValue[dict[str, Any]]:
12751275
"""Collect sub-fields to execute to complete this value."""
1276-
grouped_field_set, new_grouped_field_set_details, new_defer_usages = (
1277-
self.collect_subfields(return_type, field_group)
1276+
grouped_field_set, new_grouped_field_set_details_map, new_defer_usages = (
1277+
self.build_sub_field_plan(return_type, field_group)
12781278
)
12791279

12801280
incremental_publisher = self.incremental_publisher
@@ -1286,7 +1286,10 @@ def collect_and_execute_subfields(
12861286
path,
12871287
)
12881288
new_deferred_grouped_field_set_records = add_new_deferred_grouped_field_sets(
1289-
incremental_publisher, new_grouped_field_set_details, new_defer_map, path
1289+
incremental_publisher,
1290+
new_grouped_field_set_details_map,
1291+
new_defer_map,
1292+
path,
12901293
)
12911294

12921295
sub_fields = self.execute_fields(
@@ -1308,17 +1311,16 @@ def collect_and_execute_subfields(
13081311

13091312
return sub_fields
13101313

1311-
def collect_subfields(
1314+
def build_sub_field_plan(
13121315
self, return_type: GraphQLObjectType, field_group: FieldGroup
1313-
) -> CollectFieldsResult:
1316+
) -> FieldPlan:
13141317
"""Collect subfields.
13151318
1316-
A cached collection of relevant subfields with regard to the return type is
1317-
kept in the execution context as ``_subfields_cache``. This ensures the
1318-
subfields are not repeatedly calculated, which saves overhead when resolving
1319-
lists of values.
1319+
A memoized function for building subfield plans with regard to the return type.
1320+
Memoizing ensures the subfields are not repeatedly calculated, which saves
1321+
overhead when resolving lists of values.
13201322
"""
1321-
cache = self._subfields_cache
1323+
cache = self._field_plan_cache
13221324
# We cannot use the field_group itself as key for the cache, since it
13231325
# is not hashable as a list. We also do not want to use the field_group
13241326
# itself (converted to a tuple) as keys, since hashing them is slow.
@@ -1331,18 +1333,21 @@ def collect_subfields(
13311333
if len(field_group) == 1 # optimize most frequent case
13321334
else (return_type, *map(id, field_group))
13331335
)
1334-
sub_fields_and_patches = cache.get(key)
1335-
if sub_fields_and_patches is None:
1336-
sub_fields_and_patches = collect_subfields(
1336+
plan = cache.get(key)
1337+
if plan is None:
1338+
sub_fields = collect_subfields(
13371339
self.schema,
13381340
self.fragments,
13391341
self.variable_values,
13401342
self.operation,
13411343
return_type,
1342-
field_group,
1344+
field_group.fields,
1345+
)
1346+
plan = build_field_plan(
1347+
sub_fields, field_group.defer_usages, field_group.known_defer_usages
13431348
)
1344-
cache[key] = sub_fields_and_patches
1345-
return sub_fields_and_patches
1349+
cache[key] = plan
1350+
return plan
13461351

13471352
def map_source_to_response(
13481353
self, result_or_stream: ExecutionResult | AsyncIterable[Any]
@@ -1954,14 +1959,10 @@ def add_new_deferred_fragments(
19541959
new_defer_map = RefMap() if defer_map is None else RefMap(defer_map.items())
19551960

19561961
# For each new DeferUsage object:
1957-
for defer_usage in new_defer_usages:
1958-
ancestors = defer_usage.ancestors
1959-
parent_defer_usage = ancestors[0] if ancestors else None
1960-
1961-
# If the parent target is defined, the parent target is a DeferUsage object
1962-
# and the parent result record is the DeferredFragmentRecord corresponding
1963-
# to that DeferUsage.
1964-
# If the parent target is not defined, the parent result record is either:
1962+
for new_defer_usage in new_defer_usages:
1963+
parent_defer_usage = new_defer_usage.parent_defer_usage
1964+
1965+
# If the parent defer usage is not defined, the parent result record is either:
19651966
# - the InitialResultRecord, or
19661967
# - a StreamItemsRecord, as `@defer` may be nested under `@stream`.
19671968
parent = (
@@ -1975,15 +1976,15 @@ def add_new_deferred_fragments(
19751976
)
19761977

19771978
# Instantiate the new record.
1978-
deferred_fragment_record = DeferredFragmentRecord(path, defer_usage.label)
1979+
deferred_fragment_record = DeferredFragmentRecord(path, new_defer_usage.label)
19791980

19801981
# Report the new record to the Incremental Publisher.
19811982
incremental_publisher.report_new_defer_fragment_record(
19821983
deferred_fragment_record, parent
19831984
)
19841985

19851986
# Update the map.
1986-
new_defer_map[defer_usage] = deferred_fragment_record
1987+
new_defer_map[new_defer_usage] = deferred_fragment_record
19871988

19881989
return new_defer_map
19891990

@@ -1997,24 +1998,26 @@ def deferred_fragment_record_from_defer_usage(
19971998

19981999
def add_new_deferred_grouped_field_sets(
19992000
incremental_publisher: IncrementalPublisher,
2000-
new_grouped_field_set_details: Mapping[DeferUsageSet, GroupedFieldSetDetails],
2001+
new_grouped_field_set_details_map: Mapping[
2002+
DeferUsageSet, NewGroupedFieldSetDetails
2003+
],
20012004
defer_map: RefMap[DeferUsage, DeferredFragmentRecord],
20022005
path: Path | None = None,
20032006
) -> list[DeferredGroupedFieldSetRecord]:
20042007
"""Add new deferred grouped field sets to the defer map."""
20052008
new_deferred_grouped_field_set_records: list[DeferredGroupedFieldSetRecord] = []
20062009

20072010
for (
2008-
new_grouped_field_set_defer_usages,
2009-
grouped_field_set_details,
2010-
) in new_grouped_field_set_details.items():
2011+
defer_usage_set,
2012+
[grouped_field_set, should_initiate_defer],
2013+
) in new_grouped_field_set_details_map.items():
20112014
deferred_fragment_records = get_deferred_fragment_records(
2012-
new_grouped_field_set_defer_usages, defer_map
2015+
defer_usage_set, defer_map
20132016
)
20142017
deferred_grouped_field_set_record = DeferredGroupedFieldSetRecord(
20152018
deferred_fragment_records,
2016-
grouped_field_set_details.grouped_field_set,
2017-
grouped_field_set_details.should_initiate_defer,
2019+
grouped_field_set,
2020+
should_initiate_defer,
20182021
path,
20192022
)
20202023
incremental_publisher.report_new_deferred_grouped_filed_set_record(
@@ -2292,34 +2295,34 @@ def execute_subscription(
22922295
msg = "Schema is not configured to execute subscription operation."
22932296
raise GraphQLError(msg, context.operation)
22942297

2295-
grouped_field_set = collect_fields(
2298+
fields = collect_fields(
22962299
schema,
22972300
context.fragments,
22982301
context.variable_values,
22992302
root_type,
23002303
context.operation,
2301-
).grouped_field_set
2302-
first_root_field = next(iter(grouped_field_set.items()))
2303-
response_name, field_group = first_root_field
2304-
field_name = field_group.fields[0].node.name.value
2304+
)
2305+
2306+
first_root_field = next(iter(fields.items()))
2307+
response_name, field_details_list = first_root_field
2308+
field_name = field_details_list[0].node.name.value
23052309
field_def = schema.get_field(root_type, field_name)
23062310

2311+
field_nodes = [field_details.node for field_details in field_details_list]
23072312
if not field_def:
23082313
msg = f"The subscription field '{field_name}' is not defined."
2309-
raise GraphQLError(msg, field_group.to_nodes())
2314+
raise GraphQLError(msg, field_nodes)
23102315

23112316
path = Path(None, response_name, root_type.name)
2312-
info = context.build_resolve_info(field_def, field_group, root_type, path)
2317+
info = context.build_resolve_info(field_def, field_nodes, root_type, path)
23132318

23142319
# Implements the "ResolveFieldEventStream" algorithm from GraphQL specification.
23152320
# It differs from "ResolveFieldValue" due to providing a different `resolveFn`.
23162321

23172322
try:
23182323
# Build a dictionary of arguments from the field.arguments AST, using the
23192324
# variables scope to fulfill any variable references.
2320-
args = get_argument_values(
2321-
field_def, field_group.fields[0].node, context.variable_values
2322-
)
2325+
args = get_argument_values(field_def, field_nodes[0], context.variable_values)
23232326

23242327
# Call the `subscribe()` resolver or the default resolver to produce an
23252328
# AsyncIterable yielding raw payloads.
@@ -2332,16 +2335,14 @@ async def await_result() -> AsyncIterable[Any]:
23322335
try:
23332336
return assert_event_stream(await result)
23342337
except Exception as error:
2335-
raise located_error(
2336-
error, field_group.to_nodes(), path.as_list()
2337-
) from error
2338+
raise located_error(error, field_nodes, path.as_list()) from error
23382339

23392340
return await_result()
23402341

23412342
return assert_event_stream(result)
23422343

23432344
except Exception as error:
2344-
raise located_error(error, field_group.to_nodes(), path.as_list()) from error
2345+
raise located_error(error, field_nodes, path.as_list()) from error
23452346

23462347

23472348
def assert_event_stream(result: Any) -> AsyncIterable:

‎src/graphql/execution/incremental_publisher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
if TYPE_CHECKING:
2727
from ..error import GraphQLError, GraphQLFormattedError
2828
from ..pyutils import Path
29-
from .collect_fields import GroupedFieldSet
29+
from .build_field_plan import GroupedFieldSet
3030

3131
__all__ = [
3232
"ASYNC_DELAY",

‎src/graphql/validation/rules/single_field_subscriptions.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any
66

77
from ...error import GraphQLError
8-
from ...execution.collect_fields import FieldGroup, collect_fields
8+
from ...execution.collect_fields import FieldDetails, collect_fields
99
from ...language import (
1010
FieldNode,
1111
FragmentDefinitionNode,
@@ -17,8 +17,8 @@
1717
__all__ = ["SingleFieldSubscriptionsRule"]
1818

1919

20-
def to_nodes(field_group: FieldGroup) -> list[FieldNode]:
21-
return [field_details.node for field_details in field_group.fields]
20+
def to_nodes(field_details_list: list[FieldDetails]) -> list[FieldNode]:
21+
return [field_details.node for field_details in field_details_list]
2222

2323

2424
class SingleFieldSubscriptionsRule(ValidationRule):
@@ -46,15 +46,15 @@ def enter_operation_definition(
4646
for definition in document.definitions
4747
if isinstance(definition, FragmentDefinitionNode)
4848
}
49-
grouped_field_set = collect_fields(
49+
fields = collect_fields(
5050
schema,
5151
fragments,
5252
variable_values,
5353
subscription_type,
5454
node,
55-
).grouped_field_set
56-
if len(grouped_field_set) > 1:
57-
field_groups = list(grouped_field_set.values())
55+
)
56+
if len(fields) > 1:
57+
field_groups = list(fields.values())
5858
extra_field_groups = field_groups[1:]
5959
extra_field_selection = [
6060
node
@@ -72,7 +72,7 @@ def enter_operation_definition(
7272
extra_field_selection,
7373
)
7474
)
75-
for field_group in grouped_field_set.values():
75+
for field_group in fields.values():
7676
field_name = to_nodes(field_group)[0].name.value
7777
if field_name.startswith("__"):
7878
self.report_error(

0 commit comments

Comments
 (0)
Please sign in to comment.