Skip to content

Commit b23fcff

Browse files
committed
Merge branch 'main' into can-specify-multiple-replay-tests
2 parents 986a2b1 + e09839f commit b23fcff

File tree

6 files changed

+252
-138
lines changed

6 files changed

+252
-138
lines changed

codeflash/code_utils/code_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
import re
66
import shutil
77
import site
8+
import sys
89
from contextlib import contextmanager
910
from functools import lru_cache
1011
from pathlib import Path
1112
from tempfile import TemporaryDirectory
1213

1314
import tomlkit
1415

15-
from codeflash.cli_cmds.console import logger
16+
from codeflash.cli_cmds.console import logger, paneled_text
1617
from codeflash.code_utils.config_parser import find_pyproject_toml
1718

1819
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
@@ -213,3 +214,9 @@ def cleanup_paths(paths: list[Path]) -> None:
213214
def restore_conftest(path_to_content_map: dict[Path, str]) -> None:
214215
for path, file_content in path_to_content_map.items():
215216
path.write_text(file_content, encoding="utf8")
217+
218+
219+
def exit_with_message(message: str, *, error_on_exit: bool = False) -> None:
220+
paneled_text(message, panel_args={"style": "red"})
221+
222+
sys.exit(1 if error_on_exit else 0)

codeflash/code_utils/env_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

33
import os
4-
import sys
54
import tempfile
65
from functools import lru_cache
76
from pathlib import Path
87
from typing import Optional
98

109
from codeflash.cli_cmds.console import logger
10+
from codeflash.code_utils.code_utils import exit_with_message
1111
from codeflash.code_utils.formatter import format_code
1212
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
1313

