diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 6bcf8156..26697c3e 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -212,24 +212,24 @@ def visit_Attribute(self, node: ast.Attribute) -> None: if self.found_any_target_function: return - # Check if this is accessing a target function through an imported module - if ( - isinstance(node.value, ast.Name) - and node.value.id in self.imported_modules - and node.attr in self.function_names_to_find - ): + fnames = self.function_names_to_find + imported_mods = self.imported_modules + + # Fast path: imported module + target function + val = node.value + if isinstance(val, ast.Name) and val.id in imported_mods and node.attr in fnames: self.found_any_target_function = True self.found_qualified_name = node.attr return - # Check if this is accessing a target function through a dynamically imported module - # Only if we've detected dynamic imports are being used - if self.has_dynamic_imports and node.attr in self.function_names_to_find: + # Dynamic import fast path + if self.has_dynamic_imports and node.attr in fnames: self.found_any_target_function = True self.found_qualified_name = node.attr return - self.generic_visit(node) + # Still need to traverse for deeply nested attr + ast.NodeVisitor.generic_visit(self, node) def visit_Name(self, node: ast.Name) -> None: """Handle direct name usage like target_function().""" @@ -260,7 +260,7 @@ def generic_visit(self, node: ast.AST) -> None: """Override generic_visit to stop traversal if a target function is found.""" if self.found_any_target_function: return - super().generic_visit(node) + ast.NodeVisitor.generic_visit(self, node) def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool: