From a73d2d4cf4598e25b7dfdc053cf53f4440de2aa4 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sat, 7 Jun 2025 22:24:05 -0700 Subject: [PATCH 1/5] remove unused import_results object --- codeflash/discovery/discover_unit_tests.py | 26 +++++++++------------- tests/test_unit_test_discovery.py | 9 ++------ 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 7a5bf6f9..815a8486 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -251,7 +251,7 @@ def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: s def filter_test_files_by_imports( file_to_test_map: dict[Path, list[TestsInFile]], target_functions: set[str] -) -> tuple[dict[Path, list[TestsInFile]], dict[Path, set[str]]]: +) -> dict[Path, list[TestsInFile]]: """Filter test files based on import analysis to reduce Jedi processing. Args: @@ -259,26 +259,26 @@ def filter_test_files_by_imports( target_functions: Set of function names we're optimizing Returns: - Tuple of (filtered_file_map, import_analysis_results) + Filtered mapping of test files to test functions """ if not target_functions: - return file_to_test_map, {} + return file_to_test_map + + logger.debug(f"Target functions for import filtering: {target_functions}") filtered_map = {} - import_results = {} + num_analyzed = 0 for test_file, test_functions in file_to_test_map.items(): should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) - import_results[test_file] = found_functions + num_analyzed += 1 if should_process: filtered_map[test_file] = test_functions - else: - logger.debug(f"Skipping {test_file} - no relevant imports found") - logger.debug(f"Import filter: Processing {len(filtered_map)}/{len(file_to_test_map)} test files") - return filtered_map, import_results + logger.debug(f"analyzed {num_analyzed} test files for imports, filtered down to {len(filtered_map)} relevant files") + return filtered_map def discover_unit_tests( @@ -455,12 +455,8 @@ def process_test_files( test_framework = cfg.test_framework if functions_to_optimize: - target_function_names = set() - for func in functions_to_optimize: - target_function_names.add(func.qualified_name) - logger.debug(f"Target functions for import filtering: {target_function_names}") - file_to_test_map, import_results = filter_test_files_by_imports(file_to_test_map, target_function_names) - logger.debug(f"Import analysis results: {len(import_results)} files analyzed") + target_function_names = {func.qualified_name for func in functions_to_optimize} + file_to_test_map = filter_test_files_by_imports(file_to_test_map, target_function_names) function_to_test_map = defaultdict(set) num_discovered_tests = 0 diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index 4465acf7..d0ba0261 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -988,17 +988,13 @@ def test_star(): } target_functions = {"target_function"} - filtered_map, import_results = filter_test_files_by_imports(file_to_test_map, target_functions) + filtered_map = filter_test_files_by_imports(file_to_test_map, target_functions) # Should filter out irrelevant_test assert len(filtered_map) == 1 assert relevant_test in filtered_map assert irrelevant_test not in filtered_map - # Check import analysis results - assert "target_function" in import_results[relevant_test] - assert len(import_results[irrelevant_test]) == 0 - assert len(import_results[star_test]) == 0 def test_filter_test_files_no_target_functions(): @@ -1014,11 +1010,10 @@ def test_filter_test_files_no_target_functions(): } # No target functions provided - filtered_map, import_results = filter_test_files_by_imports(file_to_test_map, set()) + filtered_map = filter_test_files_by_imports(file_to_test_map, set()) # Should return original map unchanged assert filtered_map == file_to_test_map - assert import_results == {} def test_discover_unit_tests_with_import_filtering(): From fee08a65dde383034ab5c55867d0520a5e64bbb5 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sat, 7 Jun 2025 22:27:12 -0700 Subject: [PATCH 2/5] we don't use found_target_functions --- codeflash/discovery/discover_unit_tests.py | 20 +++++------- tests/test_unit_test_discovery.py | 38 +++++++--------------- 2 files changed, 20 insertions(+), 38 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 815a8486..c59d6e54 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -217,7 +217,7 @@ def visit_Attribute(self, node: ast.Attribute) -> None: self.generic_visit(node) -def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> tuple[bool, set[str]]: +def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool: """Analyze imports in a test file to determine if it might test any target functions. Args: @@ -225,7 +225,7 @@ def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: s target_functions: Set of function names we're looking for Returns: - Tuple of (should_process_with_jedi, found_function_names) + bool: True if the test file should be processed (contains relevant imports), False otherwise """ if isinstance(test_file_path, str): @@ -239,14 +239,11 @@ def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: s analyzer = ImportAnalyzer(target_functions) analyzer.visit(tree) - if analyzer.found_target_functions: - return True, analyzer.found_target_functions - - return False, set() # noqa: TRY300 + return bool(analyzer.found_target_functions) except (SyntaxError, UnicodeDecodeError, OSError) as e: logger.debug(f"Failed to analyze imports in {test_file_path}: {e}") - return True, set() + return True def filter_test_files_by_imports( @@ -268,16 +265,15 @@ def filter_test_files_by_imports( logger.debug(f"Target functions for import filtering: {target_functions}") filtered_map = {} - num_analyzed = 0 for test_file, test_functions in file_to_test_map.items(): - should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) - num_analyzed += 1 - + should_process = analyze_imports_in_test_file(test_file, target_functions) if should_process: filtered_map[test_file] = test_functions - logger.debug(f"analyzed {num_analyzed} test files for imports, filtered down to {len(filtered_map)} relevant files") + logger.debug( + f"analyzed {len(file_to_test_map)} test files for imports, filtered down to {len(filtered_map)} relevant files" + ) return filtered_map diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index d0ba0261..8fd0b68c 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -811,11 +811,9 @@ def test_target(): test_file.write_text(test_content) target_functions = {"target_function", "missing_function"} - should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is True - assert "target_function" in found_functions - assert "missing_function" not in found_functions def test_analyze_imports_star_import(): @@ -831,10 +829,9 @@ def test_something(): test_file.write_text(test_content) target_functions = {"target_function"} - should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is False - assert found_functions == set() def test_analyze_imports_module_import(): @@ -850,10 +847,9 @@ def test_target(): test_file.write_text(test_content) target_functions = {"target_function"} - should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is True - assert "target_function" in found_functions def test_analyze_imports_dynamic_import(): @@ -870,10 +866,9 @@ def test_dynamic(): test_file.write_text(test_content) target_functions = {"target_function"} - should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is True - assert "target_function" in found_functions def test_analyze_imports_builtin_import(): @@ -888,10 +883,9 @@ def test_builtin_import(): test_file.write_text(test_content) target_functions = {"target_function"} - should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is True - assert "target_function" in found_functions def test_analyze_imports_no_matching_imports(): @@ -907,9 +901,8 @@ def test_unrelated(): test_file.write_text(test_content) target_functions = {"target_function", "another_function"} - should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is False - assert found_functions == set() def test_analyze_qualified_names(): @@ -924,9 +917,8 @@ def test_target(): test_file.write_text(test_content) target_functions = {"target_module.some_function"} - should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is True - assert "target_module.some_function" in found_functions @@ -943,11 +935,10 @@ def test_target( test_file.write_text(test_content) target_functions = {"target_function"} - should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + should_process = analyze_imports_in_test_file(test_file, target_functions) # Should be conservative with unparseable files assert should_process is True - assert found_functions == set() def test_filter_test_files_by_imports(): @@ -1085,10 +1076,9 @@ def test_conditional(): test_file.write_text(test_content) target_functions = {"target_function"} - should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is True - assert "target_function" in found_functions def test_analyze_imports_function_name_in_code(): @@ -1108,10 +1098,9 @@ def test_indirect(): test_file.write_text(test_content) target_functions = {"target_function"} - should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is True - assert "target_function" in found_functions def test_analyze_imports_aliased_imports(): @@ -1128,11 +1117,9 @@ def test_aliased(): test_file.write_text(test_content) target_functions = {"target_function", "missing_function"} - should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is True - assert "target_function" in found_functions - assert "missing_function" not in found_functions def test_analyze_imports_underscore_function_names(): @@ -1147,10 +1134,9 @@ def test_bubble(): test_file.write_text(test_content) target_functions = {"bubble_sort"} - should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is False - assert "bubble_sort" not in found_functions def test_discover_unit_tests_filtering_different_modules(): """Test import filtering with test files from completely different modules.""" From 36f320daa9fb4e00559b1d190638441930823ab9 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sat, 7 Jun 2025 23:06:10 -0700 Subject: [PATCH 3/5] refactor to early exit --- codeflash/discovery/discover_unit_tests.py | 159 ++++++++++++--------- tests/test_unit_test_discovery.py | 90 ++++++++++++ 2 files changed, 181 insertions(+), 68 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index c59d6e54..e69305e3 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -141,109 +141,132 @@ def close(self) -> None: class ImportAnalyzer(ast.NodeVisitor): - """AST-based analyzer to find all imports in a test file.""" + """AST-based analyzer to check if any qualified names from function_names_to_find are imported or used in a test file.""" def __init__(self, function_names_to_find: set[str]) -> None: self.function_names_to_find = function_names_to_find - self.imported_names: set[str] = set() - self.imported_modules: set[str] = set() - self.found_target_functions: set[str] = set() - self.qualified_names_called: set[str] = set() + self.found_any_target_function: bool = False + self.imported_modules: set[str] = set() # Track imported modules for usage analysis + self.has_dynamic_imports: bool = False + self.wildcard_modules: set[str] = set() def visit_Import(self, node: ast.Import) -> None: """Handle 'import module' statements.""" + if self.found_any_target_function: + return + for alias in node.names: module_name = alias.asname if alias.asname else alias.name self.imported_modules.add(module_name) - self.imported_names.add(module_name) - self.generic_visit(node) + + # Check for dynamic import modules + if alias.name == "importlib": + self.has_dynamic_imports = True + + # Check if module itself is a target qualified name + if module_name in self.function_names_to_find: + self.found_any_target_function = True + return + # Check if any target qualified name starts with this module + for target_func in self.function_names_to_find: + if target_func.startswith(f"{module_name}."): + self.found_any_target_function = True + return def visit_ImportFrom(self, node: ast.ImportFrom) -> None: """Handle 'from module import name' statements.""" - if node.module: - self.imported_modules.add(node.module) + if self.found_any_target_function: + return + + if not node.module: + return for alias in node.names: if alias.name == "*": - continue - imported_name = alias.asname if alias.asname else alias.name - self.imported_names.add(imported_name) - if alias.name in self.function_names_to_find: - self.found_target_functions.add(alias.name) - # Check for qualified name matches - if node.module: + self.wildcard_modules.add(node.module) + else: + imported_name = alias.asname if alias.asname else alias.name + self.imported_modules.add(imported_name) + + # Check for dynamic import functions + if node.module == "importlib" and alias.name == "import_module": + self.has_dynamic_imports = True + + # Check if imported name is a target qualified name + if alias.name in self.function_names_to_find: + self.found_any_target_function = True + return + # Check if module.name forms a target qualified name qualified_name = f"{node.module}.{alias.name}" if qualified_name in self.function_names_to_find: - self.found_target_functions.add(qualified_name) - self.generic_visit(node) + self.found_any_target_function = True + return - def visit_Call(self, node: ast.Call) -> None: - """Handle dynamic imports like importlib.import_module() or __import__().""" + def visit_Attribute(self, node: ast.Attribute) -> None: + """Handle attribute access like module.function_name.""" + if self.found_any_target_function: + return + + # Check if this is accessing a target function through an imported module if ( - isinstance(node.func, ast.Name) - and node.func.id == "__import__" - and node.args - and isinstance(node.args[0], ast.Constant) - and isinstance(node.args[0].value, str) + isinstance(node.value, ast.Name) + and node.value.id in self.imported_modules + and node.attr in self.function_names_to_find ): - # __import__("module_name") - self.imported_modules.add(node.args[0].value) - elif ( - isinstance(node.func, ast.Attribute) - and isinstance(node.func.value, ast.Name) - and node.func.value.id == "importlib" - and node.func.attr == "import_module" - and node.args - and isinstance(node.args[0], ast.Constant) - and isinstance(node.args[0].value, str) - ): - # importlib.import_module("module_name") - self.imported_modules.add(node.args[0].value) + self.found_any_target_function = True + return + + # Check if this is accessing a target function through a dynamically imported module + # Only if we've detected dynamic imports are being used + if self.has_dynamic_imports and node.attr in self.function_names_to_find: + self.found_any_target_function = True + return + self.generic_visit(node) def visit_Name(self, node: ast.Name) -> None: - """Check if any name usage matches our target functions.""" - if node.id in self.function_names_to_find: - self.found_target_functions.add(node.id) - self.generic_visit(node) + """Handle direct name usage like target_function().""" + if self.found_any_target_function: + return - def visit_Attribute(self, node: ast.Attribute) -> None: - """Handle module.function_name patterns.""" - if node.attr in self.function_names_to_find: - self.found_target_functions.add(node.attr) - if isinstance(node.value, ast.Name): - qualified_name = f"{node.value.id}.{node.attr}" - self.qualified_names_called.add(qualified_name) - self.generic_visit(node) + # Check for __import__ usage + if node.id == "__import__": + self.has_dynamic_imports = True + if node.id in self.function_names_to_find: + self.found_any_target_function = True + return -def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool: - """Analyze imports in a test file to determine if it might test any target functions. + # Check if this name could come from a wildcard import + for wildcard_module in self.wildcard_modules: + for target_func in self.function_names_to_find: + # Check if target_func is from this wildcard module and name matches + if target_func.startswith(f"{wildcard_module}.") and target_func.endswith(f".{node.id}"): + self.found_any_target_function = True + return - Args: - test_file_path: Path to the test file - target_functions: Set of function names we're looking for + self.generic_visit(node) - Returns: - bool: True if the test file should be processed (contains relevant imports), False otherwise + def generic_visit(self, node: ast.AST) -> None: + """Override generic_visit to stop traversal if a target function is found.""" + if self.found_any_target_function: + return + super().generic_visit(node) - """ - if isinstance(test_file_path, str): - test_file_path = Path(test_file_path) +def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool: + """Analyze a test file to see if it imports any of the target functions.""" try: - with test_file_path.open("r", encoding="utf-8") as f: - content = f.read() - - tree = ast.parse(content, filename=str(test_file_path)) + with Path(test_file_path).open("r", encoding="utf-8") as f: + source_code = f.read() + tree = ast.parse(source_code, filename=str(test_file_path)) analyzer = ImportAnalyzer(target_functions) analyzer.visit(tree) - - return bool(analyzer.found_target_functions) - - except (SyntaxError, UnicodeDecodeError, OSError) as e: + except (SyntaxError, FileNotFoundError) as e: logger.debug(f"Failed to analyze imports in {test_file_path}: {e}") return True + else: + return analyzer.found_any_target_function def filter_test_files_by_imports( diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index 8fd0b68c..f985f60f 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -834,6 +834,96 @@ def test_something(): assert should_process is False + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import * + +def test_target(): + assert target_function() is True +""" + test_file.group + test_file.write_text(test_content) + + target_functions = {"mymodule.target_function"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + + + + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import * + +def test_target(): + assert target_function_extended() is True +""" + test_file.write_text(test_content) + + # Should not match - target_function != target_function_extended + target_functions = {"mymodule.target_function"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is False + + + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import * + +def test_something(): + x = 42 + assert x == 42 +""" + test_file.write_text(test_content) + + target_functions = {"mymodule.target_function"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is False + + + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import * + +def test_something(): + message = "calling target_function" + assert "target_function" in message +""" + test_file.write_text(test_content) + + target_functions = {"mymodule.target_function"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + # String literals are ast.Constant nodes, not ast.Name nodes, so they don't match + assert should_process is False + + + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import target_function +from othermodule import * + +def test_target(): + assert target_function() is True + assert other_func() is True +""" + test_file.write_text(test_content) + + target_functions = {"mymodule.target_function", "othermodule.other_func"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + + + + def test_analyze_imports_module_import(): """Test module imports with function access patterns.""" with tempfile.TemporaryDirectory() as tmpdirname: From 0fcd59c4200dae31c31e9a45bf353d64bd5b18fb Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sun, 8 Jun 2025 00:14:01 -0700 Subject: [PATCH 4/5] easier debugging. --- codeflash/discovery/discover_unit_tests.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index e69305e3..6bcf8156 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -146,7 +146,8 @@ class ImportAnalyzer(ast.NodeVisitor): def __init__(self, function_names_to_find: set[str]) -> None: self.function_names_to_find = function_names_to_find self.found_any_target_function: bool = False - self.imported_modules: set[str] = set() # Track imported modules for usage analysis + self.found_qualified_name = None + self.imported_modules: set[str] = set() self.has_dynamic_imports: bool = False self.wildcard_modules: set[str] = set() @@ -166,11 +167,13 @@ def visit_Import(self, node: ast.Import) -> None: # Check if module itself is a target qualified name if module_name in self.function_names_to_find: self.found_any_target_function = True + self.found_qualified_name = module_name return # Check if any target qualified name starts with this module for target_func in self.function_names_to_find: if target_func.startswith(f"{module_name}."): self.found_any_target_function = True + self.found_qualified_name = target_func return def visit_ImportFrom(self, node: ast.ImportFrom) -> None: @@ -195,11 +198,13 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # Check if imported name is a target qualified name if alias.name in self.function_names_to_find: self.found_any_target_function = True + self.found_qualified_name = alias.name return # Check if module.name forms a target qualified name qualified_name = f"{node.module}.{alias.name}" if qualified_name in self.function_names_to_find: self.found_any_target_function = True + self.found_qualified_name = qualified_name return def visit_Attribute(self, node: ast.Attribute) -> None: @@ -214,12 +219,14 @@ def visit_Attribute(self, node: ast.Attribute) -> None: and node.attr in self.function_names_to_find ): self.found_any_target_function = True + self.found_qualified_name = node.attr return # Check if this is accessing a target function through a dynamically imported module # Only if we've detected dynamic imports are being used if self.has_dynamic_imports and node.attr in self.function_names_to_find: self.found_any_target_function = True + self.found_qualified_name = node.attr return self.generic_visit(node) @@ -235,6 +242,7 @@ def visit_Name(self, node: ast.Name) -> None: if node.id in self.function_names_to_find: self.found_any_target_function = True + self.found_qualified_name = node.id return # Check if this name could come from a wildcard import @@ -243,6 +251,7 @@ def visit_Name(self, node: ast.Name) -> None: # Check if target_func is from this wildcard module and name matches if target_func.startswith(f"{wildcard_module}.") and target_func.endswith(f".{node.id}"): self.found_any_target_function = True + self.found_qualified_name = target_func return self.generic_visit(node) @@ -266,7 +275,11 @@ def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: s logger.debug(f"Failed to analyze imports in {test_file_path}: {e}") return True else: - return analyzer.found_any_target_function + if analyzer.found_any_target_function: + logger.debug(f"Test file {test_file_path} imports target function: {analyzer.found_qualified_name}") + return True + logger.debug(f"Test file {test_file_path} does not import any target functions.") + return False def filter_test_files_by_imports( @@ -288,7 +301,6 @@ def filter_test_files_by_imports( logger.debug(f"Target functions for import filtering: {target_functions}") filtered_map = {} - for test_file, test_functions in file_to_test_map.items(): should_process = analyze_imports_in_test_file(test_file, target_functions) if should_process: @@ -315,7 +327,6 @@ def discover_unit_tests( functions_to_optimize = None if file_to_funcs_to_optimize: functions_to_optimize = [func for funcs_list in file_to_funcs_to_optimize.values() for func in funcs_list] - function_to_tests, num_discovered_tests = strategy(cfg, discover_only_these_tests, functions_to_optimize) return function_to_tests, num_discovered_tests From a928fc00ef42c73e5fa1b93c936f0d35af77fa54 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sat, 7 Jun 2025 19:38:17 -0700 Subject: [PATCH 5/5] UX changes ux change UX changes --- codeflash/code_utils/code_utils.py | 9 ++++++++- codeflash/code_utils/env_utils.py | 10 +++++----- codeflash/discovery/functions_to_optimize.py | 11 +++++++---- codeflash/optimization/optimizer.py | 15 +++++++++++---- 4 files changed, 31 insertions(+), 14 deletions(-) diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 6a9de176..22cd817d 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -5,6 +5,7 @@ import re import shutil import site +import sys from contextlib import contextmanager from functools import lru_cache from pathlib import Path @@ -12,7 +13,7 @@ import tomlkit -from codeflash.cli_cmds.console import logger +from codeflash.cli_cmds.console import logger, paneled_text from codeflash.code_utils.config_parser import find_pyproject_toml ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE) @@ -213,3 +214,9 @@ def cleanup_paths(paths: list[Path]) -> None: def restore_conftest(path_to_content_map: dict[Path, str]) -> None: for path, file_content in path_to_content_map.items(): path.write_text(file_content, encoding="utf8") + + +def exit_with_message(message: str, *, error_on_exit: bool = False) -> None: + paneled_text(message, panel_args={"style": "red"}) + + sys.exit(1 if error_on_exit else 0) diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index 26528439..d74c0313 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -1,13 +1,13 @@ from __future__ import annotations import os -import sys import tempfile from functools import lru_cache from pathlib import Path from typing import Optional from codeflash.cli_cmds.console import logger +from codeflash.code_utils.code_utils import exit_with_message from codeflash.code_utils.formatter import format_code from codeflash.code_utils.shell_utils import read_api_key_from_shell_config @@ -24,11 +24,11 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = try: format_code(formatter_cmds, tmp_file, print_status=False) except Exception: - print( - "⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again." + exit_with_message( + "⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.", + error_on_exit=True, ) - if exit_on_failure: - sys.exit(1) + return return_code diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index c50a0ad4..2035fd17 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -17,6 +17,7 @@ from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again from codeflash.cli_cmds.console import DEBUG_MODE, console, logger from codeflash.code_utils.code_utils import ( + exit_with_message, is_class_defined_in_file, module_name_from_file_path, path_belongs_to_site_packages, @@ -179,8 +180,9 @@ def get_functions_to_optimize( if only_get_this_function is not None: split_function = only_get_this_function.split(".") if len(split_function) > 2: - msg = "Function name should be in the format 'function_name' or 'class_name.function_name'" - raise ValueError(msg) + exit_with_message( + "Function name should be in the format 'function_name' or 'class_name.function_name'" + ) if len(split_function) == 2: class_name, only_function_name = split_function else: @@ -193,8 +195,9 @@ def get_functions_to_optimize( ): found_function = fn if found_function is None: - msg = f"Function {only_function_name} not found in file {file} or the function does not have a 'return' statement or is a property" - raise ValueError(msg) + exit_with_message( + f"Function {only_function_name} not found in file {file}\nor the function does not have a 'return' statement or is a property" + ) functions[file] = [found_function] else: logger.info("Finding all functions modified in the current git diff ...") diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 0401efe3..89737912 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -72,7 +72,6 @@ def create_function_optimizer( def run(self) -> None: from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint from codeflash.code_utils.code_replacer import normalize_code, normalize_node - from codeflash.code_utils.code_utils import cleanup_paths from codeflash.code_utils.static_analysis import ( analyze_imported_modules, get_first_top_level_function_or_method_ast, @@ -283,14 +282,22 @@ def run(self) -> None: if function_optimizer: function_optimizer.cleanup_generated_files() - if self.test_cfg.concolic_test_root_dir: - cleanup_paths([self.test_cfg.concolic_test_root_dir]) + self.cleanup_temporary_paths() + + def cleanup_temporary_paths(self) -> None: + from codeflash.code_utils.code_utils import cleanup_paths + + cleanup_paths([self.test_cfg.concolic_test_root_dir, self.replay_tests_dir]) def run_with_args(args: Namespace) -> None: + optimizer = None try: optimizer = Optimizer(args) optimizer.run() except KeyboardInterrupt: - logger.warning("Keyboard interrupt received. Exiting, please wait…") + logger.warning("Keyboard interrupt received. Cleaning up and exiting, please wait…") + if optimizer: + optimizer.cleanup_temporary_paths() + raise SystemExit from None