diff --git a/codeflash/verification/instrument_codeflash_capture.py b/codeflash/verification/instrument_codeflash_capture.py index d1f9816d..c568f2da 100644 --- a/codeflash/verification/instrument_codeflash_capture.py +++ b/codeflash/verification/instrument_codeflash_capture.py @@ -92,6 +92,18 @@ def __init__( self.tests_root = tests_root self.inserted_decorator = False + # Prebuild decorator node to reuse + self._decorator = ast.Call( + func=ast.Name(id="codeflash_capture", ctx=ast.Load()), + args=[], + keywords=[ + ast.keyword(arg="function_name", value=ast.Constant(value=None)), # to be set per class + ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)), + ast.keyword(arg="tests_root", value=ast.Constant(value=self.tests_root.as_posix())), + ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)), + ], + ) + def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom: # Check if our import already exists if node.module == "codeflash.verification.codeflash_capture" and any( @@ -114,21 +126,22 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: if node.name not in self.target_classes: return node - # Look for __init__ method has_init = False + init_node = None - # Create the decorator + # Prepare a decorator customized to function_name argument decorator = ast.Call( - func=ast.Name(id="codeflash_capture", ctx=ast.Load()), + func=self._decorator.func, args=[], keywords=[ ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")), - ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)), - ast.keyword(arg="tests_root", value=ast.Constant(value=self.tests_root.as_posix())), - ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)), + self._decorator.keywords[1], + self._decorator.keywords[2], + self._decorator.keywords[3], ], ) + # Fast scan for __init__ in class body for item in node.body: if ( isinstance(item, ast.FunctionDef) @@ -138,16 +151,19 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: and item.args.args[0].arg == "self" ): has_init = True - - # Add decorator at the start of the list if not already present - if not any( - isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture" - for d in item.decorator_list - ): - item.decorator_list.insert(0, decorator) - self.inserted_decorator = True - - if not has_init: + init_node = item + break # __init__ found, no need to scan rest + + if has_init: + # Add decorator at the start if not already present + # Use direct field access and generator for fast check + if not any( + isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture" + for d in init_node.decorator_list + ): + init_node.decorator_list.insert(0, decorator) + self.inserted_decorator = True + else: # Create super().__init__(*args, **kwargs) call super_call = ast.Expr( value=ast.Call(