Skip to content

Commit eebc64d

Browse files
authored
Merge pull request kensho-technologies#166 from kensho-technologies/add-sql-filters
Adding support for Filter blocks to SQL backend
2 parents 432ac3b + ba0dd02 commit eebc64d

File tree

9 files changed

+353
-32
lines changed

9 files changed

+353
-32
lines changed

graphql_compiler/compiler/emit_sql.py

Lines changed: 184 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
"""Transform a SqlNode tree into an executable SQLAlchemy query."""
33
from collections import namedtuple
44

5-
from sqlalchemy import select
5+
from sqlalchemy import Column, bindparam, select
6+
from sqlalchemy.sql import expression as sql_expressions
7+
from sqlalchemy.sql.elements import BindParameter, and_
68

79
from . import sql_context_helpers
10+
from ..compiler import expressions
11+
from ..compiler.ir_lowering_sql import constants
812

913

1014
# The compilation context holds state that changes during compilation as the tree is traversed
@@ -20,6 +24,12 @@
2024
# renamed status. This tuple is used to construct the query outputs, and track when a name
2125
# changes due to collapsing into a CTE.
2226
'query_path_to_output_fields',
27+
# 'query_path_to_filters': Dict[Tuple[str, ...], List[Filter]], mapping from each query_path
28+
# to the Filter blocks that apply to that query path
29+
'query_path_to_filters',
30+
# 'query_path_to_node': Dict[Tuple[str, ...], SqlNode], mapping from each
31+
# query_path to the SqlNode located at that query_path.
32+
'query_path_to_node',
2333
# 'compiler_metadata': CompilerMetadata, SQLAlchemy metadata about Table objects, and
2434
# further backend specific configuration.
2535
'compiler_metadata',
@@ -40,6 +50,8 @@ def emit_code_from_ir(sql_query_tree, compiler_metadata):
4050
query_path_to_selectable=dict(),
4151
query_path_to_location_info=sql_query_tree.query_path_to_location_info,
4252
query_path_to_output_fields=sql_query_tree.query_path_to_output_fields,
53+
query_path_to_filters=sql_query_tree.query_path_to_filters,
54+
query_path_to_node=sql_query_tree.query_path_to_node,
4355
compiler_metadata=compiler_metadata,
4456
)
4557

