Skip to content

follow up on pre-filtering PR & better exit message UX #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
@@ -5,14 +5,15 @@
import re
import shutil
import site
import sys
from contextlib import contextmanager
from functools import lru_cache
from pathlib import Path
from tempfile import TemporaryDirectory

import tomlkit

from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.console import logger, paneled_text
from codeflash.code_utils.config_parser import find_pyproject_toml

ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
@@ -213,3 +214,9 @@ def cleanup_paths(paths: list[Path]) -> None:
def restore_conftest(path_to_content_map: dict[Path, str]) -> None:
for path, file_content in path_to_content_map.items():
path.write_text(file_content, encoding="utf8")


def exit_with_message(message: str, *, error_on_exit: bool = False) -> None:
paneled_text(message, panel_args={"style": "red"})

sys.exit(1 if error_on_exit else 0)
10 changes: 5 additions & 5 deletions codeflash/code_utils/env_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import os
import sys
import tempfile
from functools import lru_cache
from pathlib import Path
from typing import Optional

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.formatter import format_code
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config

@@ -24,11 +24,11 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
try:
format_code(formatter_cmds, tmp_file, print_status=False)
except Exception:
print(
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again."
exit_with_message(
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.",
error_on_exit=True,
)
if exit_on_failure:
sys.exit(1)

return return_code


208 changes: 117 additions & 91 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
@@ -141,144 +141,175 @@ def close(self) -> None:


class ImportAnalyzer(ast.NodeVisitor):
"""AST-based analyzer to find all imports in a test file."""
"""AST-based analyzer to check if any qualified names from function_names_to_find are imported or used in a test file."""

def __init__(self, function_names_to_find: set[str]) -> None:
self.function_names_to_find = function_names_to_find
self.imported_names: set[str] = set()
self.found_any_target_function: bool = False
self.found_qualified_name = None
self.imported_modules: set[str] = set()
self.found_target_functions: set[str] = set()
self.qualified_names_called: set[str] = set()
self.has_dynamic_imports: bool = False
self.wildcard_modules: set[str] = set()

def visit_Import(self, node: ast.Import) -> None:
"""Handle 'import module' statements."""
if self.found_any_target_function:
return

for alias in node.names:
module_name = alias.asname if alias.asname else alias.name
self.imported_modules.add(module_name)
self.imported_names.add(module_name)
self.generic_visit(node)

# Check for dynamic import modules
if alias.name == "importlib":
self.has_dynamic_imports = True

# Check if module itself is a target qualified name
if module_name in self.function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = module_name
return
# Check if any target qualified name starts with this module
for target_func in self.function_names_to_find:
if target_func.startswith(f"{module_name}."):
Comment on lines 154 to +174
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 209% (2.09x) speedup for ImportAnalyzer.visit_Import in codeflash/discovery/discover_unit_tests.py

⏱️ Runtime : 521 microseconds 169 microseconds (best of 357 runs)

📝 Explanation and details Here’s an optimized version of your `visit_Import` method. The main bottleneck is the **nested loop** at the end, which repeatedly checks every `target_func` in `function_names_to_find` for each module (O(M×N)). This can be reduced by **pre-indexing** your targets by prefix (the possible module), and **batching string manipulation** outside the loop.

Below is the rewrite. I only updated the method; the rest of the class and comments are unchanged.

Changes explained:

  • In __init__, we precompute self._module_prefix_map so that for each unique module prefix (the part before .), we map all targets starting with that prefix.
  • In visit_Import, instead of iterating every target_func for every module, we check if the module is a prefix in self._module_prefix_map. No more O(M×N) lookups.
  • Only string splits and lookups in dictionaries/sets—much faster for large input sets.
  • All original function signatures and return values are preserved.

No comments or function names were altered, only efficiency was improved.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 94 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests Details
from __future__ import annotations

import ast

# imports
import pytest  # used for our unit tests
from codeflash.discovery.discover_unit_tests import ImportAnalyzer

# unit tests

# Helper to create ast.Import node
def make_import(names):
    """Helper to create an ast.Import node with given list of (name, asname) tuples."""
    return ast.Import(names=[ast.alias(name=n, asname=a) for n, a in names])

# ------------------------------
# Basic Test Cases
# ------------------------------

def test_import_single_module_found_as_target():
    # Test importing a module that is exactly the target
    analyzer = ImportAnalyzer(function_names_to_find={'os'})
    node = make_import([('os', None)])
    analyzer.visit_Import(node)

def test_import_multiple_modules_one_is_target():
    # Test importing multiple modules, one of which is the target
    analyzer = ImportAnalyzer(function_names_to_find={'sys'})
    node = make_import([('os', None), ('sys', None)])
    analyzer.visit_Import(node)

def test_import_module_with_alias_target_is_alias():
    # Test importing with asname, target is the alias
    analyzer = ImportAnalyzer(function_names_to_find={'np'})
    node = make_import([('numpy', 'np')])
    analyzer.visit_Import(node)

def test_import_module_with_alias_target_is_original():
    # Test importing with asname, target is the original name (should not match)
    analyzer = ImportAnalyzer(function_names_to_find={'numpy'})
    node = make_import([('numpy', 'np')])
    analyzer.visit_Import(node)

def test_import_module_prefix_match():
    # Test importing a module that is a prefix of a qualified target
    analyzer = ImportAnalyzer(function_names_to_find={'os.path'})
    node = make_import([('os', None)])
    analyzer.visit_Import(node)

def test_import_module_no_match():
    # Test importing a module that is not in the target set
    analyzer = ImportAnalyzer(function_names_to_find={'json'})
    node = make_import([('os', None)])
    analyzer.visit_Import(node)

def test_import_module_dynamic_importlib():
    # Test importing 'importlib' sets has_dynamic_imports
    analyzer = ImportAnalyzer(function_names_to_find={'random'})
    node = make_import([('importlib', None)])
    analyzer.visit_Import(node)

# ------------------------------
# Edge Test Cases
# ------------------------------

def test_import_module_with_empty_target_set():
    # No targets to find, should never set found_any_target_function
    analyzer = ImportAnalyzer(function_names_to_find=set())
    node = make_import([('os', None)])
    analyzer.visit_Import(node)

def test_import_module_with_empty_import():
    # Import node with no names (should not error)
    analyzer = ImportAnalyzer(function_names_to_find={'os'})
    node = make_import([])
    analyzer.visit_Import(node)

def test_import_module_with_long_asname():
    # Alias is a long string, target is the alias
    analyzer = ImportAnalyzer(function_names_to_find={'superlongalias'})
    node = make_import([('os', 'superlongalias')])
    analyzer.visit_Import(node)

def test_import_module_with_dot_in_name():
    # Importing a dotted module name, target is the full dotted name
    analyzer = ImportAnalyzer(function_names_to_find={'os.path'})
    node = make_import([('os.path', None)])
    analyzer.visit_Import(node)

def test_import_module_with_alias_and_prefix_match():
    # Importing with asname, target is a qualified name that starts with the alias
    analyzer = ImportAnalyzer(function_names_to_find={'np.linalg'})
    node = make_import([('numpy', 'np')])
    analyzer.visit_Import(node)

def test_import_module_found_any_short_circuits():
    # If found_any_target_function is already True, visit_Import should return immediately
    analyzer = ImportAnalyzer(function_names_to_find={'os'})
    analyzer.found_any_target_function = True
    node = make_import([('sys', None)])
    analyzer.visit_Import(node)

def test_import_multiple_modules_with_alias_and_prefix():
    # Multiple imports, one with alias that matches prefix of a target
    analyzer = ImportAnalyzer(function_names_to_find={'pd.DataFrame'})
    node = make_import([('pandas', 'pd'), ('os', None)])
    analyzer.visit_Import(node)

def test_import_module_with_similar_names():
    # Target is 'os', import 'os2' (should not match)
    analyzer = ImportAnalyzer(function_names_to_find={'os'})
    node = make_import([('os2', None)])
    analyzer.visit_Import(node)

def test_import_module_with_substring_alias():
    # Alias is a substring of target, but not a prefix (should not match)
    analyzer = ImportAnalyzer(function_names_to_find={'np.linalg'})
    node = make_import([('numpy', 'n')])
    analyzer.visit_Import(node)

# ------------------------------
# Large Scale Test Cases
# ------------------------------

def test_import_many_modules_no_target():
    # Import a large number of modules, none of which are targets
    modules = [(f'module_{i}', None) for i in range(100)]
    analyzer = ImportAnalyzer(function_names_to_find={'target_module'})
    node = make_import(modules)
    analyzer.visit_Import(node)

def test_import_many_modules_target_last():
    # Import a large number of modules, target is the last one
    modules = [(f'module_{i}', None) for i in range(99)] + [('target_module', None)]
    analyzer = ImportAnalyzer(function_names_to_find={'target_module'})
    node = make_import(modules)
    analyzer.visit_Import(node)

def test_import_many_modules_with_aliases_and_targets():
    # Mix of modules with aliases, some matching targets, some not
    modules = [(f'module_{i}', f'alias_{i}') for i in range(50)] + [('numpy', 'np'), ('os', None)]
    analyzer = ImportAnalyzer(function_names_to_find={'np.linalg', 'os'})
    node = make_import(modules)
    analyzer.visit_Import(node)

def test_import_many_targets_multiple_matches():
    # Multiple targets, multiple matches, but only first match is set
    modules = [('os', None), ('sys', None), ('json', None)]
    analyzer = ImportAnalyzer(function_names_to_find={'os', 'sys', 'json'})
    node = make_import(modules)
    analyzer.visit_Import(node)

def test_import_many_modules_with_prefix_targets():
    # Import many modules, one is a prefix for a qualified name target
    modules = [(f'module_{i}', None) for i in range(99)] + [('np', None)]
    analyzer = ImportAnalyzer(function_names_to_find={'np.linalg'})
    node = make_import(modules)
    analyzer.visit_Import(node)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

from __future__ import annotations

import ast

# imports
import pytest  # used for our unit tests
from codeflash.discovery.discover_unit_tests import ImportAnalyzer

# unit tests

# --------- Basic Test Cases ---------

def test_import_single_module_found():
    # Test importing a module that is in function_names_to_find
    code = "import math"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"math"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_single_module_not_found():
    # Test importing a module that is NOT in function_names_to_find
    code = "import os"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"math"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_multiple_modules_one_found():
    # Test importing multiple modules, one of which is in function_names_to_find
    code = "import os, math"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"math"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_multiple_modules_none_found():
    # Test importing multiple modules, none of which are in function_names_to_find
    code = "import os, sys"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"math"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_with_asname_found():
    # Test importing a module with an alias that's in function_names_to_find
    code = "import numpy as np"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"np"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_with_asname_not_found():
    # Test importing a module with an alias that's NOT in function_names_to_find
    code = "import numpy as np"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"numpy"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_module_with_dot_not_found():
    # Test importing a module with a dot in its name, not in function_names_to_find
    code = "import xml.etree"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"math"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_module_with_dot_found():
    # Test importing a module with a dot in its name, in function_names_to_find
    code = "import xml.etree"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"xml.etree"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_module_prefix_match():
    # Test that importing a module matches a qualified name in function_names_to_find
    code = "import math"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"math.sqrt"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_module_prefix_match_with_asname():
    # Test that importing a module with asname matches qualified name with asname prefix
    code = "import math as m"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"m.sqrt"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

# --------- Edge Test Cases ---------

def test_importlib_sets_dynamic_import_flag():
    # Test that importing 'importlib' sets has_dynamic_imports to True
    code = "import importlib"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"something"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_importlib_with_asname_sets_dynamic_import_flag():
    # Test that importing 'importlib' with asname still sets has_dynamic_imports
    code = "import importlib as il"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"something"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_module_with_asname_prefix_match():
    # Test importing a module with asname and qualified name match
    code = "import numpy as np"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"np.linalg"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_module_with_asname_prefix_no_match():
    # Test importing a module with asname, but qualified name does not match
    code = "import numpy as np"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"numpy.linalg"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_multiple_modules_first_triggers_found():
    # Test that after found_any_target_function is set, further imports are ignored
    code = "import math, os"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"math", "os"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_empty_names_list():
    # Test an Import node with no names (should not happen in real code, but test for robustness)
    node = ast.Import(names=[])
    analyzer = ImportAnalyzer({"math"})
    analyzer.visit_Import(node)

def test_import_module_with_empty_function_names_to_find():
    # Test with empty function_names_to_find set
    code = "import math"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer(set())
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_module_with_duplicate_names():
    # Test that importing the same module twice only adds it once to imported_modules
    code = "import math, math"
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"math"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_module_with_unicode_name():
    # Test importing a module with a unicode name (synthetic test)
    node = ast.Import(names=[ast.alias(name="módülé", asname=None)])
    analyzer = ImportAnalyzer({"módülé"})
    analyzer.visit_Import(node)

def test_import_module_with_empty_name():
    # Test importing a module with an empty string as name (synthetic test)
    node = ast.Import(names=[ast.alias(name="", asname=None)])
    analyzer = ImportAnalyzer({""})
    analyzer.visit_Import(node)

# --------- Large Scale Test Cases ---------

def test_import_large_number_of_modules_one_found():
    # Test importing a large number of modules, one of which is in function_names_to_find
    module_names = [f"mod{i}" for i in range(100)]
    code = "import " + ", ".join(module_names)
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"mod42"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_large_number_of_modules_none_found():
    # Test importing a large number of modules, none of which are in function_names_to_find
    module_names = [f"mod{i}" for i in range(100)]
    code = "import " + ", ".join(module_names)
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"notfound"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_large_number_of_function_names_to_find():
    # Test with a large function_names_to_find set, only one matches
    module_names = [f"mod{i}" for i in range(50)]
    function_names_to_find = {f"mod{i}" for i in range(50, 100)}
    function_names_to_find.add("mod42")
    code = "import " + ", ".join(module_names)
    tree = ast.parse(code)
    analyzer = ImportAnalyzer(function_names_to_find)
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_large_number_of_modules_with_asnames():
    # Test importing many modules with asnames, one matches function_names_to_find
    module_names = [f"mod{i}" for i in range(100)]
    asnames = [f"m{i}" for i in range(100)]
    aliases = [f"{mod} as {asn}" for mod, asn in zip(module_names, asnames)]
    code = "import " + ", ".join(aliases)
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"m42"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_large_number_of_modules_with_prefix_match():
    # Test importing many modules, and function_names_to_find contains qualified names
    module_names = [f"mod{i}" for i in range(100)]
    qualified_names = {f"{mod}.foo" for mod in module_names}
    code = "import " + ", ".join(module_names)
    tree = ast.parse(code)
    analyzer = ImportAnalyzer(qualified_names)
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)

