|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import ast
|
| 4 | +from _ast import AST |
4 | 5 | from collections import defaultdict
|
5 | 6 | from functools import lru_cache
|
6 | 7 | from typing import TYPE_CHECKING, Optional, TypeVar, Union
|
@@ -40,8 +41,8 @@ def _uses_benchmark_fixture(self, node: Union[ast.FunctionDef, ast.AsyncFunction
|
40 | 41 | if self._is_benchmark_marker(decorator):
|
41 | 42 | return True
|
42 | 43 |
|
43 |
| - # Check function body for benchmark usage |
44 |
| - return any(isinstance(stmt, ast.Call) and self._is_benchmark_call(stmt) for stmt in ast.walk(node)) |
| 44 | + # Optimized: Use a fast body scan to detect use of benchmark in function body |
| 45 | + return self._body_uses_benchmark_call(node.body) |
45 | 46 |
|
46 | 47 | def _is_benchmark_marker(self, decorator: ast.expr) -> bool:
|
47 | 48 | """Check if decorator is a benchmark-related pytest marker."""
|
@@ -113,6 +114,29 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
113 | 114 | node.body = new_body
|
114 | 115 | return node
|
115 | 116 |
|
| 117 | + def _body_uses_benchmark_call(self, stmts): |
| 118 | + """Efficiently check if 'benchmark' is called anywhere in the body (recursive, shallow, single function only).""" |
| 119 | + stack = list(stmts) |
| 120 | + while stack: |
| 121 | + stmt = stack.pop() |
| 122 | + # Check for a benchmark call at this node (stmt may be an expr, an Assign, etc.) |
| 123 | + if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Call): |
| 124 | + if self._is_benchmark_call(stmt.value): |
| 125 | + return True |
| 126 | + elif isinstance(stmt, ast.Call): |
| 127 | + if self._is_benchmark_call(stmt): |
| 128 | + return True |
| 129 | + # Recursively check relevant AST containers for body calls |
| 130 | + for attr in ("body", "orelse", "finalbody"): |
| 131 | + if hasattr(stmt, attr): |
| 132 | + stack.extend(getattr(stmt, attr)) |
| 133 | + # Check except blocks for 'body' |
| 134 | + if hasattr(stmt, "handlers"): |
| 135 | + for handler in stmt.handlers: |
| 136 | + if hasattr(handler, "body"): |
| 137 | + stack.extend(handler.body) |
| 138 | + return False |
| 139 | + |
116 | 140 |
|
117 | 141 | def remove_benchmark_functions(tree: AST) -> AST:
|
118 | 142 | """Remove benchmark functions from Python source code.
|
|
0 commit comments