@@ -24,11 +24,11 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
2424
try:
2525
format_code(formatter_cmds, tmp_file, print_status=False)
2626
except Exception:
27-
print(
28-
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again."
27+
exit_with_message(
28+
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.",
29+
error_on_exit=True,
2930
)
30-
if exit_on_failure:
31-
sys.exit(1)
31+
3232
return return_code
3333

3434

codeflash/discovery/discover_unit_tests.py

Lines changed: 117 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -141,144 +141,175 @@ def close(self) -> None:
141141

142142

143143
class ImportAnalyzer(ast.NodeVisitor):
144-
"""AST-based analyzer to find all imports in a test file."""
144+
"""AST-based analyzer to check if any qualified names from function_names_to_find are imported or used in a test file."""
145145

146146
def __init__(self, function_names_to_find: set[str]) -> None:
147147
self.function_names_to_find = function_names_to_find
148-
self.imported_names: set[str] = set()
148+
self.found_any_target_function: bool = False
149+
self.found_qualified_name = None
149150
self.imported_modules: set[str] = set()
150-
self.found_target_functions: set[str] = set()
151-
self.qualified_names_called: set[str] = set()
151+
self.has_dynamic_imports: bool = False
152+
self.wildcard_modules: set[str] = set()
152153

153154
def visit_Import(self, node: ast.Import) -> None:
154155
"""Handle 'import module' statements."""
156+
if self.found_any_target_function:
157+
return
158+
155159
for alias in node.names:
156160
module_name = alias.asname if alias.asname else alias.name
157161
self.imported_modules.add(module_name)
158-
self.imported_names.add(module_name)
159-
self.generic_visit(node)
162+
163+
# Check for dynamic import modules
164+
if alias.name == "importlib":
165+
self.has_dynamic_imports = True
166+
167+
# Check if module itself is a target qualified name
168+
if module_name in self.function_names_to_find:
169+
self.found_any_target_function = True
170+
self.found_qualified_name = module_name
171+
return
172+
# Check if any target qualified name starts with this module
173+
for target_func in self.function_names_to_find:
174+
if target_func.startswith(f"{module_name}."):
175+
self.found_any_target_function = True
176+
self.found_qualified_name = target_func
177+
return
160178

161179
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
162180
"""Handle 'from module import name' statements."""
163-
if node.module:
164-
self.imported_modules.add(node.module)
181+
if self.found_any_target_function:
182+
return
183+
184+
if not node.module:
185+
return
165186

166187
for alias in node.names:
167188
if alias.name == "*":
168-
continue
169-
imported_name = alias.asname if alias.asname else alias.name
170-
self.imported_names.add(imported_name)
171-
if alias.name in self.function_names_to_find:
172-
self.found_target_functions.add(alias.name)
173-
# Check for qualified name matches
174-
if node.module:
189+
self.wildcard_modules.add(node.module)
190+
else:
191+
imported_name = alias.asname if alias.asname else alias.name
192+
self.imported_modules.add(imported_name)
193+
194+
# Check for dynamic import functions
195+
if node.module == "importlib" and alias.name == "import_module":
196+
self.has_dynamic_imports = True
197+
198+
# Check if imported name is a target qualified name
199+
if alias.name in self.function_names_to_find:
200+
self.found_any_target_function = True
201+
self.found_qualified_name = alias.name
202+
return
203+
# Check if module.name forms a target qualified name
175204
qualified_name = f"{node.module}.{alias.name}"
176205
if qualified_name in self.function_names_to_find:
177-
self.found_target_functions.add(qualified_name)
178-
self.generic_visit(node)
206+
self.found_any_target_function = True
207+
self.found_qualified_name = qualified_name
208+
return
209+
210+
def visit_Attribute(self, node: ast.Attribute) -> None:
211+
"""Handle attribute access like module.function_name."""
212+
if self.found_any_target_function:
213+
return
179214

180-
def visit_Call(self, node: ast.Call) -> None:
181-
"""Handle dynamic imports like importlib.import_module() or __import__()."""
215+
# Check if this is accessing a target function through an imported module
182216
if (
183-
isinstance(node.func, ast.Name)
184-
and node.func.id == "__import__"
185-
and node.args
186-
and isinstance(node.args[0], ast.Constant)
187-
and isinstance(node.args[0].value, str)
217+
isinstance(node.value, ast.Name)
218+
and node.value.id in self.imported_modules
219+
and node.attr in self.function_names_to_find
188220
):
189-
# __import__("module_name")
190-
self.imported_modules.add(node.args[0].value)
191-
elif (
192-
isinstance(node.func, ast.Attribute)
193-
and isinstance(node.func.value, ast.Name)
194-
and node.func.value.id == "importlib"
195-
and node.func.attr == "import_module"
196-
and node.args
197-
and isinstance(node.args[0], ast.Constant)
198-
and isinstance(node.args[0].value, str)
199-
):
200-
# importlib.import_module("module_name")
201-
self.imported_modules.add(node.args[0].value)
202-
self.generic_visit(node)
221+
self.found_any_target_function = True
222+
self.found_qualified_name = node.attr
223+
return
203224

204-
def visit_Name(self, node: ast.Name) -> None:
205-
"""Check if any name usage matches our target functions."""
206-
if node.id in self.function_names_to_find:
207-
self.found_target_functions.add(node.id)
208-
self.generic_visit(node)
225+
# Check if this is accessing a target function through a dynamically imported module
226+
# Only if we've detected dynamic imports are being used
227+
if self.has_dynamic_imports and node.attr in self.function_names_to_find:
228+
self.found_any_target_function = True
229+
self.found_qualified_name = node.attr
230+
return
209231

210-
def visit_Attribute(self, node: ast.Attribute) -> None:
211-
"""Handle module.function_name patterns."""
212-
if node.attr in self.function_names_to_find:
213-
self.found_target_functions.add(node.attr)
214-
if isinstance(node.value, ast.Name):
215-
qualified_name = f"{node.value.id}.{node.attr}"
216-
self.qualified_names_called.add(qualified_name)
217232
self.generic_visit(node)
218233

234+
def visit_Name(self, node: ast.Name) -> None:
235+
"""Handle direct name usage like target_function()."""
236+
if self.found_any_target_function:
237+
return
219238

220-
def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> tuple[bool, set[str]]:
221-
"""Analyze imports in a test file to determine if it might test any target functions.
239+
# Check for __import__ usage
240+
if node.id == "__import__":
241+
self.has_dynamic_imports = True
222242

223-
Args:
224-
test_file_path: Path to the test file
225-
target_functions: Set of function names we're looking for
243+
if node.id in self.function_names_to_find:
244+
self.found_any_target_function = True
245+
self.found_qualified_name = node.id
246+
return
247+
248+
# Check if this name could come from a wildcard import
249+
for wildcard_module in self.wildcard_modules:
250+
for target_func in self.function_names_to_find:
251+
# Check if target_func is from this wildcard module and name matches
252+
if target_func.startswith(f"{wildcard_module}.") and target_func.endswith(f".{node.id}"):
253+
self.found_any_target_function = True
254+
self.found_qualified_name = target_func
255+
return
226256

227-
Returns:
228-
Tuple of (should_process_with_jedi, found_function_names)
257+
self.generic_visit(node)
229258

230-
"""
231-
if isinstance(test_file_path, str):
232-
test_file_path = Path(test_file_path)
259+
def generic_visit(self, node: ast.AST) -> None:
260+
"""Override generic_visit to stop traversal if a target function is found."""
261+
if self.found_any_target_function:
262+
return
263+
super().generic_visit(node)
233264

234-
try:
235-
with test_file_path.open("r", encoding="utf-8") as f:
236-
content = f.read()
237265

238-
tree = ast.parse(content, filename=str(test_file_path))
266+
def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool:
267+
"""Analyze a test file to see if it imports any of the target functions."""
268+
try:
269+
with Path(test_file_path).open("r", encoding="utf-8") as f:
270+
source_code = f.read()
271+
tree = ast.parse(source_code, filename=str(test_file_path))
239272
analyzer = ImportAnalyzer(target_functions)
240273
analyzer.visit(tree)
241-
242-
if analyzer.found_target_functions:
243-
return True, analyzer.found_target_functions
244-
245-
return False, set() # noqa: TRY300
246-
247-
except (SyntaxError, UnicodeDecodeError, OSError) as e:
274+
except (SyntaxError, FileNotFoundError) as e:
248275
logger.debug(f"Failed to analyze imports in {test_file_path}: {e}")
249-
return True, set()
276+
return True
277+
else:
278+
if analyzer.found_any_target_function:
279+
logger.debug(f"Test file {test_file_path} imports target function: {analyzer.found_qualified_name}")
280+
return True
281+
logger.debug(f"Test file {test_file_path} does not import any target functions.")
282+
return False
250283

251284

252285
def filter_test_files_by_imports(
253286
file_to_test_map: dict[Path, list[TestsInFile]], target_functions: set[str]
254-
) -> tuple[dict[Path, list[TestsInFile]], dict[Path, set[str]]]:
287+
) -> dict[Path, list[TestsInFile]]:
255288
"""Filter test files based on import analysis to reduce Jedi processing.
256289
257290
Args:
258291
file_to_test_map: Original mapping of test files to test functions
259292
target_functions: Set of function names we're optimizing
260293
261294
Returns:
262-
Tuple of (filtered_file_map, import_analysis_results)
295+
Filtered mapping of test files to test functions
263296
264297
"""
265298
if not target_functions:
266-
return file_to_test_map, {}
299+
return file_to_test_map
267300

