Skip to content

Commit 350e511

Browse files
evantey14obi1kenobi
authored andcommitted
Add output info to query metadata (kensho-technologies#311)
* Add output info to query metadata * Update explain info tests * Add output tests to explain_info * Remove context output * Add docstring metadata outputs property * Use query_metadta_table variable and name * Replace more names with query_metadata_table * Remove fold from OutputInfo
1 parent 8481ae5 commit 350e511

File tree

3 files changed

+269
-43
lines changed

3 files changed

+269
-43
lines changed

graphql_compiler/compiler/compiler_frontend.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
get_vertex_field_type, invert_dict, is_tagged_parameter, is_vertex_field_name,
9191
strip_non_null_from_type, validate_output_name, validate_safe_string
9292
)
93-
from .metadata import LocationInfo, QueryMetadataTable, RecurseInfo, TagInfo
93+
from .metadata import LocationInfo, OutputInfo, QueryMetadataTable, RecurseInfo, TagInfo
9494

9595

9696
# LocationStackEntry contains the following:
@@ -298,7 +298,7 @@ def _compile_property_ast(schema, current_schema_type, ast, location,
298298
if output_directive:
299299
# Schema validation has ensured that the fields below exist.
300300
output_name = output_directive.arguments[0].value.value
301-
if output_name in context['outputs']:
301+
if context['metadata'].get_output_info(output_name):
302302
raise GraphQLCompilationError(u'Cannot reuse output name: '
303303
u'{}, {}'.format(output_name, context))
304304
validate_safe_string(output_name)
@@ -312,12 +312,12 @@ def _compile_property_ast(schema, current_schema_type, ast, location,
312312
if location.field != COUNT_META_FIELD_NAME:
313313
graphql_type = GraphQLList(graphql_type)
314314

315-
context['outputs'][output_name] = {
316-
'location': location,
317-
'optional': is_in_optional_scope(context),
318-
'type': graphql_type,
319-
'fold': context.get('fold', None),
320-
}
315+
output_info = OutputInfo(
316+
location=location,
317+
type=graphql_type,
318+
optional=is_in_optional_scope(context),
319+
)
320+
context['metadata'].record_output_info(output_name, output_info)
321321

322322

323323
def _get_recurse_directive_depth(field_name, field_directives):
@@ -524,7 +524,7 @@ def _compile_vertex_ast(schema, current_schema_type, ast,
524524
if edge_traversal_is_folded:
525525
has_count_filter = has_fold_count_filter(context)
526526
_validate_fold_has_outputs_or_count_filter(
527-
get_context_fold_info(context), has_count_filter, context['outputs'])
527+
get_context_fold_info(context), has_count_filter, query_metadata_table)
528528
basic_blocks.append(blocks.Unfold())
529529
unmark_context_fold_scope(context)
530530
if has_count_filter:
@@ -560,7 +560,19 @@ def _compile_vertex_ast(schema, current_schema_type, ast,
560560
return basic_blocks
561561

562562

563-
def _validate_fold_has_outputs_or_count_filter(fold_scope_location, fold_has_count_filter, outputs):
563+
def _are_locations_in_same_fold(first_location, second_location):
564+
"""Returns True if locations are contained in the same fold scope."""
565+
return (
566+
isinstance(first_location, FoldScopeLocation) and
567+
isinstance(second_location, FoldScopeLocation) and
568+
first_location.base_location == second_location.base_location and
569+
first_location.get_first_folded_edge() == second_location.get_first_folded_edge()
570+
)
571+
572+
573+
def _validate_fold_has_outputs_or_count_filter(
574+
fold_scope_location, fold_has_count_filter, query_metadata_table
575+
):
564576
"""Ensure the @fold scope has at least one output, or filters on the size of the fold."""
565577
# This function makes sure that the @fold scope has an effect.
566578
# Folds either output data, or filter the data enclosing the fold based on the size of the fold.
@@ -570,8 +582,8 @@ def _validate_fold_has_outputs_or_count_filter(fold_scope_location, fold_has_cou
570582

571583
# At least one output in the outputs list must point to the fold_scope_location,
572584
# or the scope corresponding to fold_scope_location had no @outputs and is illegal.
573-
for output in six.itervalues(outputs):
574-
if output['fold'] == fold_scope_location:
585+
for _, output_info in query_metadata_table.outputs:
586+
if _are_locations_in_same_fold(output_info.location, fold_scope_location):
575587
return True
576588

577589
raise GraphQLCompilationError(u'Found a @fold scope that has no effect on the query. '
@@ -798,13 +810,6 @@ def _compile_root_ast_to_ir(schema, ast, type_equivalence_hints=None):
798810
# query processing, but apply to the global query scope and should be appended to the
799811
# IR blocks only after the GlobalOperationsStart block has been emitted.
800812
'global_filters': [],
801-
# 'outputs' is a dict mapping each output name to another dict which contains
802-
# - location: Location where to output from
803-
# - optional: boolean representing whether the output was defined within an @optional scope
804-
# - type: GraphQLType of the output
805-
# - fold: FoldScopeLocation object if the current output was defined within a fold scope,
806-
# and None otherwise
807-
'outputs': dict(),
808813
# 'inputs' is a dict mapping input parameter names to their respective expected GraphQL
809814
# types, as automatically inferred by inspecting the query structure
810815
'inputs': dict(),
@@ -853,11 +858,10 @@ def _compile_root_ast_to_ir(schema, ast, type_equivalence_hints=None):
853858
basic_blocks.extend(context['global_filters'])
854859

855860
# Based on the outputs context data, add an output step and construct the output metadata.
856-
outputs_context = context['outputs']
857-
basic_blocks.append(_compile_output_step(outputs_context))
861+
basic_blocks.append(_compile_output_step(query_metadata_table))
858862
output_metadata = {
859-
name: OutputMetadata(type=value['type'], optional=value['optional'])
860-
for name, value in six.iteritems(outputs_context)
863+
name: OutputMetadata(type=info.type, optional=info.optional)
864+
for name, info in query_metadata_table.outputs
861865
}
862866

863867
return IrAndMetadata(
@@ -867,34 +871,35 @@ def _compile_root_ast_to_ir(schema, ast, type_equivalence_hints=None):
867871
query_metadata_table=context['metadata'])
868872

869873

870-
def _compile_output_step(outputs):
874+
def _compile_output_step(query_metadata_table):
871875
"""Construct the final ConstructResult basic block that defines the output format of the query.
872876
873877
Args:
874-
outputs: dict, output name (string) -> output data dict, specifying the location
875-
from where to get the data, and whether the data is optional (and therefore
876-
may be missing); missing optional data is replaced with 'null'
878+
query_metadata_table: QueryMetadataTable object, part of which specifies the location from
879+
where to get the output, and whether the output is optional (and
880+
therefore may be missing); missing optional data is replaced with
881+
'null'
877882
878883
Returns:
879884
a ConstructResult basic block that constructs appropriate outputs for the query
880885
"""
881-
if not outputs:
886+
if next(query_metadata_table.outputs, None) is None:
882887
raise GraphQLCompilationError(u'No fields were selected for output! Please mark at least '
883888
u'one field with the @output directive.')
884889

885890
output_fields = {}
886-
for output_name, output_context in six.iteritems(outputs):
887-
location = output_context['location']
888-
optional = output_context['optional']
889-
graphql_type = output_context['type']
891+
for output_name, output_info in query_metadata_table.outputs:
892+
location = output_info.location
893+
optional = output_info.optional
894+
graphql_type = output_info.type
890895

891896
expression = None
892897
existence_check = None
893898
# pylint: disable=redefined-variable-type
894899
if isinstance(location, FoldScopeLocation):
895900
if optional:
896901
raise AssertionError(u'Unreachable state reached, optional in fold: '
897-
u'{}'.format(output_context))
902+
u'{}'.format(output_info))
898903

899904
if location.field == COUNT_META_FIELD_NAME:
900905
expression = expressions.FoldCountContextField(location)

graphql_compiler/compiler/metadata.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@
2424
)
2525

2626

27+
OutputInfo = namedtuple(
28+
'OutputInfo',
29+
(
30+
'location', # Location/FoldScopeLocation, where to output from
31+
'type', # GraphQLType of the output
32+
'optional', # boolean, whether the output was defined within an @optional scope
33+
)
34+
)
35+
2736
TagInfo = namedtuple(
2837
'TagInfo',
2938
(
@@ -161,6 +170,25 @@ def get_location_info(self, location):
161170
u'{}'.format(location))
162171
return location_info
163172

173+
@property
174+
def outputs(self):
175+
"""Return an iterable of (output_name, output_info) tuples for all outputs in the query."""
176+
for output_name, output_info in six.iteritems(self._outputs):
177+
yield output_name, output_info
178+
179+
def record_output_info(self, output_name, output_info):
180+
"""Record information about the output."""
181+
old_info = self._outputs.get(output_name, None)
182+
if old_info is not None:
183+
raise AssertionError(u'Attempting to reuse an already-defined output name {}. '
184+
u'old info {}, new info {}.'
185+
.format(output_name, old_info, output_info))
186+
self._outputs[output_name] = output_info
187+
188+
def get_output_info(self, output_name):
189+
"""Get information about an output."""
190+
return self._outputs.get(output_name, None)
191+
164192
@property
165193
def tags(self):
166194
"""Return an iterable of (tag_name, tag_info) tuples for all tags in the query."""

0 commit comments

Comments
 (0)