1
+ from __future__ import annotations
2
+
3
+ import ast
1
4
import os
2
5
import re
3
6
from pathlib import Path
7
+ from textwrap import dedent
8
+ from typing import TYPE_CHECKING , Union
4
9
5
10
import libcst as cst
6
11
7
12
from codeflash .cli_cmds .console import logger
8
13
from codeflash .code_utils .time_utils import format_perf , format_time
9
- from codeflash .models .models import GeneratedTests , GeneratedTestsList , InvocationId
14
+ from codeflash .models .models import GeneratedTests , GeneratedTestsList
10
15
from codeflash .result .critic import performance_gain
11
- from codeflash .verification .verification_utils import TestConfig
16
+
17
+ if TYPE_CHECKING :
18
+ from codeflash .models .models import InvocationId
19
+ from codeflash .verification .verification_utils import TestConfig
12
20
13
21
14
22
def remove_functions_from_generated_tests (
@@ -36,6 +44,94 @@ def remove_functions_from_generated_tests(
36
44
return GeneratedTestsList (generated_tests = new_generated_tests )
37
45
38
46
47
+ class CfoVisitor (ast .NodeVisitor ):
48
+ """AST visitor that finds all assignments to a variable named 'codeflash_output'.
49
+
50
+ and reports their location relative to the function they're in.
51
+ """
52
+
53
+ def __init__ (self , source_code : str ) -> None :
54
+ self .source_lines = source_code .splitlines ()
55
+ self .results : list [int ] = [] # map actual line number to line number in ast
56
+
57
+ def _is_codeflash_output_target (self , target : Union [ast .expr , list ]) -> bool : # type: ignore[type-arg]
58
+ """Check if the assignment target is the variable 'codeflash_output'."""
59
+ if isinstance (target , ast .Name ):
60
+ return target .id == "codeflash_output"
61
+ if isinstance (target , (ast .Tuple , ast .List )):
62
+ # Handle tuple/list unpacking: a, codeflash_output, b = values
63
+ return any (self ._is_codeflash_output_target (elt ) for elt in target .elts )
64
+ if isinstance (target , (ast .Subscript , ast .Attribute )):
65
+ # Not a simple variable assignment
66
+ return False
67
+ return False
68
+
69
+ def _record_assignment (self , node : ast .AST ) -> None :
70
+ """Record an assignment to codeflash_output."""
71
+ relative_line = node .lineno - 1 # type: ignore[attr-defined]
72
+ self .results .append (relative_line )
73
+
74
+ def visit_Assign (self , node : ast .Assign ) -> None :
75
+ """Visit assignment statements: codeflash_output = value."""
76
+ for target in node .targets :
77
+ if self ._is_codeflash_output_target (target ):
78
+ self ._record_assignment (node )
79
+ break
80
+ self .generic_visit (node )
81
+
82
+ def visit_AnnAssign (self , node : ast .AnnAssign ) -> None :
83
+ """Visit annotated assignments: codeflash_output: int = value."""
84
+ if self ._is_codeflash_output_target (node .target ):
85
+ self ._record_assignment (node )
86
+ self .generic_visit (node )
87
+
88
+ def visit_AugAssign (self , node : ast .AugAssign ) -> None :
89
+ """Visit augmented assignments: codeflash_output += value."""
90
+ if self ._is_codeflash_output_target (node .target ):
91
+ self ._record_assignment (node )
92
+ self .generic_visit (node )
93
+
94
+ def visit_NamedExpr (self , node : ast .NamedExpr ) -> None :
95
+ """Visit walrus operator: (codeflash_output := value)."""
96
+ if isinstance (node .target , ast .Name ) and node .target .id == "codeflash_output" :
97
+ self ._record_assignment (node )
98
+ self .generic_visit (node )
99
+
100
+ def visit_For (self , node : ast .For ) -> None :
101
+ """Visit for loops: for codeflash_output in iterable."""
102
+ if self ._is_codeflash_output_target (node .target ):
103
+ self ._record_assignment (node )
104
+ self .generic_visit (node )
105
+
106
+ def visit_comprehension (self , node : ast .comprehension ) -> None :
107
+ """Visit comprehensions: [x for codeflash_output in iterable]."""
108
+ if self ._is_codeflash_output_target (node .target ):
109
+ # Comprehensions don't have line numbers, so we skip recording
110
+ pass
111
+ self .generic_visit (node )
112
+
113
+ def visit_With (self , node : ast .With ) -> None :
114
+ """Visit with statements: with expr as codeflash_output."""
115
+ for item in node .items :
116
+ if item .optional_vars and self ._is_codeflash_output_target (item .optional_vars ):
117
+ self ._record_assignment (node )
118
+ break
119
+ self .generic_visit (node )
120
+
121
+ def visit_ExceptHandler (self , node : ast .ExceptHandler ) -> None :
122
+ """Visit except handlers: except Exception as codeflash_output."""
123
+ if node .name == "codeflash_output" :
124
+ self ._record_assignment (node )
125
+ self .generic_visit (node )
126
+
127
+
128
+ def find_codeflash_output_assignments (source_code : str ) -> list [int ]:
129
+ tree = ast .parse (source_code )
130
+ visitor = CfoVisitor (source_code )
131
+ visitor .visit (tree )
132
+ return visitor .results
133
+
134
+
39
135
def add_runtime_comments_to_generated_tests (
40
136
test_cfg : TestConfig ,
41
137
generated_tests : GeneratedTestsList ,
@@ -49,11 +145,15 @@ def add_runtime_comments_to_generated_tests(
49
145
50
146
# TODO: reduce for loops to one
51
147
class RuntimeCommentTransformer (cst .CSTTransformer ):
52
- def __init__ (self , test : GeneratedTests , tests_root : Path , rel_tests_root : Path ) -> None :
148
+ def __init__ (self , module : cst .Module , test : GeneratedTests , tests_root : Path , rel_tests_root : Path ) -> None :
149
+ super ().__init__ ()
53
150
self .test = test
54
151
self .context_stack : list [str ] = []
55
152
self .tests_root = tests_root
56
153
self .rel_tests_root = rel_tests_root
154
+ self .module = module
155
+ self .cfo_locs : list [int ] = []
156
+ self .cfo_idx_loc_to_look_at : int = - 1
57
157
58
158
def visit_ClassDef (self , node : cst .ClassDef ) -> None :
59
159
# Track when we enter a class
@@ -65,6 +165,13 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
65
165
return updated_node
66
166
67
167
def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
168
+ # convert function body to ast normalized string and find occurrences of codeflash_output
169
+ body_code = dedent (self .module .code_for_node (node .body ))
170
+ normalized_body_code = ast .unparse (ast .parse (body_code ))
171
+ self .cfo_locs = sorted (
172
+ find_codeflash_output_assignments (normalized_body_code )
173
+ ) # sorted in order we will encounter them
174
+ self .cfo_idx_loc_to_look_at = - 1
68
175
self .context_stack .append (node .name .value )
69
176
70
177
def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef : # noqa: ARG002
@@ -91,10 +198,12 @@ def leave_SimpleStatementLine(
91
198
92
199
if codeflash_assignment_found :
93
200
# Find matching test cases by looking for this test function name in the test results
201
+ self .cfo_idx_loc_to_look_at += 1
94
202
matching_original_times = []
95
203
matching_optimized_times = []
96
- # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name
204
+ # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
97
205
for invocation_id , runtimes in original_runtimes .items ():
206
+ # get position here and match in if condition
98
207
qualified_name = (
99
208
invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
100
209
if invocation_id .test_class_name
@@ -105,13 +214,19 @@ def leave_SimpleStatementLine(
105
214
.with_suffix (".py" )
106
215
.relative_to (self .rel_tests_root )
107
216
)
108
- if qualified_name == "." .join (self .context_stack ) and rel_path in [
109
- self .test .behavior_file_path .relative_to (self .tests_root ),
110
- self .test .perf_file_path .relative_to (self .tests_root ),
111
- ]:
217
+ if (
218
+ qualified_name == "." .join (self .context_stack )
219
+ and rel_path
220
+ in [
221
+ self .test .behavior_file_path .relative_to (self .tests_root ),
222
+ self .test .perf_file_path .relative_to (self .tests_root ),
223
+ ]
224
+ and int (invocation_id .iteration_id .split ("_" )[0 ]) == self .cfo_locs [self .cfo_idx_loc_to_look_at ] # type:ignore[union-attr]
225
+ ):
112
226
matching_original_times .extend (runtimes )
113
227
114
228
for invocation_id , runtimes in optimized_runtimes .items ():
229
+ # get position here and match in if condition
115
230
qualified_name = (
116
231
invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
117
232
if invocation_id .test_class_name
@@ -122,10 +237,15 @@ def leave_SimpleStatementLine(
122
237
.with_suffix (".py" )
123
238
.relative_to (self .rel_tests_root )
124
239
)
125
- if qualified_name == "." .join (self .context_stack ) and rel_path in [
126
- self .test .behavior_file_path .relative_to (self .tests_root ),
127
- self .test .perf_file_path .relative_to (self .tests_root ),
128
- ]:
240
+ if (
241
+ qualified_name == "." .join (self .context_stack )
242
+ and rel_path
243
+ in [
244
+ self .test .behavior_file_path .relative_to (self .tests_root ),
245
+ self .test .perf_file_path .relative_to (self .tests_root ),
246
+ ]
247
+ and int (invocation_id .iteration_id .split ("_" )[0 ]) == self .cfo_locs [self .cfo_idx_loc_to_look_at ] # type:ignore[union-attr]
248
+ ):
129
249
matching_optimized_times .extend (runtimes )
130
250
131
251
if matching_original_times and matching_optimized_times :
@@ -161,9 +281,8 @@ def leave_SimpleStatementLine(
161
281
try :
162
282
# Parse the test source code
163
283
tree = cst .parse_module (test .generated_original_test_source )
164
-
165
284
# Transform the tree to add runtime comments
166
- transformer = RuntimeCommentTransformer (test , tests_root , rel_tests_root )
285
+ transformer = RuntimeCommentTransformer (tree , test , tests_root , rel_tests_root )
167
286
modified_tree = tree .visit (transformer )
168
287
169
288
# Convert back to source code
0 commit comments