Skip to content

Commit 36f320d

Browse files
committed
refactor to early exit
1 parent fee08a6 commit 36f320d

File tree

2 files changed

+181
-68
lines changed

2 files changed

+181
-68
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 91 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -141,109 +141,132 @@ def close(self) -> None:
141141

142142

143143
class ImportAnalyzer(ast.NodeVisitor):
144-
"""AST-based analyzer to find all imports in a test file."""
144+
"""AST-based analyzer to check if any qualified names from function_names_to_find are imported or used in a test file."""
145145

146146
def __init__(self, function_names_to_find: set[str]) -> None:
147147
self.function_names_to_find = function_names_to_find
148-
self.imported_names: set[str] = set()
149-
self.imported_modules: set[str] = set()
150-
self.found_target_functions: set[str] = set()
151-
self.qualified_names_called: set[str] = set()
148+
self.found_any_target_function: bool = False
149+
self.imported_modules: set[str] = set() # Track imported modules for usage analysis
150+
self.has_dynamic_imports: bool = False
151+
self.wildcard_modules: set[str] = set()
152152

153153
def visit_Import(self, node: ast.Import) -> None:
154154
"""Handle 'import module' statements."""
155+
if self.found_any_target_function:
156+
return
157+
155158
for alias in node.names:
156159
module_name = alias.asname if alias.asname else alias.name
157160
self.imported_modules.add(module_name)
158-
self.imported_names.add(module_name)
159-
self.generic_visit(node)
161+
162+
# Check for dynamic import modules
163+
if alias.name == "importlib":
164+
self.has_dynamic_imports = True
165+
166+
# Check if module itself is a target qualified name
167+
if module_name in self.function_names_to_find:
168+
self.found_any_target_function = True
169+
return
170+
# Check if any target qualified name starts with this module
171+
for target_func in self.function_names_to_find:
172+
if target_func.startswith(f"{module_name}."):
173+
self.found_any_target_function = True
174+
return
160175

161176
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
162177
"""Handle 'from module import name' statements."""
163-
if node.module:
164-
self.imported_modules.add(node.module)
178+
if self.found_any_target_function:
179+
return
180+
181+
if not node.module:
182+
return
165183

166184
for alias in node.names:
167185
if alias.name == "*":
168-
continue
169-
imported_name = alias.asname if alias.asname else alias.name
170-
self.imported_names.add(imported_name)
171-
if alias.name in self.function_names_to_find:
172-
self.found_target_functions.add(alias.name)
173-
# Check for qualified name matches
174-
if node.module:
186+
self.wildcard_modules.add(node.module)
187+
else:
188+
imported_name = alias.asname if alias.asname else alias.name
189+
self.imported_modules.add(imported_name)
190+
191+
# Check for dynamic import functions
192+
if node.module == "importlib" and alias.name == "import_module":
193+
self.has_dynamic_imports = True
194+
195+
# Check if imported name is a target qualified name
196+
if alias.name in self.function_names_to_find:
197+
self.found_any_target_function = True
198+
return
199+
# Check if module.name forms a target qualified name
175200
qualified_name = f"{node.module}.{alias.name}"
176201
if qualified_name in self.function_names_to_find:
177-
self.found_target_functions.add(qualified_name)
178-
self.generic_visit(node)
202+
self.found_any_target_function = True
203+
return
179204

180-
def visit_Call(self, node: ast.Call) -> None:
181-
"""Handle dynamic imports like importlib.import_module() or __import__()."""
205+
def visit_Attribute(self, node: ast.Attribute) -> None:
206+
"""Handle attribute access like module.function_name."""
207+
if self.found_any_target_function:
208+
return
209+
210+
# Check if this is accessing a target function through an imported module
182211
if (
183-
isinstance(node.func, ast.Name)
184-
and node.func.id == "__import__"
185-
and node.args
186-
and isinstance(node.args[0], ast.Constant)
187-
and isinstance(node.args[0].value, str)
212+
isinstance(node.value, ast.Name)
213+
and node.value.id in self.imported_modules
214+
and node.attr in self.function_names_to_find
188215
):
189-
# __import__("module_name")
190-
self.imported_modules.add(node.args[0].value)
191-
elif (
192-
isinstance(node.func, ast.Attribute)
193-
and isinstance(node.func.value, ast.Name)
194-
and node.func.value.id == "importlib"
195-
and node.func.attr == "import_module"
196-
and node.args
197-
and isinstance(node.args[0], ast.Constant)
198-
and isinstance(node.args[0].value, str)
199-
):
200-
# importlib.import_module("module_name")
201-
self.imported_modules.add(node.args[0].value)
216+
self.found_any_target_function = True
217+
return
218+
219+
# Check if this is accessing a target function through a dynamically imported module
220+
# Only if we've detected dynamic imports are being used
221+
if self.has_dynamic_imports and node.attr in self.function_names_to_find:
222+
self.found_any_target_function = True
223+
return
224+
202225
self.generic_visit(node)
203226

204227
def visit_Name(self, node: ast.Name) -> None:
205-
"""Check if any name usage matches our target functions."""
206-
if node.id in self.function_names_to_find:
207-
self.found_target_functions.add(node.id)
208-
self.generic_visit(node)
228+
"""Handle direct name usage like target_function()."""
229+
if self.found_any_target_function:
230+
return
209231

