@@ -141,109 +141,132 @@ def close(self) -> None:
141
141
142
142
143
143
class ImportAnalyzer (ast .NodeVisitor ):
144
- """AST-based analyzer to find all imports in a test file."""
144
+ """AST-based analyzer to check if any qualified names from function_names_to_find are imported or used in a test file."""
145
145
146
146
def __init__ (self , function_names_to_find : set [str ]) -> None :
147
147
self .function_names_to_find = function_names_to_find
148
- self .imported_names : set [ str ] = set ()
149
- self .imported_modules : set [str ] = set ()
150
- self .found_target_functions : set [ str ] = set ()
151
- self .qualified_names_called : set [str ] = set ()
148
+ self .found_any_target_function : bool = False
149
+ self .imported_modules : set [str ] = set () # Track imported modules for usage analysis
150
+ self .has_dynamic_imports : bool = False
151
+ self .wildcard_modules : set [str ] = set ()
152
152
153
153
def visit_Import (self , node : ast .Import ) -> None :
154
154
"""Handle 'import module' statements."""
155
+ if self .found_any_target_function :
156
+ return
157
+
155
158
for alias in node .names :
156
159
module_name = alias .asname if alias .asname else alias .name
157
160
self .imported_modules .add (module_name )
158
- self .imported_names .add (module_name )
159
- self .generic_visit (node )
161
+
162
+ # Check for dynamic import modules
163
+ if alias .name == "importlib" :
164
+ self .has_dynamic_imports = True
165
+
166
+ # Check if module itself is a target qualified name
167
+ if module_name in self .function_names_to_find :
168
+ self .found_any_target_function = True
169
+ return
170
+ # Check if any target qualified name starts with this module
171
+ for target_func in self .function_names_to_find :
172
+ if target_func .startswith (f"{ module_name } ." ):
173
+ self .found_any_target_function = True
174
+ return
160
175
161
176
def visit_ImportFrom (self , node : ast .ImportFrom ) -> None :
162
177
"""Handle 'from module import name' statements."""
163
- if node .module :
164
- self .imported_modules .add (node .module )
178
+ if self .found_any_target_function :
179
+ return
180
+
181
+ if not node .module :
182
+ return
165
183
166
184
for alias in node .names :
167
185
if alias .name == "*" :
168
- continue
169
- imported_name = alias .asname if alias .asname else alias .name
170
- self .imported_names .add (imported_name )
171
- if alias .name in self .function_names_to_find :
172
- self .found_target_functions .add (alias .name )
173
- # Check for qualified name matches
174
- if node .module :
186
+ self .wildcard_modules .add (node .module )
187
+ else :
188
+ imported_name = alias .asname if alias .asname else alias .name
189
+ self .imported_modules .add (imported_name )
190
+
191
+ # Check for dynamic import functions
192
+ if node .module == "importlib" and alias .name == "import_module" :
193
+ self .has_dynamic_imports = True
194
+
195
+ # Check if imported name is a target qualified name
196
+ if alias .name in self .function_names_to_find :
197
+ self .found_any_target_function = True
198
+ return
199
+ # Check if module.name forms a target qualified name
175
200
qualified_name = f"{ node .module } .{ alias .name } "
176
201
if qualified_name in self .function_names_to_find :
177
- self .found_target_functions . add ( qualified_name )
178
- self . generic_visit ( node )
202
+ self .found_any_target_function = True
203
+ return
179
204
180
- def visit_Call (self , node : ast .Call ) -> None :
181
- """Handle dynamic imports like importlib.import_module() or __import__()."""
205
+ def visit_Attribute (self , node : ast .Attribute ) -> None :
206
+ """Handle attribute access like module.function_name."""
207
+ if self .found_any_target_function :
208
+ return
209
+
210
+ # Check if this is accessing a target function through an imported module
182
211
if (
183
- isinstance (node .func , ast .Name )
184
- and node .func .id == "__import__"
185
- and node .args
186
- and isinstance (node .args [0 ], ast .Constant )
187
- and isinstance (node .args [0 ].value , str )
212
+ isinstance (node .value , ast .Name )
213
+ and node .value .id in self .imported_modules
214
+ and node .attr in self .function_names_to_find
188
215
):
189
- # __import__("module_name")
190
- self .imported_modules .add (node .args [0 ].value )
191
- elif (
192
- isinstance (node .func , ast .Attribute )
193
- and isinstance (node .func .value , ast .Name )
194
- and node .func .value .id == "importlib"
195
- and node .func .attr == "import_module"
196
- and node .args
197
- and isinstance (node .args [0 ], ast .Constant )
198
- and isinstance (node .args [0 ].value , str )
199
- ):
200
- # importlib.import_module("module_name")
201
- self .imported_modules .add (node .args [0 ].value )
216
+ self .found_any_target_function = True
217
+ return
218
+
219
+ # Check if this is accessing a target function through a dynamically imported module
220
+ # Only if we've detected dynamic imports are being used
221
+ if self .has_dynamic_imports and node .attr in self .function_names_to_find :
222
+ self .found_any_target_function = True
223
+ return
224
+
202
225
self .generic_visit (node )
203
226
204
227
def visit_Name (self , node : ast .Name ) -> None :
205
- """Check if any name usage matches our target functions."""
206
- if node .id in self .function_names_to_find :
207
- self .found_target_functions .add (node .id )
208
- self .generic_visit (node )
228
+ """Handle direct name usage like target_function()."""
229
+ if self .found_any_target_function :
230
+ return
209
231
210
- def visit_Attribute (self , node : ast .Attribute ) -> None :
211
- """Handle module.function_name patterns."""
212
- if node .attr in self .function_names_to_find :
213
- self .found_target_functions .add (node .attr )
214
- if isinstance (node .value , ast .Name ):
215
- qualified_name = f"{ node .value .id } .{ node .attr } "
216
- self .qualified_names_called .add (qualified_name )
217
- self .generic_visit (node )
232
+ # Check for __import__ usage
233
+ if node .id == "__import__" :
234
+ self .has_dynamic_imports = True
218
235
236
+ if node .id in self .function_names_to_find :
237
+ self .found_any_target_function = True
238
+ return
219
239
220
- def analyze_imports_in_test_file (test_file_path : Path | str , target_functions : set [str ]) -> bool :
221
- """Analyze imports in a test file to determine if it might test any target functions.
240
+ # Check if this name could come from a wildcard import
241
+ for wildcard_module in self .wildcard_modules :
242
+ for target_func in self .function_names_to_find :
243
+ # Check if target_func is from this wildcard module and name matches
244
+ if target_func .startswith (f"{ wildcard_module } ." ) and target_func .endswith (f".{ node .id } " ):
245
+ self .found_any_target_function = True
246
+ return
222
247
223
- Args:
224
- test_file_path: Path to the test file
225
- target_functions: Set of function names we're looking for
248
+ self .generic_visit (node )
226
249
227
- Returns:
228
- bool: True if the test file should be processed (contains relevant imports), False otherwise
250
+ def generic_visit (self , node : ast .AST ) -> None :
251
+ """Override generic_visit to stop traversal if a target function is found."""
252
+ if self .found_any_target_function :
253
+ return
254
+ super ().generic_visit (node )
229
255
230
- """
231
- if isinstance (test_file_path , str ):
232
- test_file_path = Path (test_file_path )
233
256
257
+ def analyze_imports_in_test_file (test_file_path : Path | str , target_functions : set [str ]) -> bool :
258
+ """Analyze a test file to see if it imports any of the target functions."""
234
259
try :
235
- with test_file_path .open ("r" , encoding = "utf-8" ) as f :
236
- content = f .read ()
237
-
238
- tree = ast .parse (content , filename = str (test_file_path ))
260
+ with Path (test_file_path ).open ("r" , encoding = "utf-8" ) as f :
261
+ source_code = f .read ()
262
+ tree = ast .parse (source_code , filename = str (test_file_path ))
239
263
analyzer = ImportAnalyzer (target_functions )
240
264
analyzer .visit (tree )
241
-
242
- return bool (analyzer .found_target_functions )
243
-
244
- except (SyntaxError , UnicodeDecodeError , OSError ) as e :
265
+ except (SyntaxError , FileNotFoundError ) as e :
245
266
logger .debug (f"Failed to analyze imports in { test_file_path } : { e } " )
246
267
return True
268
+ else :
269
+ return analyzer .found_any_target_function
247
270
248
271
249
272
def filter_test_files_by_imports (
0 commit comments