|
1 | 1 | # Copyright 2017-present Kensho Technologies, LLC.
|
2 | 2 | """Perform optimizations and lowering of the IR that allows the compiler to emit Gremlin queries.
|
3 | 3 |
|
4 |
| -The compiler IR allows blocks and expressions that cannot be directly compiled to Gremlin or MATCH. |
5 |
| -For example, ContextFieldExistence is an Expression that returns True iff its given vertex exists, |
6 |
| -but the produced Gremlin and MATCH outputs for this purpose are entirely different and not easy |
7 |
| -to generate directly from this Expression object. An output-language-aware IR lowering step allows |
8 |
| -us to convert this Expression into other Expressions, using data already present in the IR, |
9 |
| -to simplify the final code generation step. |
| 4 | +The compiler IR allows blocks and expressions that cannot be directly compiled to the underlying |
| 5 | +database query languages. For example, ContextFieldExistence is an Expression that returns |
| 6 | +True iff its given vertex exists, but the produced Gremlin and MATCH outputs for this purpose |
| 7 | +are entirely different and not easy to generate directly from this Expression object. |
| 8 | +An output-language-aware IR lowering step allows us to convert this Expression into |
| 9 | +other Expressions, using data already present in the IR, to simplify the final code generation step. |
10 | 10 | """
|
11 |
| -from graphql import GraphQLList |
| 11 | +from graphql import GraphQLInt, GraphQLList |
12 | 12 | from graphql.type import GraphQLInterfaceType, GraphQLObjectType, GraphQLUnionType
|
13 | 13 | import six
|
14 | 14 |
|
15 | 15 | from ...exceptions import GraphQLCompilationError
|
16 | 16 | from ...schema import GraphQLDate, GraphQLDateTime
|
17 |
| -from ..blocks import Backtrack, CoerceType, ConstructResult, Filter, MarkLocation, Traverse |
| 17 | +from ..blocks import Backtrack, CoerceType, Filter, GlobalOperationsStart, MarkLocation, Traverse |
18 | 18 | from ..compiler_entities import Expression
|
19 |
| -from ..expressions import BinaryComposition, FoldedContextField, Literal, LocalField, NullLiteral |
| 19 | +from ..expressions import ( |
| 20 | + BinaryComposition, FoldedContextField, Literal, LocalField, NullLiteral, |
| 21 | + make_type_replacement_visitor |
| 22 | +) |
20 | 23 | from ..helpers import (
|
21 | 24 | STANDARD_DATE_FORMAT, STANDARD_DATETIME_FORMAT, FoldScopeLocation,
|
22 | 25 | get_only_element_from_collection, strip_non_null_from_type, validate_safe_string
|
23 | 26 | )
|
24 |
| -from ..ir_lowering_common import extract_folds_from_ir_blocks |
| 27 | +from ..ir_lowering_common.common import extract_folds_from_ir_blocks |
25 | 28 |
|
26 | 29 |
|
27 | 30 | ##################################
|
@@ -148,15 +151,19 @@ def validate(self):
|
148 | 151 | u'Allowed types are {}.'
|
149 | 152 | .format(type(block), self.folded_ir_blocks, allowed_block_types))
|
150 | 153 |
|
151 |
| - if not isinstance(self.field_type, GraphQLList): |
152 |
| - raise ValueError(u'Invalid value of "field_type", expected a list type but got: ' |
153 |
| - u'{}'.format(self.field_type)) |
154 |
| - |
155 |
| - inner_type = strip_non_null_from_type(self.field_type.of_type) |
156 |
| - if isinstance(inner_type, GraphQLList): |
157 |
| - raise GraphQLCompilationError( |
158 |
| - u'Outputting list-valued fields in a @fold context is currently ' |
159 |
| - u'not supported: {} {}'.format(self.fold_scope_location, self.field_type.of_type)) |
| 154 | + bare_field_type = strip_non_null_from_type(self.field_type) |
| 155 | + if isinstance(bare_field_type, GraphQLList): |
| 156 | + inner_type = strip_non_null_from_type(bare_field_type.of_type) |
| 157 | + if isinstance(inner_type, GraphQLList): |
| 158 | + raise GraphQLCompilationError( |
| 159 | + u'Outputting list-valued fields in a @fold context is currently not supported: ' |
| 160 | + u'{} {}'.format(self.fold_scope_location, bare_field_type.of_type)) |
| 161 | + elif GraphQLInt.is_same_type(bare_field_type): |
| 162 | + # This needs to be implemented for @fold _x_count support. |
| 163 | + raise NotImplementedError() |
| 164 | + else: |
| 165 | + raise ValueError(u'Invalid value of "field_type", expected a (possibly non-null) ' |
| 166 | + u'list or int type but got: {}'.format(self.field_type)) |
160 | 167 |
|
161 | 168 | def to_match(self):
|
162 | 169 | """Must never be called."""
|
@@ -316,40 +323,42 @@ def folded_context_visitor(expression):
|
316 | 323 | return new_folded_ir_blocks
|
317 | 324 |
|
318 | 325 |
|
319 |
| -def lower_folded_outputs(ir_blocks): |
320 |
| - """Lower standard folded output fields into GremlinFoldedContextField objects.""" |
| 326 | +def lower_folded_outputs_and_context_fields(ir_blocks): |
| 327 | + """Lower standard folded output / context fields into GremlinFoldedContextField objects.""" |
321 | 328 | folds, remaining_ir_blocks = extract_folds_from_ir_blocks(ir_blocks)
|
322 | 329 |
|
323 | 330 | if not remaining_ir_blocks:
|
324 | 331 | raise AssertionError(u'Expected at least one non-folded block to remain: {} {} '
|
325 | 332 | u'{}'.format(folds, remaining_ir_blocks, ir_blocks))
|
326 |
| - output_block = remaining_ir_blocks[-1] |
327 |
| - if not isinstance(output_block, ConstructResult): |
328 |
| - raise AssertionError(u'Expected the last non-folded block to be ConstructResult, ' |
329 |
| - u'but instead was: {} {} ' |
330 |
| - u'{}'.format(type(output_block), output_block, ir_blocks)) |
331 | 333 |
|
332 | 334 | # Turn folded Filter blocks into GremlinFoldedFilter blocks.
|
333 | 335 | converted_folds = {
|
334 | 336 | base_fold_location.get_location_name()[0]: _convert_folded_blocks(folded_ir_blocks)
|
335 | 337 | for base_fold_location, folded_ir_blocks in six.iteritems(folds)
|
336 | 338 | }
|
337 | 339 |
|
338 |
| - new_output_fields = dict() |
339 |
| - for output_name, output_expression in six.iteritems(output_block.fields): |
340 |
| - new_output_expression = output_expression |
| 340 | + def rewriter_fn(folded_context_field): |
| 341 | + """Rewrite FoldedContextField objects into GremlinFoldedContextField ones.""" |
| 342 | + # Get the matching folded IR blocks and put them in the new context field. |
| 343 | + base_fold_location_name = folded_context_field.fold_scope_location.get_location_name()[0] |
| 344 | + folded_ir_blocks = converted_folds[base_fold_location_name] |
| 345 | + return GremlinFoldedContextField( |
| 346 | + folded_context_field.fold_scope_location, folded_ir_blocks, |
| 347 | + folded_context_field.field_type) |
| 348 | + |
| 349 | + visitor_fn = make_type_replacement_visitor(FoldedContextField, rewriter_fn) |
| 350 | + |
| 351 | + # Start by just appending blocks to the output list. |
| 352 | + new_ir_blocks = [] |
| 353 | + block_collection = new_ir_blocks |
| 354 | + for block in remaining_ir_blocks: |
| 355 | + block_collection.append(block) |
341 | 356 |
|
342 |
| - # Turn FoldedContextField expressions into GremlinFoldedContextField ones. |
343 |
| - if isinstance(output_expression, FoldedContextField): |
344 |
| - # Get the matching folded IR blocks and put them in the new context field. |
345 |
| - base_fold_location_name = output_expression.fold_scope_location.get_location_name()[0] |
346 |
| - folded_ir_blocks = converted_folds[base_fold_location_name] |
347 |
| - new_output_expression = GremlinFoldedContextField( |
348 |
| - output_expression.fold_scope_location, folded_ir_blocks, |
349 |
| - output_expression.field_type) |
| 357 | + if isinstance(block, GlobalOperationsStart): |
| 358 | + # Once we see the GlobalOperationsStart, start accumulating the blocks for rewriting. |
| 359 | + block_collection = [] |
350 | 360 |
|
351 |
| - new_output_fields[output_name] = new_output_expression |
| 361 | + for block in block_collection: |
| 362 | + new_ir_blocks.append(block.visit_and_update_expressions(visitor_fn)) |
352 | 363 |
|
353 |
| - new_ir_blocks = remaining_ir_blocks[:-1] |
354 |
| - new_ir_blocks.append(ConstructResult(new_output_fields)) |
355 | 364 | return new_ir_blocks
|
0 commit comments