2
2
"""Transform a SqlNode tree into an executable SQLAlchemy query."""
3
3
from collections import namedtuple
4
4
5
- from sqlalchemy import select
5
+ from sqlalchemy import Column , bindparam , select
6
+ from sqlalchemy .sql import expression as sql_expressions
7
+ from sqlalchemy .sql .elements import BindParameter , and_
6
8
7
9
from . import sql_context_helpers
10
+ from ..compiler import expressions
11
+ from ..compiler .ir_lowering_sql import constants
8
12
9
13
10
14
# The compilation context holds state that changes during compilation as the tree is traversed
20
24
# renamed status. This tuple is used to construct the query outputs, and track when a name
21
25
# changes due to collapsing into a CTE.
22
26
'query_path_to_output_fields' ,
27
+ # 'query_path_to_filters': Dict[Tuple[str, ...], List[Filter]], mapping from each query_path
28
+ # to the Filter blocks that apply to that query path
29
+ 'query_path_to_filters' ,
30
+ # 'query_path_to_node': Dict[Tuple[str, ...], SqlNode], mapping from each
31
+ # query_path to the SqlNode located at that query_path.
32
+ 'query_path_to_node' ,
23
33
# 'compiler_metadata': CompilerMetadata, SQLAlchemy metadata about Table objects, and
24
34
# further backend specific configuration.
25
35
'compiler_metadata' ,
@@ -40,6 +50,8 @@ def emit_code_from_ir(sql_query_tree, compiler_metadata):
40
50
query_path_to_selectable = dict (),
41
51
query_path_to_location_info = sql_query_tree .query_path_to_location_info ,
42
52
query_path_to_output_fields = sql_query_tree .query_path_to_output_fields ,
53
+ query_path_to_filters = sql_query_tree .query_path_to_filters ,
54
+ query_path_to_node = sql_query_tree .query_path_to_node ,
43
55
compiler_metadata = compiler_metadata ,
44
56
)
45
57
@@ -88,27 +100,187 @@ def _create_query(node, context):
88
100
Returns:
89
101
Selectable, selectable of the generated query.
90
102
"""
91
- output_columns = _get_output_columns (node , context )
103
+ visited_nodes = [node ]
104
+ output_columns = _get_output_columns (visited_nodes , context )
105
+ filters = _get_filters (visited_nodes , context )
92
106
selectable = sql_context_helpers .get_node_selectable (node , context )
93
- query = select (output_columns ).select_from (selectable )
107
+ query = select (output_columns ).select_from (selectable ). where ( and_ ( * filters ))
94
108
return query
95
109
96
110
97
- def _get_output_columns (node , context ):
98
- """Get the output columns required by the query .
111
+ def _get_output_columns (nodes , context ):
112
+ """Get the output columns for a list of SqlNodes .
99
113
100
114
Args:
101
- node: SqlNode, the current node .
115
+ nodes: List[ SqlNode] , the nodes to get output columns from .
102
116
context: CompilationContext, global compilation state and metadata.
103
117
104
118
Returns:
105
119
List[Column], list of SqlAlchemy Columns to output for this query.
106
120
"""
107
- sql_outputs = context .query_path_to_output_fields [node .query_path ]
108
121
columns = []
109
- for sql_output in sql_outputs :
110
- field_name = sql_output .field_name
111
- column = sql_context_helpers .get_column (field_name , node , context )
112
- column = column .label (sql_output .output_name )
113
- columns .append (column )
122
+ for node in nodes :
123
+ for sql_output in sql_context_helpers .get_outputs (node , context ):
124
+ field_name = sql_output .field_name
125
+ column = sql_context_helpers .get_column (field_name , node , context )
126
+ column = column .label (sql_output .output_name )
127
+ columns .append (column )
114
128
return columns
129
+
130
+
131
+ def _get_filters (nodes , context ):
132
+ """Get filters to apply to a list of SqlNodes.
133
+
134
+ Args:
135
+ nodes: List[SqlNode], the SqlNodes to get filters for.
136
+ context: CompilationContext, global compilation state and metadata.
137
+
138
+ Returns:
139
+ List[Expression], list of SQLAlchemy expressions.
140
+ """
141
+ filters = []
142
+ for node in nodes :
143
+ for filter_block in sql_context_helpers .get_filters (node , context ):
144
+ filter_sql_expression = _transform_filter_to_sql (filter_block , node , context )
145
+ filters .append (filter_sql_expression )
146
+ return filters
147
+
148
+
149
+ def _transform_filter_to_sql (filter_block , node , context ):
150
+ """Transform a Filter block to its corresponding SQLAlchemy expression.
151
+
152
+ Args:
153
+ filter_block: Filter, the Filter block to transform.
154
+ node: SqlNode, the node Filter block applies to.
155
+ context: CompilationContext, global compilation state and metadata.
156
+
157
+ Returns:
158
+ Expression, SQLAlchemy expression equivalent to the Filter.predicate expression.
159
+ """
160
+ expression = filter_block .predicate
161
+ return _expression_to_sql (expression , node , context )
162
+
163
+
164
+ def _expression_to_sql (expression , node , context ):
165
+ """Recursively transform a Filter block predicate to its SQLAlchemy expression representation.
166
+
167
+ Args:
168
+ expression: expression, the compiler expression to transform.
169
+ node: SqlNode, the SqlNode the expression applies to.
170
+ context: CompilationContext, global compilation state and metadata.
171
+
172
+ Returns:
173
+ Expression, SQLAlchemy Expression equivalent to the passed compiler expression.
174
+ """
175
+ _expression_transformers = {
176
+ expressions .LocalField : _transform_local_field_to_expression ,
177
+ expressions .Variable : _transform_variable_to_expression ,
178
+ expressions .Literal : _transform_literal_to_expression ,
179
+ expressions .BinaryComposition : _transform_binary_composition_to_expression ,
180
+ }
181
+ expression_type = type (expression )
182
+ if expression_type not in _expression_transformers :
183
+ raise NotImplementedError (
184
+ u'Unsupported compiler expression "{}" of type "{}" cannot be converted to SQL '
185
+ u'expression.' .format (expression , type (expression )))
186
+ return _expression_transformers [expression_type ](expression , node , context )
187
+
188
+
189
+ def _transform_binary_composition_to_expression (expression , node , context ):
190
+ """Transform a BinaryComposition compiler expression into a SQLAlchemy expression.
191
+
192
+ Recursively calls _expression_to_sql to convert its left and right sub-expressions.
193
+
194
+ Args:
195
+ expression: expression, BinaryComposition compiler expression.
196
+ node: SqlNode, the SqlNode the expression applies to.
197
+ context: CompilationContext, global compilation state and metadata.
198
+
199
+ Returns:
200
+ Expression, SQLAlchemy expression.
201
+ """
202
+ if expression .operator not in constants .SUPPORTED_OPERATORS :
203
+ raise NotImplementedError (
204
+ u'Filter operation "{}" is not supported by the SQL backend.' .format (
205
+ expression .operator ))
206
+ sql_operator = constants .SUPPORTED_OPERATORS [expression .operator ]
207
+ left = _expression_to_sql (expression .left , node , context )
208
+ right = _expression_to_sql (expression .right , node , context )
209
+ if sql_operator .cardinality == constants .CARDINALITY_UNARY :
210
+ left , right = _get_column_and_bindparam (left , right , sql_operator )
211
+ clause = getattr (left , sql_operator .name )(right )
212
+ return clause
213
+ elif sql_operator .cardinality == constants .CARDINALITY_BINARY :
214
+ clause = getattr (sql_expressions , sql_operator .name )(left , right )
215
+ return clause
216
+ elif sql_operator .cardinality == constants .CARDINALITY_LIST_VALUED :
217
+ left , right = _get_column_and_bindparam (left , right , sql_operator )
218
+ # ensure that SQLAlchemy treats the right bind parameter as list valued
219
+ right .expanding = True
220
+ clause = getattr (left , sql_operator .name )(right )
221
+ return clause
222
+ raise AssertionError (u'Unreachable, operator cardinality {} for compiler expression {} is '
223
+ u'unknown' .format (sql_operator .cardinality , expression ))
224
+
225
+
226
+ def _get_column_and_bindparam (left , right , operator ):
227
+ """Return left and right expressions in (Column, BindParameter) order."""
228
+ if not isinstance (left , Column ):
229
+ left , right = right , left
230
+ if not isinstance (left , Column ):
231
+ raise AssertionError (
232
+ u'SQLAlchemy operator {} expects Column as left side the of expression, got {} '
233
+ u'of type {} instead.' .format (operator , left , type (left )))
234
+ if not isinstance (right , BindParameter ):
235
+ raise AssertionError (
236
+ u'SQLAlchemy operator {} expects BindParameter as the right side of the expression, '
237
+ u'got {} of type {} instead.' .format (operator , right , type (right )))
238
+ return left , right
239
+
240
+
241
+ def _transform_literal_to_expression (expression , node , context ):
242
+ """Transform a Literal compiler expression into its SQLAlchemy expression representation.
243
+
244
+ Args:
245
+ expression: expression, Literal compiler expression.
246
+ node: SqlNode, the SqlNode the expression applies to.
247
+ context: CompilationContext, global compilation state and metadata.
248
+
249
+ Returns:
250
+ Expression, SQLAlchemy expression.
251
+ """
252
+ return expression .value
253
+
254
+
255
+ def _transform_variable_to_expression (expression , node , context ):
256
+ """Transform a Variable compiler expression into its SQLAlchemy expression representation.
257
+
258
+ Args:
259
+ expression: expression, Variable compiler expression.
260
+ node: SqlNode, the SqlNode the expression applies to.
261
+ context: CompilationContext, global compilation state and metadata.
262
+
263
+ Returns:
264
+ Expression, SQLAlchemy expression.
265
+ """
266
+ variable_name = expression .variable_name
267
+ if not variable_name .startswith (u'$' ):
268
+ raise AssertionError (u'Unexpectedly received variable name {} that is not '
269
+ u'prefixed with "$"' .format (variable_name ))
270
+ return bindparam (variable_name [1 :])
271
+
272
+
273
+ def _transform_local_field_to_expression (expression , node , context ):
274
+ """Transform a LocalField compiler expression into its SQLAlchemy expression representation.
275
+
276
+ Args:
277
+ expression: expression, LocalField compiler expression.
278
+ node: SqlNode, the SqlNode the expression applies to.
279
+ context: CompilationContext, global compilation state and metadata.
280
+
281
+ Returns:
282
+ Expression, SQLAlchemy expression.
283
+ """
284
+ column_name = expression .field_name
285
+ column = sql_context_helpers .get_column (column_name , node , context )
286
+ return column
0 commit comments