|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import ast
|
| 4 | +from collections.abc import Iterable |
4 | 5 | from pathlib import Path
|
5 | 6 | from typing import TYPE_CHECKING
|
6 | 7 |
|
|
9 | 10 | from codeflash.cli_cmds.console import logger
|
10 | 11 | from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
|
11 | 12 | from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
12 |
| -from codeflash.models.models import FunctionParent, TestingMode, VerificationType |
| 13 | +from codeflash.models.models import CodePosition, FunctionParent, TestingMode, VerificationType |
13 | 14 |
|
14 | 15 | if TYPE_CHECKING:
|
15 | 16 | from collections.abc import Iterable
|
@@ -64,62 +65,99 @@ def __init__(
|
64 | 65 | self.module_path = module_path
|
65 | 66 | self.test_framework = test_framework
|
66 | 67 | self.call_positions = call_positions
|
| 68 | + # Pre-cache node wrappers often instantiated |
| 69 | + self.ast_codeflash_loop_index = ast.Name(id="codeflash_loop_index", ctx=ast.Load()) |
| 70 | + self.ast_codeflash_cur = ast.Name(id="codeflash_cur", ctx=ast.Load()) |
| 71 | + self.ast_codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load()) |
67 | 72 | if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
|
68 | 73 | self.class_name = function.top_level_parent_name
|
69 | 74 |
|
70 | 75 | def find_and_update_line_node(
|
71 | 76 | self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
|
72 | 77 | ) -> Iterable[ast.stmt] | None:
|
| 78 | + # Optimize: Inline self._in_call_position and cache .func once |
73 | 79 | call_node = None
|
| 80 | + behavior_mode = self.mode == TestingMode.BEHAVIOR |
| 81 | + function_object_name = self.function_object.function_name |
| 82 | + function_qualified_name = self.function_object.qualified_name |
| 83 | + module_path_const = ast.Constant(value=self.module_path) |
| 84 | + test_class_const = ast.Constant(value=test_class_name or None) |
| 85 | + node_name_const = ast.Constant(value=node_name) |
| 86 | + qualified_name_const = ast.Constant(value=function_qualified_name) |
| 87 | + index_const = ast.Constant(value=index) |
| 88 | + args_behavior = [self.ast_codeflash_cur, self.ast_codeflash_con] if behavior_mode else [] |
| 89 | + |
74 | 90 | for node in ast.walk(test_node):
|
75 |
| - if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions): |
76 |
| - call_node = node |
77 |
| - if isinstance(node.func, ast.Name): |
78 |
| - function_name = node.func.id |
| 91 | + # Fast path: check for Call nodes only |
| 92 | + if not (isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset")): |
| 93 | + continue |
| 94 | + # Inline node_in_call_position logic (from profiler hotspot) |
| 95 | + node_lineno = getattr(node, "lineno", None) |
| 96 | + node_col_offset = getattr(node, "col_offset", None) |
| 97 | + node_end_lineno = getattr(node, "end_lineno", None) |
| 98 | + node_end_col_offset = getattr(node, "end_col_offset", None) |
| 99 | + found = False |
| 100 | + for pos in self.call_positions: |
| 101 | + pos_line = pos.line_no |
| 102 | + if pos_line is not None and node_end_lineno is not None and node_lineno <= pos_line <= node_end_lineno: |
| 103 | + if pos_line == node_lineno and node_col_offset <= pos.col_no: |
| 104 | + found = True |
| 105 | + break |
| 106 | + if ( |
| 107 | + pos_line == node_end_lineno |
| 108 | + and node_end_col_offset is not None |
| 109 | + and node_end_col_offset >= pos.col_no |
| 110 | + ): |
| 111 | + found = True |
| 112 | + break |
| 113 | + if node_lineno < pos_line < node_end_lineno: |
| 114 | + found = True |
| 115 | + break |
| 116 | + if not found: |
| 117 | + continue |
| 118 | + |
| 119 | + call_node = node |
| 120 | + func = node.func |
| 121 | + # Handle ast.Name fast path |
| 122 | + if isinstance(func, ast.Name): |
| 123 | + function_name = func.id |
| 124 | + node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) |
| 125 | + # Build ast.Name fields for use in args |
| 126 | + codeflash_func_arg = ast.Name(id=function_name, ctx=ast.Load()) |
| 127 | + # Compose argument tuple directly, for speed |
| 128 | + node.args = [ |
| 129 | + codeflash_func_arg, |
| 130 | + module_path_const, |
| 131 | + test_class_const, |
| 132 | + node_name_const, |
| 133 | + qualified_name_const, |
| 134 | + index_const, |
| 135 | + self.ast_codeflash_loop_index, |
| 136 | + *args_behavior, |
| 137 | + *call_node.args, |
| 138 | + ] |
| 139 | + node.keywords = call_node.keywords |
| 140 | + break |
| 141 | + if isinstance(func, ast.Attribute): |
| 142 | + # This path is almost never hit (profile), but handle it |
| 143 | + function_to_test = func.attr |
| 144 | + if function_to_test == function_object_name: |
| 145 | + # NOTE: ast.unparse is very slow; only call if necessary |
| 146 | + function_name = ast.unparse(func) |
79 | 147 | node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
|
80 | 148 | node.args = [
|
81 | 149 | ast.Name(id=function_name, ctx=ast.Load()),
|
82 |
| - ast.Constant(value=self.module_path), |
83 |
| - ast.Constant(value=test_class_name or None), |
84 |
| - ast.Constant(value=node_name), |
85 |
| - ast.Constant(value=self.function_object.qualified_name), |
86 |
| - ast.Constant(value=index), |
87 |
| - ast.Name(id="codeflash_loop_index", ctx=ast.Load()), |
88 |
| - *( |
89 |
| - [ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())] |
90 |
| - if self.mode == TestingMode.BEHAVIOR |
91 |
| - else [] |
92 |
| - ), |
| 150 | + module_path_const, |
| 151 | + test_class_const, |
| 152 | + node_name_const, |
| 153 | + qualified_name_const, |
| 154 | + index_const, |
| 155 | + self.ast_codeflash_loop_index, |
| 156 | + *args_behavior, |
93 | 157 | *call_node.args,
|
94 | 158 | ]
|
95 | 159 | node.keywords = call_node.keywords
|
96 | 160 | break
|
97 |
| - if isinstance(node.func, ast.Attribute): |
98 |
| - function_to_test = node.func.attr |
99 |
| - if function_to_test == self.function_object.function_name: |
100 |
| - function_name = ast.unparse(node.func) |
101 |
| - node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) |
102 |
| - node.args = [ |
103 |
| - ast.Name(id=function_name, ctx=ast.Load()), |
104 |
| - ast.Constant(value=self.module_path), |
105 |
| - ast.Constant(value=test_class_name or None), |
106 |
| - ast.Constant(value=node_name), |
107 |
| - ast.Constant(value=self.function_object.qualified_name), |
108 |
| - ast.Constant(value=index), |
109 |
| - ast.Name(id="codeflash_loop_index", ctx=ast.Load()), |
110 |
| - *( |
111 |
| - [ |
112 |
| - ast.Name(id="codeflash_cur", ctx=ast.Load()), |
113 |
| - ast.Name(id="codeflash_con", ctx=ast.Load()), |
114 |
| - ] |
115 |
| - if self.mode == TestingMode.BEHAVIOR |
116 |
| - else [] |
117 |
| - ), |
118 |
| - *call_node.args, |
119 |
| - ] |
120 |
| - node.keywords = call_node.keywords |
121 |
| - break |
122 |
| - |
123 | 161 | if call_node is None:
|
124 | 162 | return None
|
125 | 163 | return [test_node]
|
@@ -153,6 +191,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None =
|
153 | 191 | while j >= 0:
|
154 | 192 | compound_line_node: ast.stmt = line_node.body[j]
|
155 | 193 | internal_node: ast.AST
|
| 194 | + # No significant hotspot here; ast.walk used on small subtrees |
156 | 195 | for internal_node in ast.walk(compound_line_node):
|
157 | 196 | if isinstance(internal_node, (ast.stmt, ast.Assign)):
|
158 | 197 | updated_node = self.find_and_update_line_node(
|
@@ -284,6 +323,29 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None =
|
284 | 323 | ]
|
285 | 324 | return node
|
286 | 325 |
|
| 326 | + def _in_call_position(self, node: ast.AST) -> bool: |
| 327 | + # Inline node_in_call_position for performance |
| 328 | + if not (isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset")): |
| 329 | + return False |
| 330 | + node_lineno = getattr(node, "lineno", None) |
| 331 | + node_col_offset = getattr(node, "col_offset", None) |
| 332 | + node_end_lineno = getattr(node, "end_lineno", None) |
| 333 | + node_end_col_offset = getattr(node, "end_col_offset", None) |
| 334 | + for pos in self.call_positions: |
| 335 | + pos_line = pos.line_no |
| 336 | + if pos_line is not None and node_end_lineno is not None and node_lineno <= pos_line <= node_end_lineno: |
| 337 | + if pos_line == node_lineno and node_col_offset <= pos.col_no: |
| 338 | + return True |
| 339 | + if ( |
| 340 | + pos_line == node_end_lineno |
| 341 | + and node_end_col_offset is not None |
| 342 | + and node_end_col_offset >= pos.col_no |
| 343 | + ): |
| 344 | + return True |
| 345 | + if node_lineno < pos_line < node_end_lineno: |
| 346 | + return True |
| 347 | + return False |
| 348 | + |
287 | 349 |
|
288 | 350 | class FunctionImportedAsVisitor(ast.NodeVisitor):
|
289 | 351 | """Checks if a function has been imported as an alias. We only care about the alias then.
|
|
0 commit comments