def test_import_large_number_of_modules_with_importlib():
    # Test importing many modules, including 'importlib', should set has_dynamic_imports
    module_names = [f"mod{i}" for i in range(99)] + ["importlib"]
    code = "import " + ", ".join(module_names)
    tree = ast.parse(code)
    analyzer = ImportAnalyzer({"notfound"})
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            analyzer.visit_Import(node)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr310-2025-06-10T01.34.32

Click to see suggested changes
Suggested change
def visit_Import(self, node: ast.Import) -> None:
"""Handle 'import module' statements."""
if self.found_any_target_function:
return
for alias in node.names:
module_name = alias.asname if alias.asname else alias.name
self.imported_modules.add(module_name)
self.imported_names.add(module_name)
self.generic_visit(node)
# Check for dynamic import modules
if alias.name == "importlib":
self.has_dynamic_imports = True
# Check if module itself is a target qualified name
if module_name in self.function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = module_name
return
# Check if any target qualified name starts with this module
for target_func in self.function_names_to_find:
if target_func.startswith(f"{module_name}."):
# Build a mapping from module names to all target_funcs starting with module_name
self._module_prefix_map = {}
for func in function_names_to_find:
if "." in func:
prefix = func.split(".", 1)[0]
self._module_prefix_map.setdefault(prefix, set()).add(func)
def visit_Import(self, node: ast.Import) -> None:
"""Handle 'import module' statements."""
if self.found_any_target_function:
return
module_prefix_map = self._module_prefix_map
function_names_to_find = self.function_names_to_find
imported_modules = self.imported_modules
for alias in node.names:
module_name = alias.asname if alias.asname else alias.name
imported_modules.add(module_name)
if alias.name == "importlib":
self.has_dynamic_imports = True
if module_name in function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = module_name
return
# New: quickly check if any function starts with this module_name + '.'
if module_name in module_prefix_map:
# No need to check startswith, as all have this name as prefix.
# However, need to confirm it's really imported as such
for target_func in module_prefix_map[module_name]:

self.found_any_target_function = True
self.found_qualified_name = target_func
return

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Handle 'from module import name' statements."""
if node.module:
self.imported_modules.add(node.module)
if self.found_any_target_function:
return

if not node.module:
return

for alias in node.names:
if alias.name == "*":
continue
imported_name = alias.asname if alias.asname else alias.name
self.imported_names.add(imported_name)
if alias.name in self.function_names_to_find:
self.found_target_functions.add(alias.name)
# Check for qualified name matches
if node.module:
self.wildcard_modules.add(node.module)
else:
imported_name = alias.asname if alias.asname else alias.name
self.imported_modules.add(imported_name)

# Check for dynamic import functions
if node.module == "importlib" and alias.name == "import_module":
self.has_dynamic_imports = True

# Check if imported name is a target qualified name
if alias.name in self.function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = alias.name
return
# Check if module.name forms a target qualified name
qualified_name = f"{node.module}.{alias.name}"
if qualified_name in self.function_names_to_find:
self.found_target_functions.add(qualified_name)
self.generic_visit(node)
self.found_any_target_function = True
self.found_qualified_name = qualified_name
return

Comment on lines +181 to +209
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 31% (0.31x) speedup for ImportAnalyzer.visit_ImportFrom in codeflash/discovery/discover_unit_tests.py

⏱️ Runtime : 2.25 milliseconds 1.72 milliseconds (best of 221 runs)

📝 Explanation and details Here is an optimized version of your program. The main improvements are.
  • Replace growing of imported_modules and wildcard_modules with faster local variables and reduced set insertion calls (avoid unnecessary growth).
  • Check self.found_any_target_function immediately after mutating it, to avoid executing unnecessary lines.
  • Use early exits aggressively (via return or break) to reduce the number of instructions and comparisons per visit_ImportFrom.
  • Cache attributes/locals where appropriate to reduce attribute lookup cost inside loops.
  • Use tuple membership check for dynamic import detection to avoid repeated string comparisons.
  • Reduce the number of repeated dictionary lookups for alias.name and alias.asname.

Summary of optimizations:

  • Minimize attribute lookups in tight loop (target_functions, imported_modules, wildcard_modules as locals).
  • Use continue instead of else for * import check for clearer fast path.
  • Avoid extra qualified name computation if already found.
  • Remove comments only if the corresponding lines were changed for clarity.
  • Fully preserve functional behavior and interface.

This will give a measurable speedup especially for large numbers of imports and in large ASTs.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 1106 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests Details
from __future__ import annotations

import ast

# imports
import pytest  # used for our unit tests
from codeflash.discovery.discover_unit_tests import ImportAnalyzer

# unit tests

def make_importfrom_node(module, names):
    """
    Helper to create an ast.ImportFrom node.
    names: list of (name, asname) tuples.
    """
    return ast.ImportFrom(
        module=module,
        names=[ast.alias(name=n, asname=a) for n, a in names],
        level=0,
        lineno=1,
        col_offset=0
    )

# ---------------------------
# 1. Basic Test Cases
# ---------------------------

def test_importfrom_basic_name_found():
    # from math import sqrt
    node = make_importfrom_node("math", [("sqrt", None)])
    analyzer = ImportAnalyzer({"sqrt"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_basic_qualified_name_found():
    # from math import sqrt
    node = make_importfrom_node("math", [("sqrt", None)])
    analyzer = ImportAnalyzer({"math.sqrt"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_basic_asname_import():
    # from os.path import join as path_join
    node = make_importfrom_node("os.path", [("join", "path_join")])
    analyzer = ImportAnalyzer({"os.path.join"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_basic_no_target_found():
    # from collections import deque
    node = make_importfrom_node("collections", [("deque", None)])
    analyzer = ImportAnalyzer({"Counter"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_multiple_names_one_target():
    # from math import sqrt, cos
    node = make_importfrom_node("math", [("sqrt", None), ("cos", None)])
    analyzer = ImportAnalyzer({"cos"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_multiple_names_no_target():
    # from math import sqrt, cos
    node = make_importfrom_node("math", [("sqrt", None), ("cos", None)])
    analyzer = ImportAnalyzer({"tan"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_importlib_import_module_dynamic():
    # from importlib import import_module
    node = make_importfrom_node("importlib", [("import_module", None)])
    analyzer = ImportAnalyzer({"import_module"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_importlib_import_module_not_target():
    # from importlib import import_module
    node = make_importfrom_node("importlib", [("import_module", None)])
    analyzer = ImportAnalyzer({"something_else"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_wildcard_import():
    # from math import *
    node = make_importfrom_node("math", [("*", None)])
    analyzer = ImportAnalyzer({"sqrt"})
    analyzer.visit_ImportFrom(node)

# ---------------------------
# 2. Edge Test Cases
# ---------------------------

def test_importfrom_empty_module():
    # from . import foo (module=None)
    node = ast.ImportFrom(module=None, names=[ast.alias(name="foo", asname=None)], level=1)
    analyzer = ImportAnalyzer({"foo"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_empty_names():
    # from math import  (no names)
    node = make_importfrom_node("math", [])
    analyzer = ImportAnalyzer({"sqrt"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_duplicate_names():
    # from math import sqrt, sqrt
    node = make_importfrom_node("math", [("sqrt", None), ("sqrt", None)])
    analyzer = ImportAnalyzer({"sqrt"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_asname_and_name_both_targets():
    # from foo import bar as baz, bar
    node = make_importfrom_node("foo", [("bar", "baz"), ("bar", None)])
    analyzer = ImportAnalyzer({"foo.bar"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_found_any_target_short_circuit():
    # from math import sqrt, cos, tan
    node = make_importfrom_node("math", [("sqrt", None), ("cos", None), ("tan", None)])
    analyzer = ImportAnalyzer({"cos"})
    # Set found_any_target_function to True before visiting
    analyzer.found_any_target_function = True
    analyzer.visit_ImportFrom(node)

def test_importfrom_star_and_normal_imports():
    # from foo import *, bar
    node = make_importfrom_node("foo", [("*", None), ("bar", None)])
    analyzer = ImportAnalyzer({"foo.bar"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_asname_is_target():
    # from foo import bar as baz
    node = make_importfrom_node("foo", [("bar", "baz")])
    analyzer = ImportAnalyzer({"baz"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_module_with_dot():
    # from foo.bar import baz
    node = make_importfrom_node("foo.bar", [("baz", None)])
    analyzer = ImportAnalyzer({"foo.bar.baz"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_module_is_empty_string():
    # from "" import foo (invalid, but test for robustness)
    node = make_importfrom_node("", [("foo", None)])
    analyzer = ImportAnalyzer({"foo"})
    analyzer.visit_ImportFrom(node)

# ---------------------------
# 3. Large Scale Test Cases
# ---------------------------

def test_importfrom_many_names_one_target():
    # from mod import a0, a1, ..., a999 (target is a789)
    names = [(f"a{i}", None) for i in range(1000)]
    node = make_importfrom_node("mod", names)
    analyzer = ImportAnalyzer({"a789"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_many_names_no_target():
    # from mod import a0, a1, ..., a999 (no target)
    names = [(f"a{i}", None) for i in range(1000)]
    node = make_importfrom_node("mod", names)
    analyzer = ImportAnalyzer({"not_present"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_many_wildcard_imports():
    # from mod0 import *, from mod1 import *, ..., from mod9 import *
    analyzer = ImportAnalyzer({"something"})
    for i in range(10):
        node = make_importfrom_node(f"mod{i}", [("*", None)])
        analyzer.visit_ImportFrom(node)

def test_importfrom_many_importlib_import_module():
    # from importlib import import_module as imp0, imp1, ..., imp999
    names = [("import_module", f"imp{i}") for i in range(1000)]
    node = make_importfrom_node("importlib", names)
    analyzer = ImportAnalyzer({"not_present"})
    analyzer.visit_ImportFrom(node)
    # All asnames should be in imported_modules
    for i in range(1000):
        pass

def test_importfrom_many_targets_short_circuit():
    # from mod import a0, ..., a999 (targets: a0, a500, a999)
    names = [(f"a{i}", None) for i in range(1000)]
    node = make_importfrom_node("mod", names)
    analyzer = ImportAnalyzer({"a0", "a500", "a999"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_large_qualified_names():
    # from mod import a0, ..., a999 (targets: mod.a0, mod.a999)
    names = [(f"a{i}", None) for i in range(1000)]
    node = make_importfrom_node("mod", names)
    analyzer = ImportAnalyzer({"mod.a0", "mod.a999"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_large_asname_targets():
    # from mod import a0 as b0, ..., a999 as b999 (target: mod.a999)
    names = [(f"a{i}", f"b{i}") for i in range(1000)]
    node = make_importfrom_node("mod", names)
    analyzer = ImportAnalyzer({"mod.a999"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_large_asname_non_target():
    # from mod import a0 as b0, ..., a999 as b999 (target: b123)
    names = [(f"a{i}", f"b{i}") for i in range(1000)]
    node = make_importfrom_node("mod", names)
    analyzer = ImportAnalyzer({"b123"})
    analyzer.visit_ImportFrom(node)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

from __future__ import annotations

import ast

# imports
import pytest  # used for our unit tests
from codeflash.discovery.discover_unit_tests import ImportAnalyzer

# unit tests

def parse_first_importfrom(src: str) -> ast.ImportFrom:
    """Utility to parse the first ImportFrom node from a code snippet."""
    for node in ast.walk(ast.parse(src)):
        if isinstance(node, ast.ImportFrom):
            return node
    raise ValueError("No ImportFrom found in source.")

# ---------------------------
# 1. Basic Test Cases
# ---------------------------

def test_simple_import_detects_target_name():
    # from foo import bar
    src = "from foo import bar"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"bar"})
    analyzer.visit_ImportFrom(node)

def test_simple_import_detects_qualified_target():
    # from foo import bar
    src = "from foo import bar"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"foo.bar"})
    analyzer.visit_ImportFrom(node)

def test_import_with_alias_detects_target():
    # from foo import bar as baz
    src = "from foo import bar as baz"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"bar"})
    analyzer.visit_ImportFrom(node)

def test_import_with_alias_detects_qualified_target():
    # from foo import bar as baz
    src = "from foo import bar as baz"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"foo.bar"})
    analyzer.visit_ImportFrom(node)

def test_import_multiple_names_detects_any_target():
    # from foo import bar, baz
    src = "from foo import bar, baz"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"baz"})
    analyzer.visit_ImportFrom(node)

def test_import_multiple_names_detects_qualified_target():
    # from foo import bar, baz
    src = "from foo import bar, baz"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"foo.baz"})
    analyzer.visit_ImportFrom(node)

def test_import_no_target_found():
    # from foo import bar
    src = "from foo import bar"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"qux"})
    analyzer.visit_ImportFrom(node)

def test_importlib_import_module_sets_dynamic():
    # from importlib import import_module
    src = "from importlib import import_module"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"something"})
    analyzer.visit_ImportFrom(node)

def test_importlib_import_module_with_alias_sets_dynamic():
    # from importlib import import_module as im
    src = "from importlib import import_module as im"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"something"})
    analyzer.visit_ImportFrom(node)

def test_wildcard_import_records_module():
    # from foo import *
    src = "from foo import *"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"bar"})
    analyzer.visit_ImportFrom(node)

# ---------------------------
# 2. Edge Test Cases
# ---------------------------

def test_importfrom_with_no_module():
    # from . import bar
    src = "from . import bar"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"bar"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_with_empty_names():
    # from foo import
    # This is a syntax error, so let's simulate an ImportFrom node with empty names
    node = ast.ImportFrom(module="foo", names=[], level=0)
    analyzer = ImportAnalyzer({"bar"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_with_duplicate_names():
    # from foo import bar, bar
    src = "from foo import bar, bar"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"bar"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_with_star_and_names():
    # from foo import *, bar
    # Not valid Python, but let's simulate it
    node = ast.ImportFrom(module="foo", names=[ast.alias(name="*", asname=None), ast.alias(name="bar", asname=None)], level=0)
    analyzer = ImportAnalyzer({"bar"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_with_level_relative_import():
    # from ..foo import bar
    node = ast.ImportFrom(module="foo", names=[ast.alias(name="bar", asname=None)], level=2)
    analyzer = ImportAnalyzer({"bar"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_with_asname_target():
    # from foo import bar as baz
    src = "from foo import bar as baz"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"baz"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_with_module_dot_in_target():
    # from foo.bar import baz
    src = "from foo.bar import baz"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"foo.bar.baz"})
    analyzer.visit_ImportFrom(node)

def test_importfrom_with_nonascii_names():
    # from модуль import функция
    node = ast.ImportFrom(module="модуль", names=[ast.alias(name="функция", asname=None)], level=0)
    analyzer = ImportAnalyzer({"функция", "модуль.функция"})
    analyzer.visit_ImportFrom(node)

# ---------------------------
# 3. Large Scale Test Cases
# ---------------------------

def test_importfrom_large_number_of_names():
    # from foo import name0, name1, ..., name999
    names = [f"name{i}" for i in range(1000)]
    src = "from foo import " + ", ".join(names)
    node = parse_first_importfrom(src)
    target = f"name500"
    analyzer = ImportAnalyzer({target})
    analyzer.visit_ImportFrom(node)

def test_importfrom_large_number_of_targets():
    # from foo import bar
    src = "from foo import bar"
    node = parse_first_importfrom(src)
    # 1000 possible targets, only one matches
    targets = {f"bar{i}" for i in range(999)}
    targets.add("bar")
    analyzer = ImportAnalyzer(targets)
    analyzer.visit_ImportFrom(node)

def test_importfrom_large_number_of_wildcard_imports():
    # Simulate 1000 wildcard imports from different modules
    analyzer = ImportAnalyzer({"something"})
    for i in range(1000):
        node = ast.ImportFrom(module=f"mod{i}", names=[ast.alias(name="*", asname=None)], level=0)
        analyzer.visit_ImportFrom(node)

def test_importfrom_performance_with_many_targets_and_names():
    # from foo import name0, name1, ..., name999
    names = [f"name{i}" for i in range(1000)]
    src = "from foo import " + ", ".join(names)
    node = parse_first_importfrom(src)
    # 1000 targets, only one matches
    targets = {f"name{i}" for i in range(1000)}
    analyzer = ImportAnalyzer(targets)
    analyzer.visit_ImportFrom(node)

def test_importfrom_stops_on_first_match():
    # from foo import bar, baz, qux
    src = "from foo import bar, baz, qux"
    node = parse_first_importfrom(src)
    analyzer = ImportAnalyzer({"baz", "qux"})
    analyzer.visit_ImportFrom(node)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr310-2025-06-10T01.40.46

Click to see suggested changes
Suggested change
if self.found_any_target_function:
return
if not node.module:
return
for alias in node.names:
if alias.name == "*":
continue
imported_name = alias.asname if alias.asname else alias.name
self.imported_names.add(imported_name)
if alias.name in self.function_names_to_find:
self.found_target_functions.add(alias.name)
# Check for qualified name matches
if node.module:
self.wildcard_modules.add(node.module)
else:
imported_name = alias.asname if alias.asname else alias.name
self.imported_modules.add(imported_name)
# Check for dynamic import functions
if node.module == "importlib" and alias.name == "import_module":
self.has_dynamic_imports = True
# Check if imported name is a target qualified name
if alias.name in self.function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = alias.name
return
# Check if module.name forms a target qualified name
qualified_name = f"{node.module}.{alias.name}"
if qualified_name in self.function_names_to_find:
self.found_target_functions.add(qualified_name)
self.generic_visit(node)
self.found_any_target_function = True
self.found_qualified_name = qualified_name
return
if self.found_any_target_function or not node.module:
return
module = node.module
target_functions = self.function_names_to_find
imported_modules = self.imported_modules
wildcard_modules = self.wildcard_modules
for alias in node.names:
alias_name = alias.name
if alias_name == "*":
wildcard_modules.add(module)
continue
imported_name = alias.asname or alias_name
imported_modules.add(imported_name)
# Fast detect dynamic import
if module == "importlib" and alias_name == "import_module":
self.has_dynamic_imports = True
# Check both short and qualified names using direct set membership
if alias_name in target_functions:
self.found_any_target_function = True
self.found_qualified_name = alias_name
return
qualified_name = f"{module}.{alias_name}"
if qualified_name in target_functions:
self.found_any_target_function = True
self.found_qualified_name = qualified_name
return

def visit_Attribute(self, node: ast.Attribute) -> None:
"""Handle attribute access like module.function_name."""
if self.found_any_target_function:
return

def visit_Call(self, node: ast.Call) -> None:
"""Handle dynamic imports like importlib.import_module() or __import__()."""
# Check if this is accessing a target function through an imported module
if (
isinstance(node.func, ast.Name)
and node.func.id == "__import__"
and node.args
and isinstance(node.args[0], ast.Constant)
and isinstance(node.args[0].value, str)
isinstance(node.value, ast.Name)
and node.value.id in self.imported_modules
and node.attr in self.function_names_to_find
):
# __import__("module_name")
self.imported_modules.add(node.args[0].value)
elif (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "importlib"
and node.func.attr == "import_module"
and node.args
and isinstance(node.args[0], ast.Constant)
and isinstance(node.args[0].value, str)
):
# importlib.import_module("module_name")
self.imported_modules.add(node.args[0].value)
self.generic_visit(node)
self.found_any_target_function = True
self.found_qualified_name = node.attr
return

def visit_Name(self, node: ast.Name) -> None:
"""Check if any name usage matches our target functions."""
if node.id in self.function_names_to_find:
self.found_target_functions.add(node.id)
self.generic_visit(node)
# Check if this is accessing a target function through a dynamically imported module
# Only if we've detected dynamic imports are being used
if self.has_dynamic_imports and node.attr in self.function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = node.attr
return

def visit_Attribute(self, node: ast.Attribute) -> None:
"""Handle module.function_name patterns."""
if node.attr in self.function_names_to_find:
self.found_target_functions.add(node.attr)
if isinstance(node.value, ast.Name):
qualified_name = f"{node.value.id}.{node.attr}"
self.qualified_names_called.add(qualified_name)
self.generic_visit(node)

def visit_Name(self, node: ast.Name) -> None:
"""Handle direct name usage like target_function()."""
if self.found_any_target_function:
return

def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> tuple[bool, set[str]]:
"""Analyze imports in a test file to determine if it might test any target functions.
# Check for __import__ usage
if node.id == "__import__":
self.has_dynamic_imports = True

Args:
test_file_path: Path to the test file
target_functions: Set of function names we're looking for
if node.id in self.function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = node.id
return

# Check if this name could come from a wildcard import
for wildcard_module in self.wildcard_modules:
for target_func in self.function_names_to_find:
# Check if target_func is from this wildcard module and name matches
if target_func.startswith(f"{wildcard_module}.") and target_func.endswith(f".{node.id}"):
self.found_any_target_function = True
self.found_qualified_name = target_func
return

Returns:
Tuple of (should_process_with_jedi, found_function_names)
self.generic_visit(node)

"""
if isinstance(test_file_path, str):
test_file_path = Path(test_file_path)
def generic_visit(self, node: ast.AST) -> None:
"""Override generic_visit to stop traversal if a target function is found."""
if self.found_any_target_function:
return
super().generic_visit(node)

try:
with test_file_path.open("r", encoding="utf-8") as f:
content = f.read()

tree = ast.parse(content, filename=str(test_file_path))
def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool:
"""Analyze a test file to see if it imports any of the target functions."""
try:
with Path(test_file_path).open("r", encoding="utf-8") as f:
source_code = f.read()
tree = ast.parse(source_code, filename=str(test_file_path))
analyzer = ImportAnalyzer(target_functions)
analyzer.visit(tree)

if analyzer.found_target_functions:
return True, analyzer.found_target_functions

return False, set() # noqa: TRY300

except (SyntaxError, UnicodeDecodeError, OSError) as e:
except (SyntaxError, FileNotFoundError) as e:
logger.debug(f"Failed to analyze imports in {test_file_path}: {e}")
return True, set()
return True
else:
if analyzer.found_any_target_function:
logger.debug(f"Test file {test_file_path} imports target function: {analyzer.found_qualified_name}")
return True
logger.debug(f"Test file {test_file_path} does not import any target functions.")
return False


def filter_test_files_by_imports(
file_to_test_map: dict[Path, list[TestsInFile]], target_functions: set[str]
) -> tuple[dict[Path, list[TestsInFile]], dict[Path, set[str]]]:
) -> dict[Path, list[TestsInFile]]:
"""Filter test files based on import analysis to reduce Jedi processing.

