Skip to content

Commit a13bf17

Browse files
Merge pull request #358 from codeflash-ai/fix-test-reporting
Granular test runtime comment for every codeflash_output assignment
2 parents a8a591b + ef010e5 commit a13bf17

File tree

2 files changed

+193
-54
lines changed

2 files changed

+193
-54
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 133 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1+
from __future__ import annotations
2+
3+
import ast
14
import os
25
import re
36
from pathlib import Path
7+
from textwrap import dedent
8+
from typing import TYPE_CHECKING, Union
49

510
import libcst as cst
611

712
from codeflash.cli_cmds.console import logger
813
from codeflash.code_utils.time_utils import format_perf, format_time
9-
from codeflash.models.models import GeneratedTests, GeneratedTestsList, InvocationId
14+
from codeflash.models.models import GeneratedTests, GeneratedTestsList
1015
from codeflash.result.critic import performance_gain
11-
from codeflash.verification.verification_utils import TestConfig
16+
17+
if TYPE_CHECKING:
18+
from codeflash.models.models import InvocationId
19+
from codeflash.verification.verification_utils import TestConfig
1220

1321

1422
def remove_functions_from_generated_tests(
@@ -36,6 +44,94 @@ def remove_functions_from_generated_tests(
3644
return GeneratedTestsList(generated_tests=new_generated_tests)
3745

3846

47+
class CfoVisitor(ast.NodeVisitor):
48+
"""AST visitor that finds all assignments to a variable named 'codeflash_output'.
49+
50+
and reports their location relative to the function they're in.
51+
"""
52+
53+
def __init__(self, source_code: str) -> None:
54+
self.source_lines = source_code.splitlines()
55+
self.results: list[int] = [] # map actual line number to line number in ast
56+
57+
def _is_codeflash_output_target(self, target: Union[ast.expr, list]) -> bool: # type: ignore[type-arg]
58+
"""Check if the assignment target is the variable 'codeflash_output'."""
59+
if isinstance(target, ast.Name):
60+
return target.id == "codeflash_output"
61+
if isinstance(target, (ast.Tuple, ast.List)):
62+
# Handle tuple/list unpacking: a, codeflash_output, b = values
63+
return any(self._is_codeflash_output_target(elt) for elt in target.elts)
64+
if isinstance(target, (ast.Subscript, ast.Attribute)):
65+
# Not a simple variable assignment
66+
return False
67+
return False
68+
69+
def _record_assignment(self, node: ast.AST) -> None:
70+
"""Record an assignment to codeflash_output."""
71+
relative_line = node.lineno - 1 # type: ignore[attr-defined]
72+
self.results.append(relative_line)
73+
74+
def visit_Assign(self, node: ast.Assign) -> None:
75+
"""Visit assignment statements: codeflash_output = value."""
76+
for target in node.targets:
77+
if self._is_codeflash_output_target(target):
78+
self._record_assignment(node)
79+
break
80+
self.generic_visit(node)
81+
82+
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
83+
"""Visit annotated assignments: codeflash_output: int = value."""
84+
if self._is_codeflash_output_target(node.target):
85+
self._record_assignment(node)
86+
self.generic_visit(node)
87+
88+
def visit_AugAssign(self, node: ast.AugAssign) -> None:
89+
"""Visit augmented assignments: codeflash_output += value."""
90+
if self._is_codeflash_output_target(node.target):
91+
self._record_assignment(node)
92+
self.generic_visit(node)
93+
94+
def visit_NamedExpr(self, node: ast.NamedExpr) -> None:
95+
"""Visit walrus operator: (codeflash_output := value)."""
96+
if isinstance(node.target, ast.Name) and node.target.id == "codeflash_output":
97+
self._record_assignment(node)
98+
self.generic_visit(node)
99+
100+
def visit_For(self, node: ast.For) -> None:
101+
"""Visit for loops: for codeflash_output in iterable."""
102+
if self._is_codeflash_output_target(node.target):
103+
self._record_assignment(node)
104+
self.generic_visit(node)
105+
106+
def visit_comprehension(self, node: ast.comprehension) -> None:
107+
"""Visit comprehensions: [x for codeflash_output in iterable]."""
108+
if self._is_codeflash_output_target(node.target):
109+
# Comprehensions don't have line numbers, so we skip recording
110+
pass
111+
self.generic_visit(node)
112+
113+
def visit_With(self, node: ast.With) -> None:
114+
"""Visit with statements: with expr as codeflash_output."""
115+
for item in node.items:
116+
if item.optional_vars and self._is_codeflash_output_target(item.optional_vars):
117+
self._record_assignment(node)
118+
break
119+
self.generic_visit(node)
120+
121+
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None:
122+
"""Visit except handlers: except Exception as codeflash_output."""
123+
if node.name == "codeflash_output":
124+
self._record_assignment(node)
125+
self.generic_visit(node)
126+
127+
128+
def find_codeflash_output_assignments(source_code: str) -> list[int]:
129+
tree = ast.parse(source_code)
130+
visitor = CfoVisitor(source_code)
131+
visitor.visit(tree)
132+
return visitor.results
133+
134+
39135
def add_runtime_comments_to_generated_tests(
40136
test_cfg: TestConfig,
41137
generated_tests: GeneratedTestsList,
@@ -49,11 +145,15 @@ def add_runtime_comments_to_generated_tests(
49145

50146
# TODO: reduce for loops to one
51147
class RuntimeCommentTransformer(cst.CSTTransformer):
52-
def __init__(self, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None:
148+
def __init__(self, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None:
149+
super().__init__()
53150
self.test = test
54151
self.context_stack: list[str] = []
55152
self.tests_root = tests_root
56153
self.rel_tests_root = rel_tests_root
154+
self.module = module
155+
self.cfo_locs: list[int] = []
156+
self.cfo_idx_loc_to_look_at: int = -1
57157

58158
def visit_ClassDef(self, node: cst.ClassDef) -> None:
59159
# Track when we enter a class
@@ -65,6 +165,13 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
65165
return updated_node
66166

67167
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
168+
# convert function body to ast normalized string and find occurrences of codeflash_output
169+
body_code = dedent(self.module.code_for_node(node.body))
170+
normalized_body_code = ast.unparse(ast.parse(body_code))
171+
self.cfo_locs = sorted(
172+
find_codeflash_output_assignments(normalized_body_code)
173+
) # sorted in order we will encounter them
174+
self.cfo_idx_loc_to_look_at = -1
68175
self.context_stack.append(node.name.value)
69176

70177
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
@@ -91,10 +198,12 @@ def leave_SimpleStatementLine(
91198

92199
if codeflash_assignment_found:
93200
# Find matching test cases by looking for this test function name in the test results
201+
self.cfo_idx_loc_to_look_at += 1
94202
matching_original_times = []
95203
matching_optimized_times = []
96-
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name
204+
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
97205
for invocation_id, runtimes in original_runtimes.items():
206+
# get position here and match in if condition
98207
qualified_name = (
99208
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
100209
if invocation_id.test_class_name
@@ -105,13 +214,19 @@ def leave_SimpleStatementLine(
105214
.with_suffix(".py")
106215
.relative_to(self.rel_tests_root)
107216
)
108-
if qualified_name == ".".join(self.context_stack) and rel_path in [
109-
self.test.behavior_file_path.relative_to(self.tests_root),
110-
self.test.perf_file_path.relative_to(self.tests_root),
111-
]:
217+
if (
218+
qualified_name == ".".join(self.context_stack)
219+
and rel_path
220+
in [
221+
self.test.behavior_file_path.relative_to(self.tests_root),
222+
self.test.perf_file_path.relative_to(self.tests_root),
223+
]
224+
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
225+
):
112226
matching_original_times.extend(runtimes)
113227

114228
for invocation_id, runtimes in optimized_runtimes.items():
229+
# get position here and match in if condition
115230
qualified_name = (
116231
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
117232
if invocation_id.test_class_name
@@ -122,10 +237,15 @@ def leave_SimpleStatementLine(
122237
.with_suffix(".py")
123238
.relative_to(self.rel_tests_root)
124239
)
125-
if qualified_name == ".".join(self.context_stack) and rel_path in [
126-
self.test.behavior_file_path.relative_to(self.tests_root),
127-
self.test.perf_file_path.relative_to(self.tests_root),
128-
]:
240+
if (
241+
qualified_name == ".".join(self.context_stack)
242+
and rel_path
243+
in [
244+
self.test.behavior_file_path.relative_to(self.tests_root),
245+
self.test.perf_file_path.relative_to(self.tests_root),
246+
]
247+
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
248+
):
129249
matching_optimized_times.extend(runtimes)
130250

131251
if matching_original_times and matching_optimized_times:
@@ -161,9 +281,8 @@ def leave_SimpleStatementLine(
161281
try:
162282
# Parse the test source code
163283
tree = cst.parse_module(test.generated_original_test_source)
164-
165284
# Transform the tree to add runtime comments
166-
transformer = RuntimeCommentTransformer(test, tests_root, rel_tests_root)
285+
transformer = RuntimeCommentTransformer(tree, test, tests_root, rel_tests_root)
167286
modified_tree = tree.visit(transformer)
168287

169288
# Convert back to source code

0 commit comments

Comments
 (0)