Skip to content

Commit f6079c6

Browse files
authored
Include location type in ContextField and related expressions. (kensho-technologies#229)
* Include location type in ContextField and related types. * Fix the type check in the equality comparison of compiler entities.
1 parent 9251a83 commit f6079c6

File tree

14 files changed

+328
-271
lines changed

14 files changed

+328
-271
lines changed

graphql_compiler/compiler/compiler_entities.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from abc import ABCMeta, abstractmethod
55

6+
from graphql import is_type
67
import six
78

89

@@ -43,9 +44,23 @@ def __repr__(self):
4344
# pylint: disable=protected-access
4445
def __eq__(self, other):
4546
"""Return True if the CompilerEntity objects are equal, and False otherwise."""
46-
return (type(self) == type(other) and
47-
self._print_args == other._print_args and
48-
self._print_kwargs == other._print_kwargs)
47+
if type(self) != type(other):
48+
return False
49+
50+
if len(self._print_args) != len(other._print_args):
51+
return False
52+
53+
# The args sometimes contain GraphQL type objects, which unfortunately do not define "==".
54+
# We have to split them out and compare them using "is_same_type()" instead.
55+
for self_arg, other_arg in six.moves.zip(self._print_args, other._print_args):
56+
if is_type(self_arg):
57+
if not self_arg.is_same_type(other_arg):
58+
return False
59+
else:
60+
if self_arg != other_arg:
61+
return False
62+
63+
return self._print_kwargs == other._print_kwargs
4964
# pylint: enable=protected-access
5065

5166
def __ne__(self, other):

graphql_compiler/compiler/compiler_frontend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,8 @@ def _compile_ast_node_to_ir(schema, current_schema_type, ast, location, context)
688688

689689
visitor_fn = expressions.make_type_replacement_visitor(
690690
expressions.ContextField,
691-
lambda context_field: expressions.GlobalContextField(context_field.location))
691+
lambda context_field: expressions.GlobalContextField(
692+
context_field.location, context_field.field_type))
692693
filter_block = filter_block.visit_and_update_expressions(visitor_fn)
693694

694695
set_fold_count_filter(context)

graphql_compiler/compiler/expressions.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,9 @@ def to_gremlin(self):
270270
class GlobalContextField(Expression):
271271
"""A field drawn from the global context, for use in a global operations WHERE statement."""
272272

273-
__slots__ = ('location',)
273+
__slots__ = ('location', 'field_type')
274274

275-
def __init__(self, location):
275+
def __init__(self, location, field_type):
276276
"""Construct a new GlobalContextField object that references a field at a given location.
277277
278278
Args:
@@ -281,8 +281,9 @@ def __init__(self, location):
281281
Returns:
282282
new GlobalContextField object
283283
"""
284-
super(GlobalContextField, self).__init__(location)
284+
super(GlobalContextField, self).__init__(location, field_type)
285285
self.location = location
286+
self.field_type = field_type
286287
self.validate()
287288

288289
def validate(self):
@@ -295,6 +296,9 @@ def validate(self):
295296
raise AssertionError(u'Received Location without a field: {}'
296297
.format(self.location))
297298

299+
if not is_graphql_type(self.field_type):
300+
raise ValueError(u'Invalid value of "field_type": {}'.format(self.field_type))
301+
298302
def to_match(self):
299303
"""Return a unicode object with the MATCH representation of this GlobalContextField."""
300304
self.validate()
@@ -314,22 +318,24 @@ def to_gremlin(self):
314318
class ContextField(Expression):
315319
"""A field drawn from the global context, e.g. if selected earlier in the query."""
316320

317-
__slots__ = ('location',)
321+
__slots__ = ('location', 'field_type')
318322

319-
def __init__(self, location):
323+
def __init__(self, location, field_type):
320324
"""Construct a new ContextField object that references a field from the global context.
321325
322326
Args:
323327
location: Location, specifying where the field was declared. If the Location points
324328
to a vertex, the field refers to the data captured at the location vertex.
325329
Otherwise, if the Location points to a property, the field refers to
326330
the particular value of that property.
331+
field_type: GraphQLType object, specifying the type of the field being output
327332
328333
Returns:
329334
new ContextField object
330335
"""
331-
super(ContextField, self).__init__(location)
336+
super(ContextField, self).__init__(location, field_type)
332337
self.location = location
338+
self.field_type = field_type
333339
self.validate()
334340

335341
def validate(self):
@@ -338,6 +344,9 @@ def validate(self):
338344
raise TypeError(u'Expected Location location, got: {} {}'.format(
339345
type(self.location).__name__, self.location))
340346