Args:
file_to_test_map: Original mapping of test files to test functions
target_functions: Set of function names we're optimizing

Returns:
Tuple of (filtered_file_map, import_analysis_results)
Filtered mapping of test files to test functions

"""
if not target_functions:
return file_to_test_map, {}
return file_to_test_map

filtered_map = {}
import_results = {}
logger.debug(f"Target functions for import filtering: {target_functions}")

filtered_map = {}
for test_file, test_functions in file_to_test_map.items():
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
import_results[test_file] = found_functions

should_process = analyze_imports_in_test_file(test_file, target_functions)
if should_process:
filtered_map[test_file] = test_functions
else:
logger.debug(f"Skipping {test_file} - no relevant imports found")

logger.debug(f"Import filter: Processing {len(filtered_map)}/{len(file_to_test_map)} test files")
return filtered_map, import_results
logger.debug(
f"analyzed {len(file_to_test_map)} test files for imports, filtered down to {len(filtered_map)} relevant files"
)
return filtered_map


def discover_unit_tests(
@@ -296,7 +327,6 @@ def discover_unit_tests(
functions_to_optimize = None
if file_to_funcs_to_optimize:
functions_to_optimize = [func for funcs_list in file_to_funcs_to_optimize.values() for func in funcs_list]

function_to_tests, num_discovered_tests = strategy(cfg, discover_only_these_tests, functions_to_optimize)
return function_to_tests, num_discovered_tests

@@ -455,12 +485,8 @@ def process_test_files(
test_framework = cfg.test_framework

if functions_to_optimize:
target_function_names = set()
for func in functions_to_optimize:
target_function_names.add(func.qualified_name)
logger.debug(f"Target functions for import filtering: {target_function_names}")
file_to_test_map, import_results = filter_test_files_by_imports(file_to_test_map, target_function_names)
logger.debug(f"Import analysis results: {len(import_results)} files analyzed")
target_function_names = {func.qualified_name for func in functions_to_optimize}
file_to_test_map = filter_test_files_by_imports(file_to_test_map, target_function_names)

function_to_test_map = defaultdict(set)
num_discovered_tests = 0
11 changes: 7 additions & 4 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again
from codeflash.cli_cmds.console import DEBUG_MODE, console, logger
from codeflash.code_utils.code_utils import (
exit_with_message,
is_class_defined_in_file,
module_name_from_file_path,
path_belongs_to_site_packages,
@@ -179,8 +180,9 @@ def get_functions_to_optimize(
if only_get_this_function is not None:
split_function = only_get_this_function.split(".")
if len(split_function) > 2:
msg = "Function name should be in the format 'function_name' or 'class_name.function_name'"
raise ValueError(msg)
exit_with_message(
"Function name should be in the format 'function_name' or 'class_name.function_name'"
)
if len(split_function) == 2:
class_name, only_function_name = split_function
else:
@@ -193,8 +195,9 @@ def get_functions_to_optimize(
):
found_function = fn
if found_function is None:
msg = f"Function {only_function_name} not found in file {file} or the function does not have a 'return' statement or is a property"
raise ValueError(msg)
exit_with_message(
f"Function {only_function_name} not found in file {file}\nor the function does not have a 'return' statement or is a property"
)
functions[file] = [found_function]
else:
logger.info("Finding all functions modified in the current git diff ...")
15 changes: 11 additions & 4 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
@@ -72,7 +72,6 @@ def create_function_optimizer(
def run(self) -> None:
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
from codeflash.code_utils.code_utils import cleanup_paths
from codeflash.code_utils.static_analysis import (
analyze_imported_modules,
get_first_top_level_function_or_method_ast,
@@ -283,14 +282,22 @@ def run(self) -> None:
if function_optimizer:
function_optimizer.cleanup_generated_files()

if self.test_cfg.concolic_test_root_dir:
cleanup_paths([self.test_cfg.concolic_test_root_dir])
self.cleanup_temporary_paths()

def cleanup_temporary_paths(self) -> None:
from codeflash.code_utils.code_utils import cleanup_paths

cleanup_paths([self.test_cfg.concolic_test_root_dir, self.replay_tests_dir])


def run_with_args(args: Namespace) -> None:
optimizer = None
try:
optimizer = Optimizer(args)
optimizer.run()
except KeyboardInterrupt:
logger.warning("Keyboard interrupt received. Exiting, please wait…")
logger.warning("Keyboard interrupt received. Cleaning up and exiting, please wait…")
if optimizer:
optimizer.cleanup_temporary_paths()

raise SystemExit from None
137 changes: 104 additions & 33 deletions tests/test_unit_test_discovery.py
Original file line number Diff line number Diff line change
@@ -811,11 +811,9 @@ def test_target():
test_file.write_text(test_content)

target_functions = {"target_function", "missing_function"}
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is True
assert "target_function" in found_functions
assert "missing_function" not in found_functions


def test_analyze_imports_star_import():
@@ -831,10 +829,99 @@ def test_something():
test_file.write_text(test_content)

target_functions = {"target_function"}
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is False
assert found_functions == set()


with tempfile.TemporaryDirectory() as tmpdirname:
test_file = Path(tmpdirname) / "test_example.py"
test_content = """
from mymodule import *
def test_target():
assert target_function() is True
"""
test_file.group
test_file.write_text(test_content)

target_functions = {"mymodule.target_function"}
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is True



with tempfile.TemporaryDirectory() as tmpdirname:
test_file = Path(tmpdirname) / "test_example.py"
test_content = """
from mymodule import *
def test_target():
assert target_function_extended() is True
"""
test_file.write_text(test_content)

# Should not match - target_function != target_function_extended
target_functions = {"mymodule.target_function"}
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is False


with tempfile.TemporaryDirectory() as tmpdirname:
test_file = Path(tmpdirname) / "test_example.py"
test_content = """
from mymodule import *
def test_something():
x = 42
assert x == 42
"""
test_file.write_text(test_content)

target_functions = {"mymodule.target_function"}
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is False


with tempfile.TemporaryDirectory() as tmpdirname:
test_file = Path(tmpdirname) / "test_example.py"
test_content = """
from mymodule import *
def test_something():
message = "calling target_function"
assert "target_function" in message
"""
test_file.write_text(test_content)

target_functions = {"mymodule.target_function"}
should_process = analyze_imports_in_test_file(test_file, target_functions)

# String literals are ast.Constant nodes, not ast.Name nodes, so they don't match
assert should_process is False


with tempfile.TemporaryDirectory() as tmpdirname:
test_file = Path(tmpdirname) / "test_example.py"
test_content = """
from mymodule import target_function
from othermodule import *
def test_target():
assert target_function() is True
assert other_func() is True
"""
test_file.write_text(test_content)

target_functions = {"mymodule.target_function", "othermodule.other_func"}
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is True




def test_analyze_imports_module_import():
@@ -850,10 +937,9 @@ def test_target():
test_file.write_text(test_content)

target_functions = {"target_function"}
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is True
assert "target_function" in found_functions


def test_analyze_imports_dynamic_import():
@@ -870,10 +956,9 @@ def test_dynamic():
test_file.write_text(test_content)

target_functions = {"target_function"}
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is True
assert "target_function" in found_functions


def test_analyze_imports_builtin_import():
@@ -888,10 +973,9 @@ def test_builtin_import():
test_file.write_text(test_content)

target_functions = {"target_function"}
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is True
assert "target_function" in found_functions


def test_analyze_imports_no_matching_imports():
@@ -907,9 +991,8 @@ def test_unrelated():
test_file.write_text(test_content)

target_functions = {"target_function", "another_function"}
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
should_process = analyze_imports_in_test_file(test_file, target_functions)
assert should_process is False
assert found_functions == set()


def test_analyze_qualified_names():
@@ -924,9 +1007,8 @@ def test_target():
test_file.write_text(test_content)

target_functions = {"target_module.some_function"}
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
should_process = analyze_imports_in_test_file(test_file, target_functions)
assert should_process is True
assert "target_module.some_function" in found_functions



@@ -943,11 +1025,10 @@ def test_target(
test_file.write_text(test_content)

target_functions = {"target_function"}
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
should_process = analyze_imports_in_test_file(test_file, target_functions)

# Should be conservative with unparseable files
assert should_process is True
assert found_functions == set()


def test_filter_test_files_by_imports():
@@ -988,17 +1069,13 @@ def test_star():
}

target_functions = {"target_function"}
filtered_map, import_results = filter_test_files_by_imports(file_to_test_map, target_functions)
filtered_map = filter_test_files_by_imports(file_to_test_map, target_functions)

# Should filter out irrelevant_test
assert len(filtered_map) == 1
assert relevant_test in filtered_map
assert irrelevant_test not in filtered_map

# Check import analysis results
assert "target_function" in import_results[relevant_test]
assert len(import_results[irrelevant_test]) == 0
assert len(import_results[star_test]) == 0


def test_filter_test_files_no_target_functions():
@@ -1014,11 +1091,10 @@ def test_filter_test_files_no_target_functions():
}

# No target functions provided
filtered_map, import_results = filter_test_files_by_imports(file_to_test_map, set())
filtered_map = filter_test_files_by_imports(file_to_test_map, set())

# Should return original map unchanged
assert filtered_map == file_to_test_map
assert import_results == {}


def test_discover_unit_tests_with_import_filtering():
@@ -1090,10 +1166,9 @@ def test_conditional():
test_file.write_text(test_content)

target_functions = {"target_function"}
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is True
assert "target_function" in found_functions


def test_analyze_imports_function_name_in_code():
@@ -1113,10 +1188,9 @@ def test_indirect():
test_file.write_text(test_content)

target_functions = {"target_function"}
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is True
assert "target_function" in found_functions


def test_analyze_imports_aliased_imports():
@@ -1133,11 +1207,9 @@ def test_aliased():
test_file.write_text(test_content)

target_functions = {"target_function", "missing_function"}
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is True
assert "target_function" in found_functions
assert "missing_function" not in found_functions


def test_analyze_imports_underscore_function_names():
@@ -1152,10 +1224,9 @@ def test_bubble():
test_file.write_text(test_content)

target_functions = {"bubble_sort"}
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is False
assert "bubble_sort" not in found_functions

def test_discover_unit_tests_filtering_different_modules():
"""Test import filtering with test files from completely different modules."""