Skip to content

Commit c33d1e9

Browse files
Merge branch 'main' into uv-migration-fr
2 parents a407ed2 + a13bf17 commit c33d1e9

File tree

7 files changed

+282
-79
lines changed

7 files changed

+282
-79
lines changed

code_to_optimize/code_directories/simple_tracer_e2e/workload.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,45 @@ def test_threadpool() -> None:
2121
for r in result:
2222
print(r)
2323

24+
class AlexNet:
25+
def __init__(self, num_classes=1000):
26+
self.num_classes = num_classes
27+
self.features_size = 256 * 6 * 6
28+
29+
def forward(self, x):
30+
features = self._extract_features(x)
31+
32+
output = self._classify(features)
33+
return output
34+
35+
def _extract_features(self, x):
36+
result = []
37+
for i in range(len(x)):
38+
pass
39+
40+
return result
41+
42+
def _classify(self, features):
43+
total = sum(features)
44+
return [total % self.num_classes for _ in features]
45+
46+
class SimpleModel:
47+
@staticmethod
48+
def predict(data):
49+
return [x * 2 for x in data]
50+
51+
@classmethod
52+
def create_default(cls):
53+
return cls()
54+
55+
def test_models():
56+
model = AlexNet(num_classes=10)
57+
input_data = [1, 2, 3, 4, 5]
58+
result = model.forward(input_data)
59+
60+
model2 = SimpleModel.create_default()
61+
prediction = model2.predict(input_data)
2462

2563
if __name__ == "__main__":
2664
test_threadpool()
65+
test_models()

codeflash/code_utils/edit_generated_tests.py

Lines changed: 133 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1+
from __future__ import annotations
2+
3+
import ast
14
import os
25
import re
36
from pathlib import Path
7+
from textwrap import dedent
8+
from typing import TYPE_CHECKING, Union
49

510
import libcst as cst
611

712
from codeflash.cli_cmds.console import logger
813
from codeflash.code_utils.time_utils import format_perf, format_time
9-
from codeflash.models.models import GeneratedTests, GeneratedTestsList, InvocationId
14+
from codeflash.models.models import GeneratedTests, GeneratedTestsList
1015
from codeflash.result.critic import performance_gain
11-
from codeflash.verification.verification_utils import TestConfig
16+
17+
if TYPE_CHECKING:
18+
from codeflash.models.models import InvocationId
19+
from codeflash.verification.verification_utils import TestConfig
1220

1321

