diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 9a737298..198899fd 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +from collections.abc import Iterable from pathlib import Path from typing import TYPE_CHECKING @@ -9,7 +10,7 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent, TestingMode, VerificationType +from codeflash.models.models import CodePosition, FunctionParent, TestingMode, VerificationType if TYPE_CHECKING: from collections.abc import Iterable @@ -64,62 +65,99 @@ def __init__( self.module_path = module_path self.test_framework = test_framework self.call_positions = call_positions + # Pre-cache node wrappers often instantiated + self.ast_codeflash_loop_index = ast.Name(id="codeflash_loop_index", ctx=ast.Load()) + self.ast_codeflash_cur = ast.Name(id="codeflash_cur", ctx=ast.Load()) + self.ast_codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load()) if len(function.parents) == 1 and function.parents[0].type == "ClassDef": self.class_name = function.top_level_parent_name def find_and_update_line_node( self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None ) -> Iterable[ast.stmt] | None: + # Optimize: Inline self._in_call_position and cache .func once call_node = None + behavior_mode = self.mode == TestingMode.BEHAVIOR + function_object_name = self.function_object.function_name + function_qualified_name = self.function_object.qualified_name + module_path_const = ast.Constant(value=self.module_path) + test_class_const = ast.Constant(value=test_class_name or None) + node_name_const = ast.Constant(value=node_name) + qualified_name_const = ast.Constant(value=function_qualified_name) + index_const = ast.Constant(value=index) + args_behavior = [self.ast_codeflash_cur, self.ast_codeflash_con] if behavior_mode else [] + for node in ast.walk(test_node): - if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions): - call_node = node - if isinstance(node.func, ast.Name): - function_name = node.func.id + # Fast path: check for Call nodes only + if not (isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset")): + continue + # Inline node_in_call_position logic (from profiler hotspot) + node_lineno = getattr(node, "lineno", None) + node_col_offset = getattr(node, "col_offset", None) + node_end_lineno = getattr(node, "end_lineno", None) + node_end_col_offset = getattr(node, "end_col_offset", None) + found = False + for pos in self.call_positions: + pos_line = pos.line_no + if pos_line is not None and node_end_lineno is not None and node_lineno <= pos_line <= node_end_lineno: + if pos_line == node_lineno and node_col_offset <= pos.col_no: + found = True + break + if ( + pos_line == node_end_lineno + and node_end_col_offset is not None + and node_end_col_offset >= pos.col_no + ): + found = True + break + if node_lineno < pos_line < node_end_lineno: + found = True + break + if not found: + continue + + call_node = node + func = node.func + # Handle ast.Name fast path + if isinstance(func, ast.Name): + function_name = func.id + node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) + # Build ast.Name fields for use in args + codeflash_func_arg = ast.Name(id=function_name, ctx=ast.Load()) + # Compose argument tuple directly, for speed + node.args = [ + codeflash_func_arg, + module_path_const, + test_class_const, + node_name_const, + qualified_name_const, + index_const, + self.ast_codeflash_loop_index, + *args_behavior, + *call_node.args, + ] + node.keywords = call_node.keywords + break + if isinstance(func, ast.Attribute): + # This path is almost never hit (profile), but handle it + function_to_test = func.attr + if function_to_test == function_object_name: + # NOTE: ast.unparse is very slow; only call if necessary + function_name = ast.unparse(func) node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) node.args = [ ast.Name(id=function_name, ctx=ast.Load()), - ast.Constant(value=self.module_path), - ast.Constant(value=test_class_name or None), - ast.Constant(value=node_name), - ast.Constant(value=self.function_object.qualified_name), - ast.Constant(value=index), - ast.Name(id="codeflash_loop_index", ctx=ast.Load()), - *( - [ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())] - if self.mode == TestingMode.BEHAVIOR - else [] - ), + module_path_const, + test_class_const, + node_name_const, + qualified_name_const, + index_const, + self.ast_codeflash_loop_index, + *args_behavior, *call_node.args, ] node.keywords = call_node.keywords break - if isinstance(node.func, ast.Attribute): - function_to_test = node.func.attr - if function_to_test == self.function_object.function_name: - function_name = ast.unparse(node.func) - node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) - node.args = [ - ast.Name(id=function_name, ctx=ast.Load()), - ast.Constant(value=self.module_path), - ast.Constant(value=test_class_name or None), - ast.Constant(value=node_name), - ast.Constant(value=self.function_object.qualified_name), - ast.Constant(value=index), - ast.Name(id="codeflash_loop_index", ctx=ast.Load()), - *( - [ - ast.Name(id="codeflash_cur", ctx=ast.Load()), - ast.Name(id="codeflash_con", ctx=ast.Load()), - ] - if self.mode == TestingMode.BEHAVIOR - else [] - ), - *call_node.args, - ] - node.keywords = call_node.keywords - break - if call_node is None: return None return [test_node] @@ -153,6 +191,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = while j >= 0: compound_line_node: ast.stmt = line_node.body[j] internal_node: ast.AST + # No significant hotspot here; ast.walk used on small subtrees for internal_node in ast.walk(compound_line_node): if isinstance(internal_node, (ast.stmt, ast.Assign)): updated_node = self.find_and_update_line_node( @@ -284,6 +323,29 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = ] return node + def _in_call_position(self, node: ast.AST) -> bool: + # Inline node_in_call_position for performance + if not (isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset")): + return False + node_lineno = getattr(node, "lineno", None) + node_col_offset = getattr(node, "col_offset", None) + node_end_lineno = getattr(node, "end_lineno", None) + node_end_col_offset = getattr(node, "end_col_offset", None) + for pos in self.call_positions: + pos_line = pos.line_no + if pos_line is not None and node_end_lineno is not None and node_lineno <= pos_line <= node_end_lineno: + if pos_line == node_lineno and node_col_offset <= pos.col_no: + return True + if ( + pos_line == node_end_lineno + and node_end_col_offset is not None + and node_end_col_offset >= pos.col_no + ): + return True + if node_lineno < pos_line < node_end_lineno: + return True + return False + class FunctionImportedAsVisitor(ast.NodeVisitor): """Checks if a function has been imported as an alias. We only care about the alias then.