@@ -141,144 +141,175 @@ def close(self) -> None:
141
141
142
142
143
143
class ImportAnalyzer (ast .NodeVisitor ):
144
- """AST-based analyzer to find all imports in a test file."""
144
+ """AST-based analyzer to check if any qualified names from function_names_to_find are imported or used in a test file."""
145
145
146
146
def __init__ (self , function_names_to_find : set [str ]) -> None :
147
147
self .function_names_to_find = function_names_to_find
148
- self .imported_names : set [str ] = set ()
148
+ self .found_any_target_function : bool = False
149
+ self .found_qualified_name = None
149
150
self .imported_modules : set [str ] = set ()
150
- self .found_target_functions : set [ str ] = set ()
151
- self .qualified_names_called : set [str ] = set ()
151
+ self .has_dynamic_imports : bool = False
152
+ self .wildcard_modules : set [str ] = set ()
152
153
153
154
def visit_Import (self , node : ast .Import ) -> None :
154
155
"""Handle 'import module' statements."""
156
+ if self .found_any_target_function :
157
+ return
158
+
155
159
for alias in node .names :
156
160
module_name = alias .asname if alias .asname else alias .name
157
161
self .imported_modules .add (module_name )
158
- self .imported_names .add (module_name )
159
- self .generic_visit (node )
162
+
163
+ # Check for dynamic import modules
164
+ if alias .name == "importlib" :
165
+ self .has_dynamic_imports = True
166
+
167
+ # Check if module itself is a target qualified name
168
+ if module_name in self .function_names_to_find :
169
+ self .found_any_target_function = True
170
+ self .found_qualified_name = module_name
171
+ return
172
+ # Check if any target qualified name starts with this module
173
+ for target_func in self .function_names_to_find :
174
+ if target_func .startswith (f"{ module_name } ." ):
175
+ self .found_any_target_function = True
176
+ self .found_qualified_name = target_func
177
+ return
160
178
161
179
def visit_ImportFrom (self , node : ast .ImportFrom ) -> None :
162
180
"""Handle 'from module import name' statements."""
163
- if node .module :
164
- self .imported_modules .add (node .module )
181
+ if self .found_any_target_function :
182
+ return
183
+
184
+ if not node .module :
185
+ return
165
186
166
187
for alias in node .names :
167
188
if alias .name == "*" :
168
- continue
169
- imported_name = alias .asname if alias .asname else alias .name
170
- self .imported_names .add (imported_name )
171
- if alias .name in self .function_names_to_find :
172
- self .found_target_functions .add (alias .name )
173
- # Check for qualified name matches
174
- if node .module :
189
+ self .wildcard_modules .add (node .module )
190
+ else :
191
+ imported_name = alias .asname if alias .asname else alias .name
192
+ self .imported_modules .add (imported_name )
193
+
194
+ # Check for dynamic import functions
195
+ if node .module == "importlib" and alias .name == "import_module" :
196
+ self .has_dynamic_imports = True
197
+
198
+ # Check if imported name is a target qualified name
199
+ if alias .name in self .function_names_to_find :
200
+ self .found_any_target_function = True
201
+ self .found_qualified_name = alias .name
202
+ return
203
+ # Check if module.name forms a target qualified name
175
204
qualified_name = f"{ node .module } .{ alias .name } "
176
205
if qualified_name in self .function_names_to_find :
177
- self .found_target_functions .add (qualified_name )
178
- self .generic_visit (node )
206
+ self .found_any_target_function = True
207
+ self .found_qualified_name = qualified_name
208
+ return
209
+
210
+ def visit_Attribute (self , node : ast .Attribute ) -> None :
211
+ """Handle attribute access like module.function_name."""
212
+ if self .found_any_target_function :
213
+ return
179
214
180
- def visit_Call (self , node : ast .Call ) -> None :
181
- """Handle dynamic imports like importlib.import_module() or __import__()."""
215
+ # Check if this is accessing a target function through an imported module
182
216
if (
183
- isinstance (node .func , ast .Name )
184
- and node .func .id == "__import__"
185
- and node .args
186
- and isinstance (node .args [0 ], ast .Constant )
187
- and isinstance (node .args [0 ].value , str )
217
+ isinstance (node .value , ast .Name )
218
+ and node .value .id in self .imported_modules
219
+ and node .attr in self .function_names_to_find
188
220
):
189
- # __import__("module_name")
190
- self .imported_modules .add (node .args [0 ].value )
191
- elif (
192
- isinstance (node .func , ast .Attribute )
193
- and isinstance (node .func .value , ast .Name )
194
- and node .func .value .id == "importlib"
195
- and node .func .attr == "import_module"
196
- and node .args
197
- and isinstance (node .args [0 ], ast .Constant )
198
- and isinstance (node .args [0 ].value , str )
199
- ):
200
- # importlib.import_module("module_name")
201
- self .imported_modules .add (node .args [0 ].value )
202
- self .generic_visit (node )
221
+ self .found_any_target_function = True
222
+ self .found_qualified_name = node .attr
223
+ return
203
224
204
- def visit_Name (self , node : ast .Name ) -> None :
205
- """Check if any name usage matches our target functions."""
206
- if node .id in self .function_names_to_find :
207
- self .found_target_functions .add (node .id )
208
- self .generic_visit (node )
225
+ # Check if this is accessing a target function through a dynamically imported module
226
+ # Only if we've detected dynamic imports are being used
227
+ if self .has_dynamic_imports and node .attr in self .function_names_to_find :
228
+ self .found_any_target_function = True
229
+ self .found_qualified_name = node .attr
230
+ return
209
231
210
- def visit_Attribute (self , node : ast .Attribute ) -> None :
211
- """Handle module.function_name patterns."""
212
- if node .attr in self .function_names_to_find :
213
- self .found_target_functions .add (node .attr )
214
- if isinstance (node .value , ast .Name ):
215
- qualified_name = f"{ node .value .id } .{ node .attr } "
216
- self .qualified_names_called .add (qualified_name )
217
232
self .generic_visit (node )
218
233
234
+ def visit_Name (self , node : ast .Name ) -> None :
235
+ """Handle direct name usage like target_function()."""
236
+ if self .found_any_target_function :
237
+ return
219
238
220
- def analyze_imports_in_test_file (test_file_path : Path | str , target_functions : set [str ]) -> tuple [bool , set [str ]]:
221
- """Analyze imports in a test file to determine if it might test any target functions.
239
+ # Check for __import__ usage
240
+ if node .id == "__import__" :
241
+ self .has_dynamic_imports = True
222
242
223
- Args:
224
- test_file_path: Path to the test file
225
- target_functions: Set of function names we're looking for
243
+ if node .id in self .function_names_to_find :
244
+ self .found_any_target_function = True
245
+ self .found_qualified_name = node .id
246
+ return
247
+
248
+ # Check if this name could come from a wildcard import
249
+ for wildcard_module in self .wildcard_modules :
250
+ for target_func in self .function_names_to_find :
251
+ # Check if target_func is from this wildcard module and name matches
252
+ if target_func .startswith (f"{ wildcard_module } ." ) and target_func .endswith (f".{ node .id } " ):
253
+ self .found_any_target_function = True
254
+ self .found_qualified_name = target_func
255
+ return
226
256
227
- Returns:
228
- Tuple of (should_process_with_jedi, found_function_names)
257
+ self .generic_visit (node )
229
258
230
- """
231
- if isinstance (test_file_path , str ):
232
- test_file_path = Path (test_file_path )
259
+ def generic_visit (self , node : ast .AST ) -> None :
260
+ """Override generic_visit to stop traversal if a target function is found."""
261
+ if self .found_any_target_function :
262
+ return
263
+ super ().generic_visit (node )
233
264
234
- try :
235
- with test_file_path .open ("r" , encoding = "utf-8" ) as f :
236
- content = f .read ()
237
265
238
- tree = ast .parse (content , filename = str (test_file_path ))
266
+ def analyze_imports_in_test_file (test_file_path : Path | str , target_functions : set [str ]) -> bool :
267
+ """Analyze a test file to see if it imports any of the target functions."""
268
+ try :
269
+ with Path (test_file_path ).open ("r" , encoding = "utf-8" ) as f :
270
+ source_code = f .read ()
271
+ tree = ast .parse (source_code , filename = str (test_file_path ))
239
272
analyzer = ImportAnalyzer (target_functions )
240
273
analyzer .visit (tree )
241
-
242
- if analyzer .found_target_functions :
243
- return True , analyzer .found_target_functions
244
-
245
- return False , set () # noqa: TRY300
246
-
247
- except (SyntaxError , UnicodeDecodeError , OSError ) as e :
274
+ except (SyntaxError , FileNotFoundError ) as e :
248
275
logger .debug (f"Failed to analyze imports in { test_file_path } : { e } " )
249
- return True , set ()
276
+ return True
277
+ else :
278
+ if analyzer .found_any_target_function :
279
+ logger .debug (f"Test file { test_file_path } imports target function: { analyzer .found_qualified_name } " )
280
+ return True
281
+ logger .debug (f"Test file { test_file_path } does not import any target functions." )
282
+ return False
250
283
251
284
252
285
def filter_test_files_by_imports (
253
286
file_to_test_map : dict [Path , list [TestsInFile ]], target_functions : set [str ]
254
- ) -> tuple [ dict [Path , list [TestsInFile ]], dict [ Path , set [ str ] ]]:
287
+ ) -> dict [Path , list [TestsInFile ]]:
255
288
"""Filter test files based on import analysis to reduce Jedi processing.
256
289
257
290
Args:
258
291
file_to_test_map: Original mapping of test files to test functions
259
292
target_functions: Set of function names we're optimizing
260
293
261
294
Returns:
262
- Tuple of (filtered_file_map, import_analysis_results)
295
+ Filtered mapping of test files to test functions
263
296
264
297
"""
265
298
if not target_functions :
266
- return file_to_test_map , {}
299
+ return file_to_test_map
267
300
268
- filtered_map = {}
269
- import_results = {}
301
+ logger .debug (f"Target functions for import filtering: { target_functions } " )
270
302
303
+ filtered_map = {}
271
304
for test_file , test_functions in file_to_test_map .items ():
272
- should_process , found_functions = analyze_imports_in_test_file (test_file , target_functions )
273
- import_results [test_file ] = found_functions
274
-
305
+ should_process = analyze_imports_in_test_file (test_file , target_functions )
275
306
if should_process :
276
307
filtered_map [test_file ] = test_functions
277
- else :
278
- logger .debug (f"Skipping { test_file } - no relevant imports found" )
279
308
280
- logger .debug (f"Import filter: Processing { len (filtered_map )} /{ len (file_to_test_map )} test files" )
281
- return filtered_map , import_results
309
+ logger .debug (
310
+ f"analyzed { len (file_to_test_map )} test files for imports, filtered down to { len (filtered_map )} relevant files"
311
+ )
312
+ return filtered_map
282
313
283
314
284
315
def discover_unit_tests (
@@ -296,7 +327,6 @@ def discover_unit_tests(
296
327
functions_to_optimize = None
297
328
if file_to_funcs_to_optimize :
298
329
functions_to_optimize = [func for funcs_list in file_to_funcs_to_optimize .values () for func in funcs_list ]
299
-
300
330
function_to_tests , num_discovered_tests = strategy (cfg , discover_only_these_tests , functions_to_optimize )
301
331
return function_to_tests , num_discovered_tests
302
332
@@ -455,12 +485,8 @@ def process_test_files(
455
485
test_framework = cfg .test_framework
456
486
457
487
if functions_to_optimize :
458
- target_function_names = set ()
459
- for func in functions_to_optimize :
460
- target_function_names .add (func .qualified_name )
461
- logger .debug (f"Target functions for import filtering: { target_function_names } " )
462
- file_to_test_map , import_results = filter_test_files_by_imports (file_to_test_map , target_function_names )
463
- logger .debug (f"Import analysis results: { len (import_results )} files analyzed" )
488
+ target_function_names = {func .qualified_name for func in functions_to_optimize }
489
+ file_to_test_map = filter_test_files_by_imports (file_to_test_map , target_function_names )
464
490
465
491
function_to_test_map = defaultdict (set )
466
492
num_discovered_tests = 0
0 commit comments