@@ -88,27 +100,187 @@ def _create_query(node, context):
88100
Returns:
89101
Selectable, selectable of the generated query.
90102
"""
91-
output_columns = _get_output_columns(node, context)
103+
visited_nodes = [node]
104+
output_columns = _get_output_columns(visited_nodes, context)
105+
filters = _get_filters(visited_nodes, context)
92106
selectable = sql_context_helpers.get_node_selectable(node, context)
93-
query = select(output_columns).select_from(selectable)
107+
query = select(output_columns).select_from(selectable).where(and_(*filters))
94108
return query
95109

96110

97-
def _get_output_columns(node, context):
98-
"""Get the output columns required by the query.
111+
def _get_output_columns(nodes, context):
112+
"""Get the output columns for a list of SqlNodes.
99113
100114
Args:
101-
node: SqlNode, the current node.
115+
nodes: List[SqlNode], the nodes to get output columns from.
102116
context: CompilationContext, global compilation state and metadata.
103117
104118
Returns:
105119
List[Column], list of SqlAlchemy Columns to output for this query.
106120
"""
107-
sql_outputs = context.query_path_to_output_fields[node.query_path]
108121
columns = []
109-
for sql_output in sql_outputs:
110-
field_name = sql_output.field_name
111-
column = sql_context_helpers.get_column(field_name, node, context)
112-
column = column.label(sql_output.output_name)
113-
columns.append(column)
122+
for node in nodes:
123+
for sql_output in sql_context_helpers.get_outputs(node, context):
124+
field_name = sql_output.field_name
125+
column = sql_context_helpers.get_column(field_name, node, context)
126+
column = column.label(sql_output.output_name)
127+
columns.append(column)
114128
return columns
129+
130+
131+
def _get_filters(nodes, context):
132+
"""Get filters to apply to a list of SqlNodes.
133+
134+
Args:
135+
nodes: List[SqlNode], the SqlNodes to get filters for.
136+
context: CompilationContext, global compilation state and metadata.
137+
138+
Returns:
139+
List[Expression], list of SQLAlchemy expressions.
140+
"""
141+
filters = []
142+
for node in nodes:
143+
for filter_block in sql_context_helpers.get_filters(node, context):
144+
filter_sql_expression = _transform_filter_to_sql(filter_block, node, context)
145+
filters.append(filter_sql_expression)
146+
return filters
147+
148+
149+
def _transform_filter_to_sql(filter_block, node, context):
150+
"""Transform a Filter block to its corresponding SQLAlchemy expression.
151+
152+
Args:
153+
filter_block: Filter, the Filter block to transform.
154+
node: SqlNode, the node Filter block applies to.
155+
context: CompilationContext, global compilation state and metadata.
156+
157+
Returns:
158+
Expression, SQLAlchemy expression equivalent to the Filter.predicate expression.
159+
"""
160+
expression = filter_block.predicate
161+
return _expression_to_sql(expression, node, context)
162+
163+
164+
def _expression_to_sql(expression, node, context):
165+
"""Recursively transform a Filter block predicate to its SQLAlchemy expression representation.
166+
167+
Args:
168+
expression: expression, the compiler expression to transform.
169+
node: SqlNode, the SqlNode the expression applies to.
170+
context: CompilationContext, global compilation state and metadata.
171+
172+
Returns:
173+
Expression, SQLAlchemy Expression equivalent to the passed compiler expression.
174+
"""
175+
_expression_transformers = {
176+
expressions.LocalField: _transform_local_field_to_expression,
177+
expressions.Variable: _transform_variable_to_expression,
178+
expressions.Literal: _transform_literal_to_expression,
179+
expressions.BinaryComposition: _transform_binary_composition_to_expression,
180+
}
181+
expression_type = type(expression)
182+
if expression_type not in _expression_transformers:
183+
raise NotImplementedError(
184+
u'Unsupported compiler expression "{}" of type "{}" cannot be converted to SQL '
185+
u'expression.'.format(expression, type(expression)))
186+
return _expression_transformers[expression_type](expression, node, context)
187+
188+
189+
def _transform_binary_composition_to_expression(expression, node, context):
190+
"""Transform a BinaryComposition compiler expression into a SQLAlchemy expression.
191+
192+
Recursively calls _expression_to_sql to convert its left and right sub-expressions.
193+
194+
Args:
195+
expression: expression, BinaryComposition compiler expression.
196+
node: SqlNode, the SqlNode the expression applies to.
197+
context: CompilationContext, global compilation state and metadata.
198+
199+
Returns:
200+
Expression, SQLAlchemy expression.
201+
"""
202+
if expression.operator not in constants.SUPPORTED_OPERATORS:
203+
raise NotImplementedError(
204+
u'Filter operation "{}" is not supported by the SQL backend.'.format(
205+
expression.operator))
206+
sql_operator = constants.SUPPORTED_OPERATORS[expression.operator]
207+
left = _expression_to_sql(expression.left, node, context)
208+
right = _expression_to_sql(expression.right, node, context)
209+
if sql_operator.cardinality == constants.CARDINALITY_UNARY:
210+
left, right = _get_column_and_bindparam(left, right, sql_operator)
211+
clause = getattr(left, sql_operator.name)(right)
212+
return clause
213+
elif sql_operator.cardinality == constants.CARDINALITY_BINARY:
214+
clause = getattr(sql_expressions, sql_operator.name)(left, right)
215+
return clause
216+
elif sql_operator.cardinality == constants.CARDINALITY_LIST_VALUED:
217+
left, right = _get_column_and_bindparam(left, right, sql_operator)
218+
# ensure that SQLAlchemy treats the right bind parameter as list valued
219+
right.expanding = True
220+
clause = getattr(left, sql_operator.name)(right)
221+
return clause
222+
raise AssertionError(u'Unreachable, operator cardinality {} for compiler expression {} is '
223+
u'unknown'.format(sql_operator.cardinality, expression))
224+
225+
226+
def _get_column_and_bindparam(left, right, operator):
227+
"""Return left and right expressions in (Column, BindParameter) order."""
228+
if not isinstance(left, Column):
229+
left, right = right, left
230+
if not isinstance(left, Column):
231+
raise AssertionError(
232+
u'SQLAlchemy operator {} expects Column as left side the of expression, got {} '
233+
u'of type {} instead.'.format(operator, left, type(left)))
234+
if not isinstance(right, BindParameter):
235+
raise AssertionError(
236+
u'SQLAlchemy operator {} expects BindParameter as the right side of the expression, '
237+
u'got {} of type {} instead.'.format(operator, right, type(right)))
238+
return left, right
239+
240+
241+
def _transform_literal_to_expression(expression, node, context):
242+
"""Transform a Literal compiler expression into its SQLAlchemy expression representation.
243+
244+
Args:
245+
expression: expression, Literal compiler expression.
246+
node: SqlNode, the SqlNode the expression applies to.
247+
context: CompilationContext, global compilation state and metadata.
248+
249+
Returns:
250+
Expression, SQLAlchemy expression.
251+
"""
252+
return expression.value
253+
254+
255+
def _transform_variable_to_expression(expression, node, context):
256+
"""Transform a Variable compiler expression into its SQLAlchemy expression representation.
257+
258+
Args:
259+
expression: expression, Variable compiler expression.
260+
node: SqlNode, the SqlNode the expression applies to.
261+
context: CompilationContext, global compilation state and metadata.
262+
263+
Returns:
264+
Expression, SQLAlchemy expression.
265+
"""
266+
variable_name = expression.variable_name
267+
if not variable_name.startswith(u'$'):
268+
raise AssertionError(u'Unexpectedly received variable name {} that is not '
269+
u'prefixed with "$"'.format(variable_name))
270+
return bindparam(variable_name[1:])
271+
272+
273+
def _transform_local_field_to_expression(expression, node, context):
274+
"""Transform a LocalField compiler expression into its SQLAlchemy expression representation.
275+
276+
Args:
277+
expression: expression, LocalField compiler expression.
278+
node: SqlNode, the SqlNode the expression applies to.
279+
context: CompilationContext, global compilation state and metadata.
280+
281+
Returns:
282+
Expression, SQLAlchemy expression.
283+
"""
284+
column_name = expression.field_name
285+
column = sql_context_helpers.get_column(column_name, node, context)
286+
return column