210-
def visit_Attribute(self, node: ast.Attribute) -> None:
211-
"""Handle module.function_name patterns."""
212-
if node.attr in self.function_names_to_find:
213-
self.found_target_functions.add(node.attr)
214-
if isinstance(node.value, ast.Name):
215-
qualified_name = f"{node.value.id}.{node.attr}"
216-
self.qualified_names_called.add(qualified_name)
217-
self.generic_visit(node)
232+
# Check for __import__ usage
233+
if node.id == "__import__":
234+
self.has_dynamic_imports = True
218235

236+
if node.id in self.function_names_to_find:
237+
self.found_any_target_function = True
238+
return
219239

220-
def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool:
221-
"""Analyze imports in a test file to determine if it might test any target functions.
240+
# Check if this name could come from a wildcard import
241+
for wildcard_module in self.wildcard_modules:
242+
for target_func in self.function_names_to_find:
243+
# Check if target_func is from this wildcard module and name matches
244+
if target_func.startswith(f"{wildcard_module}.") and target_func.endswith(f".{node.id}"):
245+
self.found_any_target_function = True
246+
return
222247

223-
Args:
224-
test_file_path: Path to the test file
225-
target_functions: Set of function names we're looking for
248+
self.generic_visit(node)
226249

227-
Returns:
228-
bool: True if the test file should be processed (contains relevant imports), False otherwise
250+
def generic_visit(self, node: ast.AST) -> None:
251+
"""Override generic_visit to stop traversal if a target function is found."""
252+
if self.found_any_target_function:
253+
return
254+
super().generic_visit(node)
229255

230-
"""
231-
if isinstance(test_file_path, str):
232-
test_file_path = Path(test_file_path)
233256

257+
def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool:
258+
"""Analyze a test file to see if it imports any of the target functions."""
234259
try:
235-
with test_file_path.open("r", encoding="utf-8") as f:
236-
content = f.read()
237-
238-
tree = ast.parse(content, filename=str(test_file_path))
260+
with Path(test_file_path).open("r", encoding="utf-8") as f:
261+
source_code = f.read()
262+
tree = ast.parse(source_code, filename=str(test_file_path))
239263
analyzer = ImportAnalyzer(target_functions)
240264
analyzer.visit(tree)
241-
242-
return bool(analyzer.found_target_functions)
243-
244-
except (SyntaxError, UnicodeDecodeError, OSError) as e:
265+
except (SyntaxError, FileNotFoundError) as e:
245266
logger.debug(f"Failed to analyze imports in {test_file_path}: {e}")
246267
return True
268+
else:
269+
return analyzer.found_any_target_function
247270

248271

249272
def filter_test_files_by_imports(

tests/test_unit_test_discovery.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,96 @@ def test_something():
834834
assert should_process is False
835835

836836

837+
with tempfile.TemporaryDirectory() as tmpdirname:
838+
test_file = Path(tmpdirname) / "test_example.py"
839+
test_content = """
840+
from mymodule import *
841+
842+
def test_target():
843+
assert target_function() is True
844+
"""
845+
test_file.group
846+
test_file.write_text(test_content)
847+
848+
target_functions = {"mymodule.target_function"}
849+
should_process = analyze_imports_in_test_file(test_file, target_functions)
850+
851+
assert should_process is True
852+
853+
854+
855+
with tempfile.TemporaryDirectory() as tmpdirname:
856+
test_file = Path(tmpdirname) / "test_example.py"
857+
test_content = """
858+
from mymodule import *
859+
860+
def test_target():
861+
assert target_function_extended() is True
862+
"""
863+
test_file.write_text(test_content)
864+
865+
# Should not match - target_function != target_function_extended
866+
target_functions = {"mymodule.target_function"}
867+
should_process = analyze_imports_in_test_file(test_file, target_functions)
868+
869+
assert should_process is False
870+
871+
872+
with tempfile.TemporaryDirectory() as tmpdirname:
873+
test_file = Path(tmpdirname) / "test_example.py"
874+
test_content = """
875+
from mymodule import *
876+
877+
def test_something():
878+
x = 42
879+
assert x == 42
880+
"""
881+
test_file.write_text(test_content)
882+
883+
target_functions = {"mymodule.target_function"}
884+
should_process = analyze_imports_in_test_file(test_file, target_functions)
885+
886+
assert should_process is False
887+
888+
889+
with tempfile.TemporaryDirectory() as tmpdirname:
890+
test_file = Path(tmpdirname) / "test_example.py"
891+
test_content = """
892+
from mymodule import *
893+
894+
def test_something():
895+
message = "calling target_function"
896+
assert "target_function" in message
897+
"""
898+
test_file.write_text(test_content)
899+
900+
target_functions = {"mymodule.target_function"}
901+
should_process = analyze_imports_in_test_file(test_file, target_functions)
902+
903+
# String literals are ast.Constant nodes, not ast.Name nodes, so they don't match
904+
assert should_process is False
905+
906+
907+
with tempfile.TemporaryDirectory() as tmpdirname:
908+
test_file = Path(tmpdirname) / "test_example.py"
909+
test_content = """
910+
from mymodule import target_function
911+
from othermodule import *
912+
913+
def test_target():
914+
assert target_function() is True
915+
assert other_func() is True
916+
"""
917+
test_file.write_text(test_content)
918+
919+
target_functions = {"mymodule.target_function", "othermodule.other_func"}
920+
should_process = analyze_imports_in_test_file(test_file, target_functions)
921+
922+
assert should_process is True
923+
924+
925+
926+
837927
def test_analyze_imports_module_import():
838928
"""Test module imports with function access patterns."""
839929
with tempfile.TemporaryDirectory() as tmpdirname:

0 commit comments

Comments
 (0)