Skip to content

⚡️ Speed up function create_trace_replay_test_code by 32% in PR #363 (part-1-windows-fixes) #364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: part-1-windows-fixes
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 128 additions & 110 deletions codeflash/benchmarking/replay_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import sqlite3
import textwrap
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -43,6 +42,7 @@ def get_next_arg_and_return(


def get_function_alias(module: str, function_name: str) -> str:
# This is already pretty optimal.
return "_".join(module.split(".")) + "_" + function_name


Expand All @@ -66,152 +66,144 @@ def create_trace_replay_test_code(
A string containing the test code

"""
assert test_framework in ["pytest", "unittest"]
assert test_framework in ("pytest", "unittest")

# Create Imports
imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
{"import unittest" if test_framework == "unittest" else ""}
from codeflash.benchmarking.replay_test import get_next_arg_and_return
"""
# Precompute aliases and filepaths
func_aliases, class_aliases, classfunc_aliases, file_paths = _get_aliases_and_paths(functions_data)

# Build function imports in one pass
function_imports = []
for func in functions_data:
module_name = func.get("module_name")
function_name = func.get("function_name")
class_name = func.get("class_name", "")
if class_name:
function_imports.append(
f"from {module_name} import {class_name} as {get_function_alias(module_name, class_name)}"
)
cname_alias = class_aliases[class_name]
function_imports.append(f"from {module_name} import {class_name} as {cname_alias}")
else:
function_imports.append(
f"from {module_name} import {function_name} as {get_function_alias(module_name, function_name)}"
)

imports += "\n".join(function_imports)

functions_to_optimize = sorted(
{func.get("function_name") for func in functions_data if func.get("function_name") != "__init__"}
alias = func_aliases[(module_name, function_name)]
function_imports.append(f"from {module_name} import {function_name} as {alias}")
imports = (
"from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle\n"
f"{'import unittest' if test_framework == 'unittest' else ''}\n"
"from codeflash.benchmarking.replay_test import get_next_arg_and_return\n" + "\n".join(function_imports)
)

# Precompute functions_to_optimize efficiently using set and list since sorted(set(...))
functions_set = {func["function_name"] for func in functions_data if func["function_name"] != "__init__"}
functions_to_optimize = sorted(functions_set)
metadata = f"""functions = {functions_to_optimize}
trace_file_path = r"{trace_file}"
"""
# Templates for different types of tests
test_function_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl)
ret = {function_name}(*args, **kwargs)
"""
)

test_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
function_name = "{orig_function_name}"
if not args:
raise ValueError("No arguments provided for the method.")
if function_name == "__init__":
ret = {class_name_alias}(*args[1:], **kwargs)
else:
ret = {class_name_alias}{method_name}(*args, **kwargs)
"""
# Prepare templates only once
test_function_body = (
"for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, "
'benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", '
'file_path=r"{file_path}", num_to_get={max_run_count}):\n'
" args = pickle.loads(args_pkl)\n"
" kwargs = pickle.loads(kwargs_pkl)\n"
" ret = {function_name}(*args, **kwargs)\n"
)

test_class_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
if not args:
raise ValueError("No arguments provided for the method.")
ret = {class_name_alias}{method_name}(*args[1:], **kwargs)
"""
test_method_body = (
"for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, "
'benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", '
'file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n'
" args = pickle.loads(args_pkl)\n"
" kwargs = pickle.loads(kwargs_pkl){filter_variables}\n"
' function_name = "{orig_function_name}"\n'
" if not args:\n"
' raise ValueError("No arguments provided for the method.")\n'
' if function_name == "__init__":\n'
" ret = {class_name_alias}(*args[1:], **kwargs)\n"
" else:\n"
" ret = {class_name_alias}{method_name}(*args, **kwargs)\n"
)
test_static_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
ret = {class_name_alias}{method_name}(*args, **kwargs)
"""
test_class_method_body = (
"for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, "
'benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", '
'file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n'
" args = pickle.loads(args_pkl)\n"
" kwargs = pickle.loads(kwargs_pkl){filter_variables}\n"
" if not args:\n"
' raise ValueError("No arguments provided for the method.")\n'
" ret = {class_name_alias}{method_name}(*args[1:], **kwargs)\n"
)

# Create main body

test_static_method_body = (
"for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, "
'benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", '
'file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n'
" args = pickle.loads(args_pkl)\n"
" kwargs = pickle.loads(kwargs_pkl){filter_variables}\n"
" ret = {class_name_alias}{method_name}(*args, **kwargs)\n"
)
test_bodies = {
"function": test_function_body,
"method": test_method_body,
"classmethod": test_class_method_body,
"staticmethod": test_static_method_body,
}

# Precompute the format values up-front for all functions
if test_framework == "unittest":
self = "self"
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
self_str = "self"
test_template_list = ["\nclass TestTracedFunctions(unittest.TestCase):\n"]
indent_level = " "
def_line = " "
else:
test_template = ""
self = ""
self_str = ""
test_template_list = []
indent_level = " "
def_line = ""

for func in functions_data:
module_name = func.get("module_name")
function_name = func.get("function_name")
module_name = func["module_name"]
function_name = func["function_name"]
class_name = func.get("class_name")
file_path = func.get("file_path")
benchmark_function_name = func.get("benchmark_function_name")
function_properties = func.get("function_properties")
file_path = func["file_path"]
file_path_posix = file_paths[file_path]
benchmark_function_name = func["benchmark_function_name"]
function_properties = func["function_properties"]
if not class_name:
alias = get_function_alias(module_name, function_name)
test_body = test_function_body.format(
alias = func_aliases[(module_name, function_name)]
template = test_bodies["function"]
test_body_filled = template.format(
benchmark_function_name=benchmark_function_name,
orig_function_name=function_name,
function_name=alias,
file_path=Path(file_path).as_posix(),
file_path=file_path_posix,
max_run_count=max_run_count,
)
else:
class_name_alias = get_function_alias(module_name, class_name)
alias = get_function_alias(module_name, class_name + "_" + function_name)

class_name_alias = class_aliases[class_name]
alias = classfunc_aliases[(module_name, class_name, function_name)]
filter_variables = ""
# filter_variables = '\n args.pop("cls", None)'
method_name = "." + function_name if function_name != "__init__" else ""
if function_properties.is_classmethod:
test_body = test_class_method_body.format(
benchmark_function_name=benchmark_function_name,
orig_function_name=function_name,
file_path=Path(file_path).as_posix(),
class_name_alias=class_name_alias,
class_name=class_name,
method_name=method_name,
max_run_count=max_run_count,
filter_variables=filter_variables,
)
template = test_bodies["classmethod"]
elif function_properties.is_staticmethod:
test_body = test_static_method_body.format(
benchmark_function_name=benchmark_function_name,
orig_function_name=function_name,
file_path=Path(file_path).as_posix(),
class_name_alias=class_name_alias,
class_name=class_name,
method_name=method_name,
max_run_count=max_run_count,
filter_variables=filter_variables,
)
template = test_bodies["staticmethod"]
else:
test_body = test_method_body.format(
benchmark_function_name=benchmark_function_name,
orig_function_name=function_name,
file_path=Path(file_path).as_posix(),
class_name_alias=class_name_alias,
class_name=class_name,
method_name=method_name,
max_run_count=max_run_count,
filter_variables=filter_variables,
)
template = test_bodies["method"]
test_body_filled = template.format(
benchmark_function_name=benchmark_function_name,
orig_function_name=function_name,
file_path=file_path_posix,
class_name_alias=class_name_alias,
class_name=class_name,
method_name=method_name,
max_run_count=max_run_count,
filter_variables=filter_variables,
)

formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
# No repeated indent/dedent. Do indent directly, as we know where to indent.
formatted_test_body = "".join(
indent_level + line if line.strip() else line for line in test_body_filled.splitlines(True)
)

test_template += " " if test_framework == "unittest" else ""
test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n"
test_template_list.append(f"{def_line}def test_{alias}({self_str}):\n{formatted_test_body}\n")

return imports + "\n" + metadata + "\n" + test_template
return imports + "\n" + metadata + "\n" + "".join(test_template_list)


def generate_replay_test(
Expand Down Expand Up @@ -294,3 +286,29 @@ def generate_replay_test(
logger.info(f"Error generating replay tests: {e}")

return count


def _get_aliases_and_paths(functions_data):
# Precompute all needed aliases and file posix paths up front in a single pass
func_aliases = {}
class_aliases = {}
classfunc_aliases = {}
file_paths = {}
for func in functions_data:
module_name = func.get("module_name")
function_name = func.get("function_name")
class_name = func.get("class_name", "")
file_path = func.get("file_path")
# Precompute Path(file_path).as_posix() once per unique file_path
if file_path not in file_paths:
file_paths[file_path] = Path(file_path).as_posix()
if class_name:
# avoid re-calculating class alias if already done
if class_name not in class_aliases:
class_aliases[class_name] = get_function_alias(module_name, class_name)
classfunc_key = (module_name, class_name, function_name)
classfunc_aliases[classfunc_key] = get_function_alias(module_name, class_name + "_" + function_name)
else:
# alias for global function
func_aliases[(module_name, function_name)] = get_function_alias(module_name, function_name)
return func_aliases, class_aliases, classfunc_aliases, file_paths
Loading