graphql_compiler/compiler/ir_lowering_sql/__init__.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .sql_tree import SqlNode, SqlQueryTree
66
from .. import blocks
7+
from ...compiler import expressions
78
from ...compiler.helpers import Location
89
from ..ir_lowering_sql import constants
910
from ..metadata import LocationInfo
@@ -45,10 +46,14 @@ def lower_ir(ir_blocks, query_metadata_table, type_equivalence_hints=None):
4546
query_path_to_location_info = _map_query_path_to_location_info(query_metadata_table)
4647
query_path_to_output_fields = _map_query_path_to_outputs(
4748
construct_result, query_path_to_location_info)
49+
block_index_to_location = _map_block_index_to_location(ir_blocks)
50+
51+
# perform lowering steps
52+
ir_blocks = lower_unary_transformations(ir_blocks)
4853

4954
# iteratively construct SqlTree
5055
query_path_to_node = {}
51-
block_index_to_location = _map_block_index_to_location(ir_blocks)
56+
query_path_to_filters = {}
5257
tree_root = None
5358
for index, block in enumerate(ir_blocks):
5459
if isinstance(block, constants.SKIPPABLE_BLOCK_TYPES):
@@ -64,12 +69,15 @@ def lower_ir(ir_blocks, query_metadata_table, type_equivalence_hints=None):
6469
block, tree_root, ir_blocks, query_metadata_table))
6570
tree_root = SqlNode(block=block, query_path=query_path)
6671
query_path_to_node[query_path] = tree_root
72+
elif isinstance(block, blocks.Filter):
73+
query_path_to_filters.setdefault(query_path, []).append(block)
6774
else:
6875
raise AssertionError(
6976
u'Unsupported block {} unexpectedly passed validation for IR blocks '
7077
u'{} with query metadata table {} .'.format(block, ir_blocks, query_metadata_table))
7178