347+
if not is_graphql_type(self.field_type):
348+
raise ValueError(u'Invalid value of "field_type": {}'.format(self.field_type))
349+
341350
def to_match(self):
342351
"""Return a unicode object with the MATCH representation of this ContextField."""
343352
self.validate()

graphql_compiler/compiler/filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _represent_argument(directive_location, context, argument, inferred_type):
148148
if field_is_local:
149149
representation = expressions.LocalField(argument_name)
150150
else:
151-
representation = expressions.ContextField(location)
151+
representation = expressions.ContextField(location, tag_inferred_type)
152152

153153
return (representation, non_existence_expression)
154154
else:

graphql_compiler/compiler/ir_lowering_common.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,30 +53,34 @@ def to_match(self):
5353
return mark_name
5454

5555

56-
def lower_context_field_existence(ir_blocks):
56+
def lower_context_field_existence(ir_blocks, query_metadata_table):
5757
"""Lower ContextFieldExistence expressions into lower-level expressions."""
5858
def regular_visitor_fn(expression):
5959
"""Expression visitor function that rewrites ContextFieldExistence expressions."""
6060
if not isinstance(expression, ContextFieldExistence):
6161
return expression
6262

63+
location_type = query_metadata_table.get_location_info(expression.location).type
64+
6365
# Since this function is only used in blocks that aren't ConstructResult,
6466
# the location check is performed using a regular ContextField expression.
6567
return BinaryComposition(
6668
u'!=',
67-
ContextField(expression.location),
69+
ContextField(expression.location, location_type),
6870
NullLiteral)
6971

7072
def construct_result_visitor_fn(expression):
7173
"""Expression visitor function that rewrites ContextFieldExistence expressions."""
7274
if not isinstance(expression, ContextFieldExistence):
7375
return expression
7476

77+
location_type = query_metadata_table.get_location_info(expression.location).type
78+
7579
# Since this function is only used in ConstructResult blocks,
7680
# the location check is performed using the special OutputContextVertex expression.
7781
return BinaryComposition(
7882
u'!=',
79-
OutputContextVertex(expression.location),
83+
OutputContextVertex(expression.location, location_type),
8084
NullLiteral)
8185

8286
new_ir_blocks = []

