Skip to content

Commit 3acc466

Browse files
authored
Refactor and clean up IR lowering code. Add common IR lowering module. (kensho-technologies#324)
* Refactor and clean up IR lowering code. Add common IR lowering operations module. * Delint. * Improve expression comment around null values. * Add defensive strip_non_null_from_type() call.
1 parent ce81ebb commit 3acc466

File tree

13 files changed

+198
-214
lines changed

13 files changed

+198
-214
lines changed

graphql_compiler/compiler/expressions.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ def to_match(self):
205205

206206
def to_gremlin(self):
207207
"""Return a unicode object with the Gremlin representation of this expression."""
208+
self.validate()
209+
208210
# We can't directly pass a Date or a DateTime object, so we have to pass it as a string
209211
# and then parse it inline. For date format parameter meanings, see:
210212
# http://docs.oracle.com/javase/7/docs/api/java/text/SimpleDateFormat.html
@@ -559,8 +561,9 @@ def to_match(self):
559561
return template % template_data
560562

561563
def to_gremlin(self):
562-
"""Must never be called."""
563-
raise NotImplementedError()
564+
"""Not implemented, should not be used."""
565+
raise AssertionError(u'FoldedContextField are not used during the query emission process '
566+
u'in Gremlin, so this is a bug. This function should not be called.')
564567

565568
def __eq__(self, other):
566569
"""Return True if the given object is equal to this one, and False otherwise."""
@@ -618,7 +621,7 @@ def to_match(self):
618621
return template % template_data
619622

620623
def to_gremlin(self):
621-
"""Must never be called."""
624+
"""Not supported yet."""
622625
raise NotImplementedError()
623626

624627

@@ -799,9 +802,9 @@ def to_match(self):
799802
intersects_operator_format = '(%(operator)s(%(left)s, %(right)s).asList().size() > 0)'
800803
# pylint: enable=unused-variable
801804

802-
# Null literals use 'is/is not' as (in)equality operators, while other values use '=/<>'.
803-
if any((isinstance(self.left, Literal) and self.left.value is None,
804-
isinstance(self.right, Literal) and self.right.value is None)):
805+
# Null literals use the OrientDB 'IS/IS NOT' (in)equality operators,
806+
# while other values use the OrientDB '=/<>' operators.
807+
if self.left == NullLiteral or self.right == NullLiteral:
805808
translation_table = {
806809
u'=': (u'IS', regular_operator_format),
807810
u'!=': (u'IS NOT', regular_operator_format),
@@ -947,6 +950,7 @@ def visitor_fn(expression):
947950
def to_gremlin(self):
948951
"""Return a unicode object with the Gremlin representation of this expression."""
949952
self.validate()
953+
950954
return u'({predicate} ? {if_true} : {if_false})'.format(
951955
predicate=self.predicate.to_gremlin(),
952956
if_true=self.if_true.to_gremlin(),
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright 2019-present Kensho Technologies, LLC.

graphql_compiler/compiler/ir_lowering_common.py renamed to graphql_compiler/compiler/ir_lowering_common/common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
"""Language-independent IR lowering and optimization functions."""
33
import six
44

5-
from .blocks import (
5+
from ..blocks import (
66
ConstructResult, EndOptional, Filter, Fold, MarkLocation, Recurse, Traverse, Unfold
77
)
8-
from .expressions import (
8+
from ..expressions import (
99
BinaryComposition, ContextField, ContextFieldExistence, FalseLiteral, NullLiteral, TrueLiteral
1010
)
11-
from .helpers import validate_safe_string
11+
from ..helpers import validate_safe_string
1212

1313

1414
def merge_consecutive_filter_clauses(ir_blocks):
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2019-present Kensho Technologies, LLC.
2+
"""Utilities for rewriting IR to replace one set of locations with another."""
3+
import six
4+
5+
from ..helpers import FoldScopeLocation, Location
6+
7+
8+
def make_revisit_location_translations(query_metadata_table):
9+
"""Return a dict mapping location revisits to the location being revisited, for rewriting."""
10+
location_translations = dict()
11+
12+
for location, _ in query_metadata_table.registered_locations:
13+
location_being_revisited = query_metadata_table.get_revisit_origin(location)
14+
if location_being_revisited != location:
15+
location_translations[location] = location_being_revisited
16+
17+
return location_translations
18+
19+
20+
def translate_potential_location(location_translations, potential_location):
21+
"""If the input is a BaseLocation object, translate it, otherwise return it as-is."""
22+
if isinstance(potential_location, Location):
23+
old_location_at_vertex = potential_location.at_vertex()
24+
field = potential_location.field
25+
26+
new_location = location_translations.get(old_location_at_vertex, None)
27+
if new_location is None:
28+
# No translation needed.
29+
return potential_location
30+
else:
31+
# If necessary, add the field component to the new location before returning it.
32+
if field is None:
33+
return new_location
34+
else:
35+
return new_location.navigate_to_field(field)
36+
elif isinstance(potential_location, FoldScopeLocation):
37+
old_base_location = potential_location.base_location
38+
new_base_location = location_translations.get(old_base_location, old_base_location)
39+
fold_path = potential_location.fold_path
40+
fold_field = potential_location.field
41+
return FoldScopeLocation(new_base_location, fold_path, field=fold_field)
42+
else:
43+
return potential_location
44+
45+
46+
def make_location_rewriter_visitor_fn(location_translations):
47+
"""Return a visitor function that is able to replace locations with equivalent locations."""
48+
def visitor_fn(expression):
49+
"""Expression visitor function used to rewrite expressions with updated Location data."""
50+
# All CompilerEntity objects store their exact constructor input args/kwargs.
51+
# To minimize the chances that we forget to update a location somewhere in an expression,
52+
# we rewrite all locations that we find as arguments to expression constructors.
53+
# pylint: disable=protected-access
54+
new_args = [
55+
translate_potential_location(location_translations, arg)
56+
for arg in expression._print_args
57+
]
58+
new_kwargs = {
59+
kwarg_name: translate_potential_location(location_translations, kwarg_value)
60+
for kwarg_name, kwarg_value in six.iteritems(expression._print_kwargs)
61+
}
62+
# pylint: enable=protected-access
63+
64+
expression_cls = type(expression)
65+
return expression_cls(*new_args, **new_kwargs)
66+
67+
return visitor_fn

graphql_compiler/compiler/ir_lowering_gremlin/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# Copyright 2018-present Kensho Technologies, LLC.
22
from .ir_lowering import (lower_coerce_type_block_type_data, lower_coerce_type_blocks,
3-
lower_folded_outputs, rewrite_filters_in_optional_blocks)
3+
lower_folded_outputs_and_context_fields,
4+
rewrite_filters_in_optional_blocks)
45
from ..ir_sanity_checks import sanity_check_ir_blocks_from_frontend
5-
from ..ir_lowering_common import (lower_context_field_existence, merge_consecutive_filter_clauses,
6-
optimize_boolean_expression_comparisons)
6+
from ..ir_lowering_common.common import (lower_context_field_existence,
7+
merge_consecutive_filter_clauses,
8+
optimize_boolean_expression_comparisons)
79

810

911
##############
@@ -48,6 +50,6 @@ def lower_ir(ir_blocks, query_metadata_table, type_equivalence_hints=None):
4850
ir_blocks = lower_coerce_type_blocks(ir_blocks)
4951
ir_blocks = rewrite_filters_in_optional_blocks(ir_blocks)
5052
ir_blocks = merge_consecutive_filter_clauses(ir_blocks)
51-
ir_blocks = lower_folded_outputs(ir_blocks)
53+
ir_blocks = lower_folded_outputs_and_context_fields(ir_blocks)
5254

5355
return ir_blocks

graphql_compiler/compiler/ir_lowering_gremlin/ir_lowering.py

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,30 @@
11
# Copyright 2017-present Kensho Technologies, LLC.
22
"""Perform optimizations and lowering of the IR that allows the compiler to emit Gremlin queries.
33
4-
The compiler IR allows blocks and expressions that cannot be directly compiled to Gremlin or MATCH.
5-
For example, ContextFieldExistence is an Expression that returns True iff its given vertex exists,
6-
but the produced Gremlin and MATCH outputs for this purpose are entirely different and not easy
7-
to generate directly from this Expression object. An output-language-aware IR lowering step allows
8-
us to convert this Expression into other Expressions, using data already present in the IR,
9-
to simplify the final code generation step.
4+
The compiler IR allows blocks and expressions that cannot be directly compiled to the underlying
5+
database query languages. For example, ContextFieldExistence is an Expression that returns
6+
True iff its given vertex exists, but the produced Gremlin and MATCH outputs for this purpose
7+
are entirely different and not easy to generate directly from this Expression object.
8+
An output-language-aware IR lowering step allows us to convert this Expression into
9+
other Expressions, using data already present in the IR, to simplify the final code generation step.
1010
"""
11-
from graphql import GraphQLList
11+
from graphql import GraphQLInt, GraphQLList
1212
from graphql.type import GraphQLInterfaceType, GraphQLObjectType, GraphQLUnionType
1313
import six
1414

1515
from ...exceptions import GraphQLCompilationError
1616
from ...schema import GraphQLDate, GraphQLDateTime
17-
from ..blocks import Backtrack, CoerceType, ConstructResult, Filter, MarkLocation, Traverse
17+
from ..blocks import Backtrack, CoerceType, Filter, GlobalOperationsStart, MarkLocation, Traverse
1818
from ..compiler_entities import Expression
19-
from ..expressions import BinaryComposition, FoldedContextField, Literal, LocalField, NullLiteral
19+
from ..expressions import (
20+
BinaryComposition, FoldedContextField, Literal, LocalField, NullLiteral,
21+
make_type_replacement_visitor
22+
)
2023
from ..helpers import (
2124
STANDARD_DATE_FORMAT, STANDARD_DATETIME_FORMAT, FoldScopeLocation,
2225
get_only_element_from_collection, strip_non_null_from_type, validate_safe_string
2326
)
24-
from ..ir_lowering_common import extract_folds_from_ir_blocks
27+
from ..ir_lowering_common.common import extract_folds_from_ir_blocks
2528

2629

2730
##################################
@@ -148,15 +151,19 @@ def validate(self):
148151
u'Allowed types are {}.'
149152
.format(type(block), self.folded_ir_blocks, allowed_block_types))
150153

151-
if not isinstance(self.field_type, GraphQLList):
152-
raise ValueError(u'Invalid value of "field_type", expected a list type but got: '
153-
u'{}'.format(self.field_type))
154-
155-
inner_type = strip_non_null_from_type(self.field_type.of_type)
156-
if isinstance(inner_type, GraphQLList):
157-
raise GraphQLCompilationError(
158-
u'Outputting list-valued fields in a @fold context is currently '
159-
u'not supported: {} {}'.format(self.fold_scope_location, self.field_type.of_type))
154+
bare_field_type = strip_non_null_from_type(self.field_type)
155+
if isinstance(bare_field_type, GraphQLList):
156+
inner_type = strip_non_null_from_type(bare_field_type.of_type)
157+
if isinstance(inner_type, GraphQLList):
158+
raise GraphQLCompilationError(
159+
u'Outputting list-valued fields in a @fold context is currently not supported: '
160+
u'{} {}'.format(self.fold_scope_location, bare_field_type.of_type))
161+
elif GraphQLInt.is_same_type(bare_field_type):
162+
# This needs to be implemented for @fold _x_count support.
163+
raise NotImplementedError()
164+
else:
165+
raise ValueError(u'Invalid value of "field_type", expected a (possibly non-null) '
166+
u'list or int type but got: {}'.format(self.field_type))
160167

161168
def to_match(self):
162169
"""Must never be called."""
@@ -316,40 +323,42 @@ def folded_context_visitor(expression):
316323
return new_folded_ir_blocks
317324

318325

319-
def lower_folded_outputs(ir_blocks):
320-
"""Lower standard folded output fields into GremlinFoldedContextField objects."""
326+
def lower_folded_outputs_and_context_fields(ir_blocks):
327+
"""Lower standard folded output / context fields into GremlinFoldedContextField objects."""
321328
folds, remaining_ir_blocks = extract_folds_from_ir_blocks(ir_blocks)
322329

323330
if not remaining_ir_blocks:
324331
raise AssertionError(u'Expected at least one non-folded block to remain: {} {} '
325332
u'{}'.format(folds, remaining_ir_blocks, ir_blocks))
326-
output_block = remaining_ir_blocks[-1]
327-
if not isinstance(output_block, ConstructResult):
328-
raise AssertionError(u'Expected the last non-folded block to be ConstructResult, '
329-
u'but instead was: {} {} '
330-
u'{}'.format(type(output_block), output_block, ir_blocks))
331333

332334
# Turn folded Filter blocks into GremlinFoldedFilter blocks.
333335
converted_folds = {
334336
base_fold_location.get_location_name()[0]: _convert_folded_blocks(folded_ir_blocks)
335337
for base_fold_location, folded_ir_blocks in six.iteritems(folds)
336338
}
337339

338-
new_output_fields = dict()
339-
for output_name, output_expression in six.iteritems(output_block.fields):
340-
new_output_expression = output_expression
340+
def rewriter_fn(folded_context_field):
341+
"""Rewrite FoldedContextField objects into GremlinFoldedContextField ones."""
342+
# Get the matching folded IR blocks and put them in the new context field.
343+
base_fold_location_name = folded_context_field.fold_scope_location.get_location_name()[0]
344+
folded_ir_blocks = converted_folds[base_fold_location_name]
345+
return GremlinFoldedContextField(
346+
folded_context_field.fold_scope_location, folded_ir_blocks,
347+
folded_context_field.field_type)
348+
349+
visitor_fn = make_type_replacement_visitor(FoldedContextField, rewriter_fn)
350+
351+
# Start by just appending blocks to the output list.
352+
new_ir_blocks = []
353+
block_collection = new_ir_blocks
354+
for block in remaining_ir_blocks:
355+
block_collection.append(block)
341356

342-
# Turn FoldedContextField expressions into GremlinFoldedContextField ones.
343-
if isinstance(output_expression, FoldedContextField):
344-
# Get the matching folded IR blocks and put them in the new context field.
345-
base_fold_location_name = output_expression.fold_scope_location.get_location_name()[0]
346-
folded_ir_blocks = converted_folds[base_fold_location_name]
347-
new_output_expression = GremlinFoldedContextField(
348-
output_expression.fold_scope_location, folded_ir_blocks,
349-
output_expression.field_type)
357+
if isinstance(block, GlobalOperationsStart):
358+
# Once we see the GlobalOperationsStart, start accumulating the blocks for rewriting.
359+
block_collection = []
350360

351-
new_output_fields[output_name] = new_output_expression
361+
for block in block_collection:
362+
new_ir_blocks.append(block.visit_and_update_expressions(visitor_fn))
352363

353-
new_ir_blocks = remaining_ir_blocks[:-1]
354-
new_ir_blocks.append(ConstructResult(new_output_fields))
355364
return new_ir_blocks

graphql_compiler/compiler/ir_lowering_match/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import six
33

44
from ..blocks import Filter
5-
from ..ir_lowering_common import (extract_optional_location_root_info,
6-
extract_simple_optional_location_info,
7-
lower_context_field_existence, merge_consecutive_filter_clauses,
8-
optimize_boolean_expression_comparisons, remove_end_optionals)
5+
from ..ir_lowering_common.common import (extract_optional_location_root_info,
6+
extract_simple_optional_location_info,
7+
lower_context_field_existence,
8+
merge_consecutive_filter_clauses,
9+
optimize_boolean_expression_comparisons,
10+
remove_end_optionals)
911
from .ir_lowering import (lower_backtrack_blocks,
1012
lower_folded_coerce_types_into_filter_blocks,
1113
lower_has_substring_binary_compositions,
@@ -100,7 +102,7 @@ def lower_ir(ir_blocks, query_metadata_table, type_equivalence_hints=None):
100102

101103
match_query = lower_comparisons_to_between(match_query)
102104

103-
match_query = lower_backtrack_blocks(match_query, location_types)
105+
match_query = lower_backtrack_blocks(match_query, query_metadata_table)
104106
match_query = truncate_repeated_single_step_traversals(match_query)
105107
match_query = orientdb_class_with_while.workaround_type_coercions_in_recursions(match_query)
106108

0 commit comments

Comments
 (0)