Skip to content

⚡️ Speed up method InitDecorator.visit_ClassDef by 91% in PR #363 (part-1-windows-fixes) #366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: part-1-windows-fixes
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions codeflash/verification/instrument_codeflash_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ def __init__(
self.tests_root = tests_root
self.inserted_decorator = False

# Prebuild decorator node to reuse
self._decorator = ast.Call(
func=ast.Name(id="codeflash_capture", ctx=ast.Load()),
args=[],
keywords=[
ast.keyword(arg="function_name", value=ast.Constant(value=None)), # to be set per class
ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)),
ast.keyword(arg="tests_root", value=ast.Constant(value=self.tests_root.as_posix())),
ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)),
],
)

def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
# Check if our import already exists
if node.module == "codeflash.verification.codeflash_capture" and any(
Expand All @@ -114,21 +126,22 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
if node.name not in self.target_classes:
return node

# Look for __init__ method
has_init = False
init_node = None

# Create the decorator
# Prepare a decorator customized to function_name argument
decorator = ast.Call(
func=ast.Name(id="codeflash_capture", ctx=ast.Load()),
func=self._decorator.func,
args=[],
keywords=[
ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")),
ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)),
ast.keyword(arg="tests_root", value=ast.Constant(value=self.tests_root.as_posix())),
ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)),
self._decorator.keywords[1],
self._decorator.keywords[2],
self._decorator.keywords[3],
],
)

# Fast scan for __init__ in class body
for item in node.body:
if (
isinstance(item, ast.FunctionDef)
Expand All @@ -138,16 +151,19 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
and item.args.args[0].arg == "self"
):
has_init = True

# Add decorator at the start of the list if not already present
if not any(
isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture"
for d in item.decorator_list
):
item.decorator_list.insert(0, decorator)
self.inserted_decorator = True

if not has_init:
init_node = item
break # __init__ found, no need to scan rest

if has_init:
# Add decorator at the start if not already present
# Use direct field access and generator for fast check
if not any(
isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture"
for d in init_node.decorator_list
):
init_node.decorator_list.insert(0, decorator)
self.inserted_decorator = True
else:
# Create super().__init__(*args, **kwargs) call
super_call = ast.Expr(
value=ast.Call(
Expand Down
Loading