1422
def remove_functions_from_generated_tests(
@@ -36,6 +44,94 @@ def remove_functions_from_generated_tests(
3644
return GeneratedTestsList(generated_tests=new_generated_tests)
3745

3846

47+
class CfoVisitor(ast.NodeVisitor):
48+
"""AST visitor that finds all assignments to a variable named 'codeflash_output'.
49+
50+
and reports their location relative to the function they're in.
51+
"""
52+
53+
def __init__(self, source_code: str) -> None:
54+
self.source_lines = source_code.splitlines()
55+
self.results: list[int] = [] # map actual line number to line number in ast
56+
57+
def _is_codeflash_output_target(self, target: Union[ast.expr, list]) -> bool: # type: ignore[type-arg]
58+
"""Check if the assignment target is the variable 'codeflash_output'."""
59+
if isinstance(target, ast.Name):
60+
return target.id == "codeflash_output"
61+
if isinstance(target, (ast.Tuple, ast.List)):
62+
# Handle tuple/list unpacking: a, codeflash_output, b = values
63+
return any(self._is_codeflash_output_target(elt) for elt in target.elts)
64+
if isinstance(target, (ast.Subscript, ast.Attribute)):
65+
# Not a simple variable assignment
66+
return False
67+
return False
68+
69+
def _record_assignment(self, node: ast.AST) -> None:
70+
"""Record an assignment to codeflash_output."""
71+
relative_line = node.lineno - 1 # type: ignore[attr-defined]
72+
self.results.append(relative_line)
73+
74+
def visit_Assign(self, node: ast.Assign) -> None:
75+
"""Visit assignment statements: codeflash_output = value."""
76+
for target in node.targets:
77+
if self._is_codeflash_output_target(target):
78+
self._record_assignment(node)
79+
break
80+
self.generic_visit(node)
81+
82+
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
83+
"""Visit annotated assignments: codeflash_output: int = value."""
84+
if self._is_codeflash_output_target(node.target):
85+
self._record_assignment(node)
86+
self.generic_visit(node)
87+
88+
def visit_AugAssign(self, node: ast.AugAssign) -> None:
89+
"""Visit augmented assignments: codeflash_output += value."""
90+
if self._is_codeflash_output_target(node.target):
91+
self._record_assignment(node)
92+
self.generic_visit(node)
93+
94+
def visit_NamedExpr(self, node: ast.NamedExpr) -> None:
95+
"""Visit walrus operator: (codeflash_output := value)."""
96+
if isinstance(node.target, ast.Name) and node.target.id == "codeflash_output":
97+
self._record_assignment(node)
98+
self.generic_visit(node)
99+
100+
def visit_For(self, node: ast.For) -> None:
101+
"""Visit for loops: for codeflash_output in iterable."""
102+
if self._is_codeflash_output_target(node.target):
103+
self._record_assignment(node)
104+
self.generic_visit(node)
105+
106+
def visit_comprehension(self, node: ast.comprehension) -> None:
107+
"""Visit comprehensions: [x for codeflash_output in iterable]."""
108+
if self._is_codeflash_output_target(node.target):
109+
# Comprehensions don't have line numbers, so we skip recording
110+
pass
111+
self.generic_visit(node)
112+
113+
def visit_With(self, node: ast.With) -> None:
114+
"""Visit with statements: with expr as codeflash_output."""
115+
for item in node.items:
116+
if item.optional_vars and self._is_codeflash_output_target(item.optional_vars):
117+
self._record_assignment(node)
118+
break
119+
self.generic_visit(node)
120+
121+
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None:
122+
"""Visit except handlers: except Exception as codeflash_output."""
123+
if node.name == "codeflash_output":
124+
self._record_assignment(node)
125+
self.generic_visit(node)
126+
127+
128+
def find_codeflash_output_assignments(source_code: str) -> list[int]:
129+
tree = ast.parse(source_code)
130+
visitor = CfoVisitor(source_code)
131+
visitor.visit(tree)
132+
return visitor.results
133+
134+
39135
def add_runtime_comments_to_generated_tests(
40136
test_cfg: TestConfig,
41137
generated_tests: GeneratedTestsList,
@@ -49,11 +145,15 @@ def add_runtime_comments_to_generated_tests(
49145

50146
# TODO: reduce for loops to one
51147
class RuntimeCommentTransformer(cst.CSTTransformer):
52-
def __init__(self, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None:
148+
def __init__(self, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None:
149+
super().__init__()
53150
self.test = test
54151
self.context_stack: list[str] = []
55152
self.tests_root = tests_root
56153
self.rel_tests_root = rel_tests_root
154+
self.module = module
155+
self.cfo_locs: list[int] = []
156+
self.cfo_idx_loc_to_look_at: int = -1
57157

58158
def visit_ClassDef(self, node: cst.ClassDef) -> None:
59159
# Track when we enter a class
@@ -65,6 +165,13 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
65165
return updated_node
66166

67167
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
168+
# convert function body to ast normalized string and find occurrences of codeflash_output
169+
body_code = dedent(self.module.code_for_node(node.body))
170+
normalized_body_code = ast.unparse(ast.parse(body_code))
171+
self.cfo_locs = sorted(
172+
find_codeflash_output_assignments(normalized_body_code)
173+
) # sorted in order we will encounter them
174+
self.cfo_idx_loc_to_look_at = -1
68175
self.context_stack.append(node.name.value)
69176

70177
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
@@ -91,10 +198,12 @@ def leave_SimpleStatementLine(
91198

92199
if codeflash_assignment_found:
93200
# Find matching test cases by looking for this test function name in the test results
201+
self.cfo_idx_loc_to_look_at += 1
94202
matching_original_times = []
95203
matching_optimized_times = []
96-
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name
204+
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
97205
for invocation_id, runtimes in original_runtimes.items():
206+
# get position here and match in if condition
98207
qualified_name = (
99208
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
100209
if invocation_id.test_class_name
@@ -105,13 +214,19 @@ def leave_SimpleStatementLine(
105214
.with_suffix(".py")
106215
.relative_to(self.rel_tests_root)
107216
)
108-
if qualified_name == ".".join(self.context_stack) and rel_path in [
109-
self.test.behavior_file_path.relative_to(self.tests_root),
110-
self.test.perf_file_path.relative_to(self.tests_root),
111-
]:
217+
if (
218+
qualified_name == ".".join(self.context_stack)
219+
and rel_path
220+
in [
221+
self.test.behavior_file_path.relative_to(self.tests_root),
222+
self.test.perf_file_path.relative_to(self.tests_root),
223+
]
224+
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
225+
):
112226
matching_original_times.extend(runtimes)
113227

114228
for invocation_id, runtimes in optimized_runtimes.items():
229+
# get position here and match in if condition
115230
qualified_name = (
116231
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
117232
if invocation_id.test_class_name
@@ -122,10 +237,15 @@ def leave_SimpleStatementLine(
122237
.with_suffix(".py")
123238
.relative_to(self.rel_tests_root)
124239
)
125-
if qualified_name == ".".join(self.context_stack) and rel_path in [
126-
self.test.behavior_file_path.relative_to(self.tests_root),
127-
self.test.perf_file_path.relative_to(self.tests_root),
128-
]:
240+
if (
241+
qualified_name == ".".join(self.context_stack)
242+
and rel_path
243+
in [
244+
self.test.behavior_file_path.relative_to(self.tests_root),
245+
self.test.perf_file_path.relative_to(self.tests_root),
246+
]
247+
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
248+
):
129249
matching_optimized_times.extend(runtimes)
130250

131251
if matching_original_times and matching_optimized_times:
@@ -161,9 +281,8 @@ def leave_SimpleStatementLine(
161281
try:
162282
# Parse the test source code
163283
tree = cst.parse_module(test.generated_original_test_source)
164-
165284
# Transform the tree to add runtime comments
166-
transformer = RuntimeCommentTransformer(test, tests_root, rel_tests_root)
285+
transformer = RuntimeCommentTransformer(tree, test, tests_root, rel_tests_root)
167286
modified_tree = tree.visit(transformer)
168287

169288
# Convert back to source code

codeflash/discovery/discover_unit_tests.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,15 @@ def __init__(self, function_names_to_find: set[str]) -> None:
151151
self.has_dynamic_imports: bool = False
152152
self.wildcard_modules: set[str] = set()
153153

154+
# Precompute function_names for prefix search
155+
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
156+
self._exact_names = function_names_to_find
157+
self._prefix_roots = {}
158+
for name in function_names_to_find:
159+
if "." in name:
160+
root = name.split(".", 1)[0]
161+
self._prefix_roots.setdefault(root, []).append(name)
162+
154163
def visit_Import(self, node: ast.Import) -> None:
155164
"""Handle 'import module' statements."""
156165
if self.found_any_target_function:
@@ -181,30 +190,46 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
181190
if self.found_any_target_function:
182191
return
183192

184-
if not node.module:
193+
mod = node.module
194+
if not mod:
185195
return
186196

197+
fnames = self._exact_names
198+
proots = self._prefix_roots
199+
187200
for alias in node.names:
188-
if alias.name == "*":
189-
self.wildcard_modules.add(node.module)
190-
else:
191-
imported_name = alias.asname if alias.asname else alias.name
192-
self.imported_modules.add(imported_name)
193-
194-
# Check for dynamic import functions
195-
if node.module == "importlib" and alias.name == "import_module":
196-
self.has_dynamic_imports = True
197-
198-
# Check if imported name is a target qualified name
199-
if alias.name in self.function_names_to_find:
200-
self.found_any_target_function = True
201-
self.found_qualified_name = alias.name
202-
return
203-
# Check if module.name forms a target qualified name
204-
qualified_name = f"{node.module}.{alias.name}"
205-
if qualified_name in self.function_names_to_find:
201+
aname = alias.name
202+
if aname == "*":
203+
self.wildcard_modules.add(mod)
204+
continue
205+
206+
imported_name = alias.asname if alias.asname else aname
207+
self.imported_modules.add(imported_name)
208+
209+
# Fast check for dynamic import
210+
if mod == "importlib" and aname == "import_module":
211+
self.has_dynamic_imports = True
212+
213+
qname = f"{mod}.{aname}"
214+
215+
# Fast exact match check
216+
if aname in fnames:
217+
self.found_any_target_function = True
218+
self.found_qualified_name = aname
219+
return
220+
if qname in fnames:
221+
self.found_any_target_function = True
222+
self.found_qualified_name = qname
223+
return
224+
225+
# Fast prefix match: only for relevant roots
226+
prefix = qname + "."
227+
# Only bother if one of the targets startswith the prefix-root
228+
candidates = proots.get(qname, ())
229+
for target_func in candidates:
230+
if target_func.startswith(prefix):
206231
self.found_any_target_function = True
207-
self.found_qualified_name = qualified_name
232+
self.found_qualified_name = target_func
208233
return
209234

210235
def visit_Attribute(self, node: ast.Attribute) -> None:

codeflash/tracer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def __exit__(
263263
if self.function_count[
264264
str(function.file_name)
265265
+ ":"
266-
+ (function.class_name + ":" if function.class_name else "")
266+
+ (function.class_name + "." if function.class_name else "")
267267
+ function.function_name
268268
]
269269
> 0
@@ -353,7 +353,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
353353
return
354354
if function_qualified_name not in self.function_count:
355355
# seeing this function for the first time
356-
self.function_count[function_qualified_name] = 0
356+
self.function_count[function_qualified_name] = 1
357357
file_valid = filter_files_optimized(
358358
file_path=file_name,
359359
tests_root=Path(self.config["tests_root"]),

tests/scripts/end_to_end_test_tracer_replay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def run_test(expected_improvement_pct: int) -> bool:
88
config = TestConfig(
99
trace_mode=True,
1010
min_improvement_x=0.1,
11-
expected_unit_tests=1,
11+
expected_unit_tests=7,
1212
coverage_expectations=[
1313
CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[5, 6, 7, 8, 10, 13])
1414
],

tests/scripts/end_to_end_test_utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p
206206
return False
207207

208208
functions_traced = re.search(r"Traced (\d+) function calls successfully and replay test created at - (.*)$", stdout)
209-
if not functions_traced or int(functions_traced.group(1)) != 4:
210-
logging.error("Expected 4 traced functions")
209+
if not functions_traced or int(functions_traced.group(1)) != 13:
210+
logging.error("Expected 13 traced functions")
211211
return False
212212

213213
replay_test_path = pathlib.Path(functions_traced.group(2))

0 commit comments

Comments
 (0)