graphql_compiler/compiler/ir_lowering_gremlin/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def lower_ir(ir_blocks, query_metadata_table, type_equivalence_hints=None):
3939
"""
4040
sanity_check_ir_blocks_from_frontend(ir_blocks, query_metadata_table)
4141

42-
ir_blocks = lower_context_field_existence(ir_blocks)
42+
ir_blocks = lower_context_field_existence(ir_blocks, query_metadata_table)
4343
ir_blocks = optimize_boolean_expression_comparisons(ir_blocks)
4444

4545
if type_equivalence_hints:

graphql_compiler/compiler/ir_lowering_match/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,18 @@ def lower_ir(ir_blocks, query_metadata_table, type_equivalence_hints=None):
8080
# Append global operation block(s) to filter out incorrect results
8181
# from simple optional match traverses (using a WHERE statement)
8282
if len(simple_optional_root_info) > 0:
83-
where_filter_predicate = construct_where_filter_predicate(simple_optional_root_info)
83+
where_filter_predicate = construct_where_filter_predicate(
84+
query_metadata_table, simple_optional_root_info)
8485
ir_blocks.insert(-1, GlobalOperationsStart())
8586
ir_blocks.insert(-1, Filter(where_filter_predicate))
8687

8788
# These lowering / optimization passes work on IR blocks.
88-
ir_blocks = lower_context_field_existence(ir_blocks)
89+
ir_blocks = lower_context_field_existence(ir_blocks, query_metadata_table)
8990
ir_blocks = optimize_boolean_expression_comparisons(ir_blocks)
9091
ir_blocks = rewrite_binary_composition_inside_ternary_conditional(ir_blocks)
9192
ir_blocks = merge_consecutive_filter_clauses(ir_blocks)
9293
ir_blocks = lower_has_substring_binary_compositions(ir_blocks)
93-
ir_blocks = orientdb_eval_scheduling.workaround_lowering_pass(ir_blocks)
94+
ir_blocks = orientdb_eval_scheduling.workaround_lowering_pass(ir_blocks, query_metadata_table)
9495

9596
# Here, we lower from raw IR blocks into a MatchQuery object.
9697
# From this point on, the lowering / optimization passes work on the MatchQuery representation.

graphql_compiler/compiler/ir_lowering_match/ir_lowering.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ..blocks import Backtrack, CoerceType, MarkLocation, QueryRoot
1414
from ..expressions import (
1515
BinaryComposition, ContextField, ContextFieldExistence, FalseLiteral, FoldedContextField,
16-
Literal, TernaryConditional, TrueLiteral
16+
GlobalContextField, Literal, TernaryConditional, TrueLiteral
1717
)
1818
from ..helpers import FoldScopeLocation
1919
from .utils import convert_coerce_type_to_instanceof_filter
@@ -254,17 +254,24 @@ def _translate_equivalent_locations(match_query, location_translations):
254254

255255
def visitor_fn(expression):
256256
"""Expression visitor function used to rewrite expressions with updated Location data."""
257-
if isinstance(expression, (ContextField, ContextFieldExistence)):
258-
old_location = expression.location
257+
if isinstance(expression, (ContextField, GlobalContextField)):
258+
old_location = expression.location.at_vertex()
259259
new_location = location_translations.get(old_location, old_location)
260+
if expression.location.field is not None:
261+
new_location = new_location.navigate_to_field(expression.location.field)
260262

261263
# The Expression could be one of many types, including:
262264
# - ContextField
263-
# - ContextFieldExistence
265+
# - GlobalContextField
264266
# We determine its exact class to make sure we return an object of the same class
265-
# as the replacement expression.
267+
# as the expression being replaced.
266268
expression_cls = type(expression)
267-
return expression_cls(new_location)
269+
return expression_cls(new_location, expression.field_type)
270+
elif isinstance(expression, ContextFieldExistence):
271+
old_location = expression.location
272+
new_location = location_translations.get(old_location, old_location)
273+
274+
return ContextFieldExistence(new_location)
268275
elif isinstance(expression, FoldedContextField):
269276
# Update the Location within FoldedContextField
270277
old_location = expression.fold_scope_location.base_location
@@ -322,8 +329,13 @@ def visitor_fn(expression):
322329
# Rewrite the Locations in the ConstructResult output block.
323330
new_output_block = match_query.output_block.visit_and_update_expressions(visitor_fn)
324331

332+
# Rewrite the Locations in the global where block.
333+
new_where_block = None
334+
if match_query.where_block is not None:
335+
new_where_block = match_query.where_block.visit_and_update_expressions(visitor_fn)
336+
325337
return match_query._replace(match_traversals=new_match_traversals, folds=new_folds,
326-
output_block=new_output_block)
338+
output_block=new_output_block, where_block=new_where_block)
327339

328340

329341
def lower_folded_coerce_types_into_filter_blocks(folded_ir_blocks):

graphql_compiler/compiler/ir_lowering_match/utils.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
BinaryComposition, Expression, GlobalContextField, Literal, LocalField, NullLiteral,
1010
TrueLiteral, UnaryTransformation, ZeroLiteral
1111
)
12-
from ..helpers import Location, get_only_element_from_collection, is_vertex_field_name
12+
from ..helpers import get_only_element_from_collection, is_vertex_field_name
1313

1414

1515
def convert_coerce_type_to_instanceof_filter(coerce_type_block):
@@ -134,7 +134,8 @@ def filter_edge_field_non_existence(edge_expression):
134134
return BinaryComposition(u'||', field_null_check, field_size_check)
135135

136136

137-
def _filter_orientdb_simple_optional_edge(optional_edge_location, inner_location_name):
137+
def _filter_orientdb_simple_optional_edge(
138+
query_metadata_table, optional_edge_location, inner_location_name):
138139
"""Return an Expression that is False for rows that don't follow the @optional specification.
139140
140141
OrientDB does not filter correctly within optionals. Namely, a result where the optional edge
@@ -162,6 +163,9 @@ def _filter_orientdb_simple_optional_edge(optional_edge_location, inner_location
162163
Here, the `optional_edge_location` is `Animal___1.out_Animal_ParentOf`.
163164
164165
Args:
166+
query_metadata_table: QueryMetadataTable object containing all metadata collected during
167+
query processing, including location metadata (e.g. which locations
168+
are folded or optional).
165169
optional_edge_location: Location object representing the optional edge field
166170
inner_location_name: string representing location within the corresponding optional traverse
167171
@@ -171,13 +175,21 @@ def _filter_orientdb_simple_optional_edge(optional_edge_location, inner_location
171175
inner_local_field = LocalField(inner_location_name)
172176
inner_location_existence = BinaryComposition(u'!=', inner_local_field, NullLiteral)
173177

174-
edge_context_field = GlobalContextField(optional_edge_location)
178+
# The optional_edge_location here is actually referring to the edge field itself.
179+
# This is definitely non-standard, but required to get the proper semantics.
180+
# To get its type, we construct the location of the vertex field on the other side of the edge.
181+
vertex_location = (
182+
optional_edge_location.at_vertex().navigate_to_subpath(optional_edge_location.field)
183+
)
184+
location_type = query_metadata_table.get_location_info(vertex_location).type
185+
186+
edge_context_field = GlobalContextField(optional_edge_location, location_type)
175187
edge_field_non_existence = filter_edge_field_non_existence(edge_context_field)
176188

177189
return BinaryComposition(u'||', edge_field_non_existence, inner_location_existence)
178190

179191

180-
def construct_where_filter_predicate(simple_optional_root_info):
192+
def construct_where_filter_predicate(query_metadata_table, simple_optional_root_info):
181193
"""Return an Expression that is True if and only if each simple optional filter is True.
182194
183195
Construct filters for each simple optional, that are True if and only if `edge_field` does
@@ -186,6 +198,9 @@ def construct_where_filter_predicate(simple_optional_root_info):
186198
evaluate to True (conjunction).
187199
188200
Args:
201+
query_metadata_table: QueryMetadataTable object containing all metadata collected during
202+
query processing, including location metadata (e.g. which locations
203+
are folded or optional).
189204
simple_optional_root_info: dict mapping from simple_optional_root_location -> dict
190205
containing keys
191206
- 'inner_location_name': Location object correspoding to the
@@ -204,9 +219,9 @@ def construct_where_filter_predicate(simple_optional_root_info):
204219
inner_location_name = root_info_dict['inner_location_name']
205220
edge_field = root_info_dict['edge_field']
206221

207-
optional_edge_location = Location(root_location.query_path, field=edge_field)
222+
optional_edge_location = root_location.navigate_to_field(edge_field)
208223
optional_edge_where_filter = _filter_orientdb_simple_optional_edge(
209-
optional_edge_location, inner_location_name)
224+
query_metadata_table, optional_edge_location, inner_location_name)
210225
inner_location_name_to_where_filter[inner_location_name] = optional_edge_where_filter
211226

212227
# Sort expressions by inner_location_name to obtain deterministic order

graphql_compiler/compiler/workarounds/orientdb_eval_scheduling.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@
1616
)
1717

1818

19-
def workaround_lowering_pass(ir_blocks):
19+
def workaround_lowering_pass(ir_blocks, query_metadata_table):
2020
"""Extract locations from TernaryConditionals and rewrite their Filter blocks as necessary."""
2121
new_ir_blocks = []
2222

2323
for block in ir_blocks:
2424
if isinstance(block, Filter):
25-
new_block = _process_filter_block(block)
25+
new_block = _process_filter_block(query_metadata_table, block)
2626
else:
2727
new_block = block
2828
new_ir_blocks.append(new_block)
2929

3030
return new_ir_blocks
3131

3232

33-
def _process_filter_block(block):
33+
def _process_filter_block(query_metadata_table, block):
3434
"""Rewrite the provided Filter block if necessary."""
3535
# For a given Filter block with BinaryComposition predicate expression X,
3636
# let L be the set of all Locations referenced in any TernaryConditional
@@ -85,7 +85,7 @@ def extract_locations_visitor(expression):
8585
u'{} {}'.format(ternary, return_value))
8686

8787
tautologies = [
88-
_create_tautological_expression_for_location(location)
88+
_create_tautological_expression_for_location(query_metadata_table, location)
8989
for location in problematic_locations
9090
]
9191

@@ -98,8 +98,12 @@ def extract_locations_visitor(expression):
9898
return Filter(final_predicate)
9999

100100

101-
def _create_tautological_expression_for_location(location):
101+
def _create_tautological_expression_for_location(query_metadata_table, location):
102102
"""For a given location, create a BinaryComposition that always evaluates to 'true'."""
103-
location_exists = BinaryComposition(u'!=', ContextField(location), NullLiteral)
104-
location_does_not_exist = BinaryComposition(u'=', ContextField(location), NullLiteral)
103+
location_type = query_metadata_table.get_location_info(location).type
104+
105+
location_exists = BinaryComposition(
106+
u'!=', ContextField(location, location_type), NullLiteral)
107+
location_does_not_exist = BinaryComposition(
108+
u'=', ContextField(location, location_type), NullLiteral)
105109
return BinaryComposition(u'||', location_exists, location_does_not_exist)

0 commit comments

Comments
 (0)