268-
filtered_map = {}
269-
import_results = {}
301+
logger.debug(f"Target functions for import filtering: {target_functions}")
270302

303+
filtered_map = {}
271304
for test_file, test_functions in file_to_test_map.items():
272-
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
273-
import_results[test_file] = found_functions
274-
305+
should_process = analyze_imports_in_test_file(test_file, target_functions)
275306
if should_process:
276307
filtered_map[test_file] = test_functions
277-
else:
278-
logger.debug(f"Skipping {test_file} - no relevant imports found")
279308

280-
logger.debug(f"Import filter: Processing {len(filtered_map)}/{len(file_to_test_map)} test files")
281-
return filtered_map, import_results
309+
logger.debug(
310+
f"analyzed {len(file_to_test_map)} test files for imports, filtered down to {len(filtered_map)} relevant files"
311+
)
312+
return filtered_map
282313

283314

284315
def discover_unit_tests(
@@ -296,7 +327,6 @@ def discover_unit_tests(
296327
functions_to_optimize = None
297328
if file_to_funcs_to_optimize:
298329
functions_to_optimize = [func for funcs_list in file_to_funcs_to_optimize.values() for func in funcs_list]
299-
300330
function_to_tests, num_discovered_tests = strategy(cfg, discover_only_these_tests, functions_to_optimize)
301331
return function_to_tests, num_discovered_tests
302332

@@ -455,12 +485,8 @@ def process_test_files(
455485
test_framework = cfg.test_framework
456486

457487
if functions_to_optimize:
458-
target_function_names = set()
459-
for func in functions_to_optimize:
460-
target_function_names.add(func.qualified_name)
461-
logger.debug(f"Target functions for import filtering: {target_function_names}")
462-
file_to_test_map, import_results = filter_test_files_by_imports(file_to_test_map, target_function_names)
463-
logger.debug(f"Import analysis results: {len(import_results)} files analyzed")
488+
target_function_names = {func.qualified_name for func in functions_to_optimize}
489+
file_to_test_map = filter_test_files_by_imports(file_to_test_map, target_function_names)
464490

465491
function_to_test_map = defaultdict(set)
466492
num_discovered_tests = 0

codeflash/discovery/functions_to_optimize.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again
1818
from codeflash.cli_cmds.console import DEBUG_MODE, console, logger
1919
from codeflash.code_utils.code_utils import (
20+
exit_with_message,
2021
is_class_defined_in_file,
2122
module_name_from_file_path,
2223
path_belongs_to_site_packages,
@@ -179,8 +180,9 @@ def get_functions_to_optimize(
179180
if only_get_this_function is not None:
180181
split_function = only_get_this_function.split(".")
181182
if len(split_function) > 2:
182-
msg = "Function name should be in the format 'function_name' or 'class_name.function_name'"
183-
raise ValueError(msg)
183+
exit_with_message(
184+
"Function name should be in the format 'function_name' or 'class_name.function_name'"
185+
)
184186
if len(split_function) == 2:
185187
class_name, only_function_name = split_function
186188
else:
@@ -193,8 +195,9 @@ def get_functions_to_optimize(
193195
):
194196
found_function = fn
195197
if found_function is None:
196-
msg = f"Function {only_function_name} not found in file {file} or the function does not have a 'return' statement or is a property"
197-
raise ValueError(msg)
198+
exit_with_message(
199+
f"Function {only_function_name} not found in file {file}\nor the function does not have a 'return' statement or is a property"
200+
)
198201
functions[file] = [found_function]
199202
else:
200203
logger.info("Finding all functions modified in the current git diff ...")

0 commit comments

Comments
 (0)