72-
return SqlQueryTree(tree_root, query_path_to_location_info, query_path_to_output_fields)
79+
return SqlQueryTree(tree_root, query_path_to_location_info, query_path_to_output_fields,
80+
query_path_to_filters, query_path_to_node)
7381

7482

7583
def _validate_all_blocks_supported(ir_blocks, query_metadata_table):
@@ -223,3 +231,21 @@ def _map_block_index_to_location(ir_blocks):
223231
block_index_to_location[ix] = ir_block.location
224232
current_block_ixs = []
225233
return block_index_to_location
234+
235+
236+
def lower_unary_transformations(ir_blocks):
237+
"""Raise exception if any unary transformation block encountered."""
238+
def visitor_fn(expression):
239+
"""Raise error if current expression is a UnaryTransformation."""
240+
if not isinstance(expression, expressions.UnaryTransformation):
241+
return expression
242+
raise NotImplementedError(
243+
u'UnaryTransformation expression "{}" encountered with IR blocks {} is unsupported by '
244+
u'the SQL backend.'.format(expression, ir_blocks)
245+
)
246+
247+
new_ir_blocks = [
248+
block.visit_and_update_expressions(visitor_fn)
249+
for block in ir_blocks
250+
]
251+
return new_ir_blocks

graphql_compiler/compiler/ir_lowering_sql/constants.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,42 @@
2020

2121
SUPPORTED_BLOCK_TYPES = (
2222
blocks.QueryRoot,
23+
blocks.Filter,
2324
)
2425

2526
SUPPORTED_OUTPUT_EXPRESSION_TYPES = (
2627
expressions.OutputContextField,
2728
)
2829

2930

31+
Operator = namedtuple('Operator', ('name', 'cardinality'))
32+
33+
34+
CARDINALITY_UNARY = 'UNARY'
35+
CARDINALITY_BINARY = 'BINARY'
36+
CARDINALITY_LIST_VALUED = 'LIST_VALUED'
37+
38+
39+
# The mapping supplied for SUPPORTED_OPERATORS allows for programmatic resolution of expressions
40+
# to their SQLAlchemy equivalents. As a concrete example, when converting the GraphQL filter
41+
# column_name @filter(op_name: "=", value: ["$variable_name"])
42+
# the corresponding python uses SQLAlchemy operation `__eq__` in a call like
43+
# getattr(Column(column_name), '__eq__')(BindParameter('variable_name')
44+
# which programattically generates the equivalent of the desired SQLAlchemy expression
45+
# Column('column_name') == BindParameter('variable_name')
46+
SUPPORTED_OPERATORS = {
47+
u'contains': Operator(u'in_', CARDINALITY_LIST_VALUED),
48+
u'&&': Operator(u'and_', CARDINALITY_BINARY),
49+
u'||': Operator(u'or_', CARDINALITY_BINARY),
50+
u'=': Operator(u'__eq__', CARDINALITY_UNARY),
51+
u'<': Operator(u'__lt__', CARDINALITY_UNARY),
52+
u'>': Operator(u'__gt__', CARDINALITY_UNARY),
53+
u'<=': Operator(u'__le__', CARDINALITY_UNARY),
54+
u'>=': Operator(u'__ge__', CARDINALITY_UNARY),
55+
u'has_substring': Operator(u'contains', CARDINALITY_UNARY),
56+
}
57+
58+
3059
class SqlBackend(object):
3160

3261
def __init__(self, backend):

graphql_compiler/compiler/ir_lowering_sql/sql_tree.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33

44
class SqlQueryTree(object):
5-
def __init__(self, root, query_path_to_location_info, query_path_to_output_fields):
5+
def __init__(self, root, query_path_to_location_info,
6+
query_path_to_output_fields, query_path_to_filters, query_path_to_node):
67
"""Wrap a SqlNode root with additional location_info metadata."""
78
self.root = root
89
self.query_path_to_location_info = query_path_to_location_info
910
self.query_path_to_output_fields = query_path_to_output_fields
11+
self.query_path_to_filters = query_path_to_filters
12+
self.query_path_to_node = query_path_to_node
1013

1114

1215
class SqlNode(object):

0 commit comments

Comments
 (0)