diff --git a/code_to_optimize/code_directories/simple_tracer_e2e/codeflash.trace b/code_to_optimize/code_directories/simple_tracer_e2e/codeflash.trace new file mode 100644 index 00000000..6e3ea527 Binary files /dev/null and b/code_to_optimize/code_directories/simple_tracer_e2e/codeflash.trace differ diff --git a/code_to_optimize/code_directories/simple_tracer_e2e/workload.py b/code_to_optimize/code_directories/simple_tracer_e2e/workload.py index db708a5c..3b207a7a 100644 --- a/code_to_optimize/code_directories/simple_tracer_e2e/workload.py +++ b/code_to_optimize/code_directories/simple_tracer_e2e/workload.py @@ -1,4 +1,5 @@ from concurrent.futures import ThreadPoolExecutor +from time import sleep def funcA(number): @@ -46,12 +47,20 @@ def _classify(self, features): class SimpleModel: @staticmethod def predict(data): - return [x * 2 for x in data] + result = [] + sleep(0.1) # can be optimized away + for i in range(500): + for x in data: + computation = 0 + computation += x * i ** 2 + result.append(computation) + return result @classmethod def create_default(cls): return cls() + def test_models(): model = AlexNet(num_classes=10) input_data = [1, 2, 3, 4, 5] diff --git a/codeflash/benchmarking/function_ranker.py b/codeflash/benchmarking/function_ranker.py new file mode 100644 index 00000000..9d1d8ec1 --- /dev/null +++ b/codeflash/benchmarking/function_ranker.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from codeflash.cli_cmds.console import console, logger +from codeflash.code_utils.config_consts import DEFAULT_IMPORTANCE_THRESHOLD +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.tracing.profile_stats import ProfileStats + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + + +class FunctionRanker: + """Ranks and filters functions based on a ttX score derived from profiling data. + + The ttX score is calculated as: + ttX = own_time + (time_spent_in_callees / call_count) + + This score prioritizes functions that are computationally heavy themselves (high `own_time`) + or that make expensive calls to other functions (high average `time_spent_in_callees`). + + Functions are first filtered by an importance threshold based on their `own_time` as a + fraction of the total runtime. The remaining functions are then ranked by their ttX score + to identify the best candidates for optimization. + """ + + def __init__(self, trace_file_path: Path) -> None: + self.trace_file_path = trace_file_path + self._profile_stats = ProfileStats(trace_file_path.as_posix()) + self._function_stats: dict[str, dict] = {} + self.load_function_stats() + + def load_function_stats(self) -> None: + try: + for (filename, line_number, func_name), ( + call_count, + _num_callers, + total_time_ns, + cumulative_time_ns, + _callers, + ) in self._profile_stats.stats.items(): + if call_count <= 0: + continue + + # Parse function name to handle methods within classes + class_name, qualified_name, base_function_name = (None, func_name, func_name) + if "." in func_name and not func_name.startswith("<"): + parts = func_name.split(".", 1) + if len(parts) == 2: + class_name, base_function_name = parts + + # Calculate own time (total time - time spent in subcalls) + own_time_ns = total_time_ns + time_in_callees_ns = cumulative_time_ns - total_time_ns + + # Calculate ttX score + ttx_score = own_time_ns + (time_in_callees_ns / call_count) + + function_key = f"{filename}:{qualified_name}" + self._function_stats[function_key] = { + "filename": filename, + "function_name": base_function_name, + "qualified_name": qualified_name, + "class_name": class_name, + "line_number": line_number, + "call_count": call_count, + "own_time_ns": own_time_ns, + "cumulative_time_ns": cumulative_time_ns, + "time_in_callees_ns": time_in_callees_ns, + "ttx_score": ttx_score, + } + + logger.debug(f"Loaded timing stats for {len(self._function_stats)} functions from trace using ProfileStats") + + except Exception as e: + logger.warning(f"Failed to process function stats from trace file {self.trace_file_path}: {e}") + self._function_stats = {} + + def _get_function_stats(self, function_to_optimize: FunctionToOptimize) -> dict | None: + target_filename = function_to_optimize.file_path.name + for key, stats in self._function_stats.items(): + if stats.get("function_name") == function_to_optimize.function_name and ( + key.endswith(f"/{target_filename}") or target_filename in key + ): + return stats + + logger.debug( + f"Could not find stats for function {function_to_optimize.function_name} in file {target_filename}" + ) + return None + + def get_function_ttx_score(self, function_to_optimize: FunctionToOptimize) -> float: + stats = self._get_function_stats(function_to_optimize) + return stats["ttx_score"] if stats else 0.0 + + def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]: + ranked = sorted(functions_to_optimize, key=self.get_function_ttx_score, reverse=True) + logger.debug( + f"Function ranking order: {[f'{func.function_name} (ttX={self.get_function_ttx_score(func):.2f})' for func in ranked]}" + ) + return ranked + + def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None: + return self._get_function_stats(function_to_optimize) + + def rerank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]: + """Ranks functions based on their ttX score. + + This method calculates the ttX score for each function and returns + the functions sorted in descending order of their ttX score. + """ + if not self._function_stats: + logger.warning("No function stats available to rank functions.") + return [] + + return self.rank_functions(functions_to_optimize) + + def rerank_and_filter_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]: + """Reranks and filters functions based on their impact on total runtime. + + This method first calculates the total runtime of all profiled functions. + It then filters out functions whose own_time is less than a specified + percentage of the total runtime (importance_threshold). + + The remaining 'important' functions are then ranked by their ttX score. + """ + stats_map = self._function_stats + if not stats_map: + return [] + + total_program_time = sum(s["own_time_ns"] for s in stats_map.values() if s.get("own_time_ns", 0) > 0) + + if total_program_time == 0: + logger.warning("Total program time is zero, cannot determine function importance.") + return self.rank_functions(functions_to_optimize) + + important_functions = [] + for func in functions_to_optimize: + func_stats = self._get_function_stats(func) + if func_stats and func_stats.get("own_time_ns", 0) > 0: + importance = func_stats["own_time_ns"] / total_program_time + if importance >= DEFAULT_IMPORTANCE_THRESHOLD: + important_functions.append(func) + else: + logger.debug( + f"Filtering out function {func.qualified_name} with importance " + f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})" + ) + + logger.info( + f"Filtered down to {len(important_functions)} important functions from {len(functions_to_optimize)} total functions" + ) + console.rule() + + return self.rank_functions(important_functions) diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index c6aaebfe..265d3274 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -22,6 +22,36 @@ def parse_args() -> Namespace: init_actions_parser = subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow") init_actions_parser.set_defaults(func=install_github_actions) + + trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize a Python project.") + + from codeflash.tracer import main as tracer_main + + trace_optimize.set_defaults(func=tracer_main) + + trace_optimize.add_argument( + "--max-function-count", + type=int, + default=100, + help="The maximum number of times to trace a single function. More calls to a function will not be traced. Default is 100.", + ) + trace_optimize.add_argument( + "--timeout", + type=int, + help="The maximum time in seconds to trace the entire workflow. Default is indefinite. This is useful while tracing really long workflows, to not wait indefinitely.", + ) + trace_optimize.add_argument( + "--output", + type=str, + default="codeflash.trace", + help="The file to save the trace to. Default is codeflash.trace.", + ) + trace_optimize.add_argument( + "--config-file-path", + type=str, + help="The path to the pyproject.toml file which stores the Codeflash config. This is auto-discovered by default.", + ) + parser.add_argument("--file", help="Try to optimize only this file") parser.add_argument("--function", help="Try to optimize only this function within the given file path") parser.add_argument( @@ -64,7 +94,8 @@ def parse_args() -> Namespace: ) parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs") - args: Namespace = parser.parse_args() + args, unknown_args = parser.parse_known_args() + sys.argv[:] = [sys.argv[0], *unknown_args] return process_and_validate_cmd_args(args) @@ -102,6 +133,8 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace: if not Path(test_path).is_file(): exit_with_message(f"Replay test file {test_path} does not exist", error_on_exit=True) args.replay_test = [Path(replay_test).resolve() for replay_test in args.replay_test] + if env_utils.is_ci(): + args.no_pr = True return args diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index 0b8f5420..50b4bce1 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -10,3 +10,4 @@ COVERAGE_THRESHOLD = 60.0 MIN_TESTCASE_PASSED_THRESHOLD = 6 REPEAT_OPTIMIZATION_PROBABILITY = 0.1 +DEFAULT_IMPORTANCE_THRESHOLD = 0.001 diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index f127a305..3def52e3 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -112,6 +112,27 @@ def get_cached_gh_event_data() -> dict[str, Any] | None: return json.load(f) # type: ignore # noqa +@lru_cache(maxsize=1) +def is_ci() -> bool: + """Check if running in a CI environment.""" + return bool(os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS")) + + @lru_cache(maxsize=1) def is_LSP_enabled() -> bool: return console.quiet + + +def is_pr_draft() -> bool: + """Check if the PR is draft. in the github action context.""" + try: + event_path = os.getenv("GITHUB_EVENT_PATH") + pr_number = get_pr_number() + if pr_number is not None and event_path: + with Path(event_path).open() as f: + event_data = json.load(f) + return bool(event_data["pull_request"]["draft"]) + return False # noqa + except Exception as e: + logger.warning(f"Error checking if PR is draft: {e}") + return False diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index b4b54815..357d6153 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -160,11 +160,12 @@ def get_functions_to_optimize( project_root: Path, module_root: Path, previous_checkpoint_functions: dict[str, dict[str, str]] | None = None, -) -> tuple[dict[Path, list[FunctionToOptimize]], int]: +) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]: assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, ( "Only one of optimize_all, replay_test, or file should be provided" ) functions: dict[str, list[FunctionToOptimize]] + trace_file_path: Path | None = None with warnings.catch_warnings(): warnings.simplefilter(action="ignore", category=SyntaxWarning) if optimize_all: @@ -172,7 +173,7 @@ def get_functions_to_optimize( console.rule() functions = get_all_files_and_functions(Path(optimize_all)) elif replay_test: - functions = get_all_replay_test_functions( + functions, trace_file_path = get_all_replay_test_functions( replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root ) elif file is not None: @@ -208,6 +209,7 @@ def get_functions_to_optimize( filtered_modified_functions, functions_count = filter_functions( functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions ) + logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize") if optimize_all: three_min_in_ns = int(1.8e11) @@ -216,7 +218,7 @@ def get_functions_to_optimize( f"It might take about {humanize_runtime(functions_count * three_min_in_ns)} to fully optimize this project. Codeflash " f"will keep opening pull requests as it finds optimizations." ) - return filtered_modified_functions, functions_count + return filtered_modified_functions, functions_count, trace_file_path def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]: @@ -274,7 +276,34 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt def get_all_replay_test_functions( replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path -) -> dict[Path, list[FunctionToOptimize]]: +) -> tuple[dict[Path, list[FunctionToOptimize]], Path]: + trace_file_path: Path | None = None + for replay_test_file in replay_test: + try: + with replay_test_file.open("r", encoding="utf8") as f: + tree = ast.parse(f.read()) + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if ( + isinstance(target, ast.Name) + and target.id == "trace_file_path" + and isinstance(node.value, ast.Constant) + and isinstance(node.value.value, str) + ): + trace_file_path = Path(node.value.value) + break + if trace_file_path: + break + if trace_file_path: + break + except Exception as e: + logger.warning(f"Error parsing replay test file {replay_test_file}: {e}") + + if not trace_file_path: + logger.error("Could not find trace_file_path in replay test files.") + exit_with_message("Could not find trace_file_path in replay test files.") + function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=replay_test) # Get the absolute file paths for each function, excluding class name if present filtered_valid_functions = defaultdict(list) @@ -319,7 +348,7 @@ def get_all_replay_test_functions( if filtered_list: filtered_valid_functions[file_path] = filtered_list - return filtered_valid_functions + return filtered_valid_functions, trace_file_path def is_git_repo(file_path: str) -> bool: diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 819e6126..cd53cc56 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -1,7 +1,6 @@ from __future__ import annotations import ast -import json import os import tempfile import time @@ -13,7 +12,7 @@ from codeflash.cli_cmds.console import console, logger, progress_bar from codeflash.code_utils import env_utils from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file -from codeflash.code_utils.env_utils import get_pr_number +from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft from codeflash.either import is_successful from codeflash.models.models import ValidCode from codeflash.telemetry.posthog_cf import ph @@ -64,7 +63,6 @@ def run_benchmarks( from codeflash.benchmarking.replay_test import generate_replay_test from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table - from codeflash.code_utils.env_utils import get_pr_number with progress_bar( f"Running benchmarks in {self.args.benchmarks_root}", transient=True, revert_to_print=bool(get_pr_number()) @@ -109,7 +107,7 @@ def run_benchmarks( return function_benchmark_timings, total_benchmark_timings - def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int]: + def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]: """Discover functions to optimize.""" from codeflash.discovery.functions_to_optimize import get_functions_to_optimize @@ -255,7 +253,7 @@ def run(self) -> None: cleanup_paths(Optimizer.find_leftover_instrumented_test_files(self.test_cfg.tests_root)) function_optimizer = None - file_to_funcs_to_optimize, num_optimizable_functions = self.get_optimizable_functions() + file_to_funcs_to_optimize, num_optimizable_functions, trace_file_path = self.get_optimizable_functions() function_benchmark_timings, total_benchmark_timings = self.run_benchmarks( file_to_funcs_to_optimize, num_optimizable_functions ) @@ -282,7 +280,21 @@ def run(self) -> None: validated_original_code, original_module_ast = module_prep_result - for function_to_optimize in file_to_funcs_to_optimize[original_module_path]: + functions_to_optimize = file_to_funcs_to_optimize[original_module_path] + if trace_file_path and trace_file_path.exists() and len(functions_to_optimize) > 1: + try: + from codeflash.benchmarking.function_ranker import FunctionRanker + + ranker = FunctionRanker(trace_file_path) + functions_to_optimize = ranker.rank_functions(functions_to_optimize) + logger.info( + f"Ranked {len(functions_to_optimize)} functions by performance impact in {original_module_path}" + ) + console.rule() + except Exception as e: + logger.debug(f"Could not rank functions in {original_module_path}: {e}") + + for function_to_optimize in functions_to_optimize: function_iterator_count += 1 logger.info( f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: " @@ -322,6 +334,8 @@ def run(self) -> None: ph("cli-optimize-run-finished", {"optimizations_found": optimizations_found}) if self.functions_checkpoint: self.functions_checkpoint.cleanup() + if hasattr(self.args, "command") and self.args.command == "optimize": + self.cleanup_replay_tests() if optimizations_found == 0: logger.info("❌ No optimizations found.") elif self.args.all: @@ -352,6 +366,11 @@ def find_leftover_instrumented_test_files(test_root: Path) -> list[Path]: file_path for file_path in test_root.rglob("*") if file_path.is_file() and pattern.match(file_path.name) ] + def cleanup_replay_tests(self) -> None: + if self.replay_tests_dir and self.replay_tests_dir.exists(): + logger.debug(f"Cleaning up replay tests directory: {self.replay_tests_dir}") + cleanup_paths([self.replay_tests_dir]) + def cleanup_temporary_paths(self) -> None: if self.current_function_optimizer: self.current_function_optimizer.cleanup_generated_files() @@ -373,18 +392,3 @@ def run_with_args(args: Namespace) -> None: optimizer.cleanup_temporary_paths() raise SystemExit from None - - -def is_pr_draft() -> bool: - """Check if the PR is draft. in the github action context.""" - try: - event_path = os.getenv("GITHUB_EVENT_PATH") - pr_number = get_pr_number() - if pr_number is not None and event_path: - with Path(event_path).open() as f: - event_data = json.load(f) - return bool(event_data["pull_request"]["draft"]) - return False # noqa - except Exception as e: - logger.warning(f"Error checking if PR is draft: {e}") - return False diff --git a/codeflash/picklepatch/pickle_patcher.py b/codeflash/picklepatch/pickle_patcher.py index 3f3236f7..5c635c0d 100644 --- a/codeflash/picklepatch/pickle_patcher.py +++ b/codeflash/picklepatch/pickle_patcher.py @@ -79,8 +79,6 @@ def _create_placeholder(obj: object, error_msg: str, path: list[str]) -> PickleP except: # noqa: E722 obj_str = f"" - print(f"Creating placeholder for {obj_type.__name__} at path {'->'.join(path) or 'root'}: {error_msg}") - placeholder = PicklePlaceholder(obj_type.__name__, obj_str, error_msg, path) # Add this type to our known unpicklable types cache diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 650e74cd..85c6f3b6 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -46,6 +46,7 @@ from codeflash.verification.verification_utils import get_test_file_path if TYPE_CHECKING: + from argparse import Namespace from types import FrameType, TracebackType @@ -123,8 +124,8 @@ def __init__( self.function_count = defaultdict(int) self.current_file_path = Path(__file__).resolve() self.ignored_qualified_functions = { - f"{self.current_file_path}:Tracer:__exit__", - f"{self.current_file_path}:Tracer:__enter__", + f"{self.current_file_path}:Tracer.__exit__", + f"{self.current_file_path}:Tracer.__enter__", } self.max_function_count = max_function_count self.config, found_config_path = parse_config_file(config_file_path) @@ -133,6 +134,7 @@ def __init__( self.ignored_functions = {"", "", "", "", "", ""} self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_") # noqa: SLF001 + self.replay_test_file_path: Path | None = None assert timeout is None or timeout > 0, "Timeout should be greater than 0" self.timeout = timeout @@ -283,6 +285,7 @@ def __exit__( with Path(test_file_path).open("w", encoding="utf8") as file: file.write(replay_test) + self.replay_test_file_path = test_file_path console.print( f"Codeflash: Traced {self.trace_count} function calls successfully and replay test created at - {test_file_path}", @@ -344,11 +347,16 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911 # someone can override the getattr method and raise an exception. I'm looking at you wrapt return + # Extract class name from co_qualname for static methods that lack self/cls + if class_name is None and "." in getattr(code, "co_qualname", ""): + qualname_parts = code.co_qualname.split(".") + if len(qualname_parts) >= 2: + class_name = qualname_parts[-2] + try: function_qualified_name = f"{file_name}:{code.co_qualname}" except AttributeError: - function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}" - + function_qualified_name = f"{file_name}:{(class_name + '.' if class_name else '')}{code.co_name}" if function_qualified_name in self.ignored_qualified_functions: return if function_qualified_name not in self.function_count: @@ -701,7 +709,7 @@ def print_stats(self, sort: str | int | tuple = -1) -> None: border_style="blue", title="[bold]Function Profile[/bold] (ordered by internal time)", title_style="cyan", - caption=f"Showing top 25 of {len(self.stats)} functions", + caption=f"Showing top {min(25, len(self.stats))} of {len(self.stats)} functions", ) table.add_column("Calls", justify="right", style="green", width=10) @@ -791,9 +799,9 @@ def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, An return self -def main() -> ArgumentParser: +def main(args: Namespace | None = None) -> ArgumentParser: parser = ArgumentParser(allow_abbrev=False) - parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to ", required=True) + parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to ", default="codeflash.trace") parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None) parser.add_argument( "--max-function-count", @@ -815,21 +823,39 @@ def main() -> ArgumentParser: "with the codeflash config. Will be auto-discovered if not specified.", default=None, ) + parser.add_argument("--trace-only", action="store_true", help="Trace and create replay tests only, don't optimize") + + if args is not None: + parsed_args = args + parsed_args.outfile = getattr(args, "output", "codeflash.trace") + parsed_args.only_functions = getattr(args, "only_functions", None) + parsed_args.max_function_count = getattr(args, "max_function_count", 100) + parsed_args.tracer_timeout = getattr(args, "timeout", None) + parsed_args.codeflash_config = getattr(args, "config_file_path", None) + parsed_args.trace_only = getattr(args, "trace_only", False) + parsed_args.module = False + + if getattr(args, "disable", False): + console.rule("Codeflash: Tracer disabled by --disable option", style="bold red") + return parser + + unknown_args = [] + else: + if not sys.argv[1:]: + parser.print_usage() + sys.exit(2) - if not sys.argv[1:]: - parser.print_usage() - sys.exit(2) - - args, unknown_args = parser.parse_known_args() - sys.argv[:] = unknown_args + parsed_args, unknown_args = parser.parse_known_args() + sys.argv[:] = unknown_args # The script that we're profiling may chdir, so capture the absolute path # to the output file at startup. - if args.outfile is not None: - args.outfile = Path(args.outfile).resolve() + if parsed_args.outfile is not None: + parsed_args.outfile = Path(parsed_args.outfile).resolve() + outfile = parsed_args.outfile if len(unknown_args) > 0: - if args.module: + if parsed_args.module: import runpy code = "run_module(modname, run_name='__main__')" @@ -848,14 +874,48 @@ def main() -> ArgumentParser: "__cached__": None, } try: - Tracer( - output=args.outfile, - functions=args.only_functions, - max_function_count=args.max_function_count, - timeout=args.tracer_timeout, - config_file_path=args.codeflash_config, + tracer = Tracer( + output=parsed_args.outfile, + functions=parsed_args.only_functions, + max_function_count=parsed_args.max_function_count, + timeout=parsed_args.tracer_timeout, + config_file_path=parsed_args.codeflash_config, command=" ".join(sys.argv), - ).runctx(code, globs, None) + ) + tracer.runctx(code, globs, None) + replay_test_path = tracer.replay_test_file_path + if not parsed_args.trace_only and replay_test_path is not None: + del tracer + + from codeflash.cli_cmds.cli import parse_args, process_pyproject_config + from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO + from codeflash.cli_cmds.console import paneled_text + from codeflash.telemetry import posthog_cf + from codeflash.telemetry.sentry import init_sentry + + sys.argv = ["codeflash", "--replay-test", str(replay_test_path)] + + args = parse_args() + paneled_text( + CODEFLASH_LOGO, + panel_args={"title": "https://codeflash.ai", "expand": False}, + text_args={"style": "bold gold3"}, + ) + + args = process_pyproject_config(args) + args.previous_checkpoint_functions = None + init_sentry(not args.disable_telemetry, exclude_errors=True) + posthog_cf.initialize_posthog(not args.disable_telemetry) + + from codeflash.optimization import optimizer + + optimizer.run_with_args(args) + + # Delete the trace file and the replay test file if they exist + if outfile: + outfile.unlink(missing_ok=True) + if replay_test_path: + replay_test_path.unlink(missing_ok=True) except BrokenPipeError as exc: # Prevent "Exception ignored" during interpreter shutdown. diff --git a/codeflash/tracing/profile_stats.py b/codeflash/tracing/profile_stats.py index 8e2fc5e2..e810c940 100644 --- a/codeflash/tracing/profile_stats.py +++ b/codeflash/tracing/profile_stats.py @@ -27,6 +27,7 @@ def create_stats(self) -> None: filename, line_number, function, + class_name, call_count_nonrecursive, num_callers, total_time_ns, @@ -34,8 +35,19 @@ def create_stats(self) -> None: callers, ) in pdata: loaded_callers = json.loads(callers) - unmapped_callers = {caller["key"]: caller["value"] for caller in loaded_callers} - self.stats[(filename, line_number, function)] = ( + unmapped_callers = {} + for caller in loaded_callers: + caller_key = caller["key"] + if isinstance(caller_key, list): + caller_key = tuple(caller_key) + elif not isinstance(caller_key, tuple): + caller_key = (caller_key,) if not isinstance(caller_key, (list, tuple)) else tuple(caller_key) + unmapped_callers[caller_key] = caller["value"] + + # Create function key with class name if present (matching tracer.py format) + function_name = f"{class_name}.{function}" if class_name else function + + self.stats[(filename, line_number, function_name)] = ( call_count_nonrecursive, num_callers, total_time_ns / time_conversion_factor if time_conversion_factor != 1 else total_time_ns, diff --git a/pyproject.toml b/pyproject.toml index bcac6575..1862bed9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -289,6 +289,14 @@ formatter-cmds = [ "uvx ruff format $file", ] +[tool.pytest.ini_options] +filterwarnings = [ + "ignore::pytest.PytestCollectionWarning", +] +markers = [ + "ci_skip: mark test to skip in CI environment", +] + [build-system] requires = ["hatchling", "uv-dynamic-versioning"] build-backend = "hatchling.build" diff --git a/tests/scripts/end_to_end_test_tracer_replay.py b/tests/scripts/end_to_end_test_tracer_replay.py index 7c1e820c..118eb1a9 100644 --- a/tests/scripts/end_to_end_test_tracer_replay.py +++ b/tests/scripts/end_to_end_test_tracer_replay.py @@ -8,9 +8,9 @@ def run_test(expected_improvement_pct: int) -> bool: config = TestConfig( trace_mode=True, min_improvement_x=0.1, - expected_unit_tests=7, + expected_unit_tests=8, coverage_expectations=[ - CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[5, 6, 7, 8, 10, 13]) + CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[6, 7, 8, 9, 11, 14]) ], ) cwd = ( diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index d0c70097..6dc1aba4 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -185,10 +185,9 @@ def validate_stdout_in_candidate(stdout: str, expected_in_stdout: list[str]) -> def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int) -> bool: - # First command: Run the tracer test_root = cwd / "tests" / (config.test_framework or "") clear_directory(test_root) - command = ["python", "-m", "codeflash.tracer", "-o", "codeflash.trace", "workload.py"] + command = ["python", "-m", "codeflash.main", "optimize", "workload.py"] process = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy() ) @@ -202,33 +201,20 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p stdout = "".join(output) if return_code != 0: - logging.error(f"Tracer command returned exit code {return_code}") + logging.error(f"Tracer with optimization command returned exit code {return_code}") return False - functions_traced = re.search(r"Traced (\d+) function calls successfully and replay test created at - (.*)$", stdout) - if not functions_traced or int(functions_traced.group(1)) != 13: - logging.error("Expected 13 traced functions") + functions_traced = re.search(r"Traced (\d+) function calls successfully", stdout) + logging.info(functions_traced.groups() if functions_traced else "No functions traced") + if not functions_traced: + logging.error("Failed to find traced functions in output") return False - - replay_test_path = pathlib.Path(functions_traced.group(2)) - if not replay_test_path.exists(): - logging.error(f"Replay test file missing at {replay_test_path}") + if int(functions_traced.group(1)) != 13: + logging.error(functions_traced.groups()) + logging.error("Expected 13 traced functions") return False - # Second command: Run optimization - command = ["python", "../../../codeflash/main.py", "--replay-test", str(replay_test_path), "--no-pr"] - process = subprocess.Popen( - command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy() - ) - - output = [] - for line in process.stdout: - logging.info(line.strip()) - output.append(line) - - return_code = process.wait() - stdout = "".join(output) - + # Validate optimization results (from optimization phase) return validate_output(stdout, return_code, expected_improvement_pct, config) diff --git a/tests/test_function_discovery.py b/tests/test_function_discovery.py index 38a616cd..291b4270 100644 --- a/tests/test_function_discovery.py +++ b/tests/test_function_discovery.py @@ -129,7 +129,7 @@ def functionA(): tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() ) path_obj_name = Path(f.name) - functions, functions_count = get_functions_to_optimize( + functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, replay_test=None, file=path_obj_name, @@ -145,7 +145,7 @@ def functionA(): assert functions[file][0].function_name == "functionA" assert functions[file][0].top_level_parent_name == "A" - functions, functions_count = get_functions_to_optimize( + functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, replay_test=None, file=path_obj_name, @@ -161,7 +161,7 @@ def functionA(): assert functions[file][0].function_name == "functionA" assert functions[file][0].top_level_parent_name == "X" - functions, functions_count = get_functions_to_optimize( + functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, replay_test=None, file=path_obj_name, @@ -229,7 +229,7 @@ def traverse(node_id): tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() ) path_obj_name = Path(f.name) - functions, functions_count = get_functions_to_optimize( + functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, replay_test=None, file=path_obj_name, @@ -258,7 +258,7 @@ def inner_function(): tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() ) path_obj_name = Path(f.name) - functions, functions_count = get_functions_to_optimize( + functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, replay_test=None, file=path_obj_name, @@ -289,7 +289,7 @@ def another_inner_function(): tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() ) path_obj_name = Path(f.name) - functions, functions_count = get_functions_to_optimize( + functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, replay_test=None, file=path_obj_name, diff --git a/tests/test_function_ranker.py b/tests/test_function_ranker.py new file mode 100644 index 00000000..0cb1bb77 --- /dev/null +++ b/tests/test_function_ranker.py @@ -0,0 +1,172 @@ +import pytest +from pathlib import Path +from unittest.mock import patch + +from codeflash.benchmarking.function_ranker import FunctionRanker +from codeflash.discovery.functions_to_optimize import FunctionToOptimize, find_all_functions_in_file +from codeflash.models.models import FunctionParent + + +@pytest.fixture +def trace_file(): + return Path(__file__).parent.parent / "code_to_optimize/code_directories/simple_tracer_e2e/codeflash.trace" + + +@pytest.fixture +def workload_functions(): + workloads_file = Path(__file__).parent.parent / "code_to_optimize/code_directories/simple_tracer_e2e/workload.py" + functions_dict = find_all_functions_in_file(workloads_file) + all_functions = [] + for functions_list in functions_dict.values(): + all_functions.extend(functions_list) + return all_functions + + +@pytest.fixture +def function_ranker(trace_file): + return FunctionRanker(trace_file) + + +def test_function_ranker_initialization(trace_file): + ranker = FunctionRanker(trace_file) + assert ranker.trace_file_path == trace_file + assert ranker._profile_stats is not None + assert isinstance(ranker._function_stats, dict) + + +def test_load_function_stats(function_ranker): + assert len(function_ranker._function_stats) > 0 + + # Check that funcA is loaded with expected structure + func_a_key = None + for key, stats in function_ranker._function_stats.items(): + if stats["function_name"] == "funcA": + func_a_key = key + break + + assert func_a_key is not None + func_a_stats = function_ranker._function_stats[func_a_key] + + # Verify funcA stats structure + expected_keys = { + "filename", "function_name", "qualified_name", "class_name", + "line_number", "call_count", "own_time_ns", "cumulative_time_ns", + "time_in_callees_ns", "ttx_score" + } + assert set(func_a_stats.keys()) == expected_keys + + # Verify funcA specific values + assert func_a_stats["function_name"] == "funcA" + assert func_a_stats["call_count"] == 1 + assert func_a_stats["own_time_ns"] == 63000 + assert func_a_stats["cumulative_time_ns"] == 5443000 + + +def test_get_function_ttx_score(function_ranker, workload_functions): + func_a = None + for func in workload_functions: + if func.function_name == "funcA": + func_a = func + break + + assert func_a is not None + ttx_score = function_ranker.get_function_ttx_score(func_a) + + # Expected ttX score: own_time + (time_in_callees / call_count) + # = 63000 + ((5443000 - 63000) / 1) = 5443000 + assert ttx_score == 5443000 + + +def test_rank_functions(function_ranker, workload_functions): + ranked_functions = function_ranker.rank_functions(workload_functions) + + assert len(ranked_functions) == len(workload_functions) + + # Verify functions are sorted by ttX score in descending order + for i in range(len(ranked_functions) - 1): + current_score = function_ranker.get_function_ttx_score(ranked_functions[i]) + next_score = function_ranker.get_function_ttx_score(ranked_functions[i + 1]) + assert current_score >= next_score + + +def test_rerank_and_filter_functions(function_ranker, workload_functions): + filtered_ranked = function_ranker.rerank_and_filter_functions(workload_functions) + + # Should filter out functions below importance threshold + assert len(filtered_ranked) <= len(workload_functions) + + # funcA should pass the importance threshold (0.33% > 0.1%) + func_a_in_results = any(f.function_name == "funcA" for f in filtered_ranked) + assert func_a_in_results + + +def test_get_function_stats_summary(function_ranker, workload_functions): + func_a = None + for func in workload_functions: + if func.function_name == "funcA": + func_a = func + break + + assert func_a is not None + stats = function_ranker.get_function_stats_summary(func_a) + + assert stats is not None + assert stats["function_name"] == "funcA" + assert stats["own_time_ns"] == 63000 + assert stats["cumulative_time_ns"] == 5443000 + assert stats["ttx_score"] == 5443000 + + + + +def test_importance_calculation(function_ranker): + total_program_time = sum( + s["own_time_ns"] for s in function_ranker._function_stats.values() + if s.get("own_time_ns", 0) > 0 + ) + + func_a_stats = None + for stats in function_ranker._function_stats.values(): + if stats["function_name"] == "funcA": + func_a_stats = stats + break + + assert func_a_stats is not None + importance = func_a_stats["own_time_ns"] / total_program_time + + # funcA importance should be approximately 0.57% (63000/10968000) + assert abs(importance - 0.0057) < 0.001 + + +def test_simple_model_predict_stats(function_ranker, workload_functions): + # Find SimpleModel::predict function + predict_func = None + for func in workload_functions: + if func.function_name == "predict": + predict_func = func + break + + assert predict_func is not None + + stats = function_ranker.get_function_stats_summary(predict_func) + assert stats is not None + assert stats["function_name"] == "predict" + assert stats["call_count"] == 1 + assert stats["own_time_ns"] == 2289000 + assert stats["cumulative_time_ns"] == 4017000 + assert stats["ttx_score"] == 4017000 + + # Test ttX score calculation + ttx_score = function_ranker.get_function_ttx_score(predict_func) + # Expected ttX score: own_time + (time_in_callees / call_count) + # = 2289000 + ((4017000 - 2289000) / 1) = 4017000 + assert ttx_score == 4017000 + + # Test importance calculation for predict function + total_program_time = sum( + s["own_time_ns"] for s in function_ranker._function_stats.values() + if s.get("own_time_ns", 0) > 0 + ) + importance = stats["own_time_ns"] / total_program_time + # predict importance should be approximately 20.9% (2289000/10968000) + assert abs(importance - 0.209) < 0.01