Skip to content

introduce a new integrated "codeflash optimize" command #384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4debe7e
introduce a new integrated "codeflash optimize" command
misrasaurabh1 Jun 26, 2025
535a9b1
Merge branch 'main' into trace-and-optimize
KRRT7 Jun 26, 2025
09bf156
Merge branch 'main' into trace-and-optimize
KRRT7 Jun 27, 2025
0b4fcb6
rank functions
KRRT7 Jun 28, 2025
059b4dc
Merge branch 'main' into trace-and-optimize
KRRT7 Jun 30, 2025
7f9a609
implement reranker
KRRT7 Jun 30, 2025
eb9e0c6
allow predict to be included
KRRT7 Jul 1, 2025
ce68cad
fix tracer for static methods
KRRT7 Jul 1, 2025
b7258a9
Merge branch 'main' into trace-and-optimize
KRRT7 Jul 1, 2025
72b51c1
⚡️ Speed up method `FunctionRanker._get_function_stats` by 51% in PR …
codeflash-ai[bot] Jul 1, 2025
67bd717
Merge pull request #466 from codeflash-ai/codeflash/optimize-pr384-20…
misrasaurabh1 Jul 1, 2025
ea16342
update tests
KRRT7 Jul 1, 2025
947ab07
don't let the AI replicate
KRRT7 Jul 2, 2025
4823ee5
Merge branch 'main' into trace-and-optimize
KRRT7 Jul 2, 2025
faebe9b
ruff
KRRT7 Jul 2, 2025
a0e57ba
mypy-ruff
KRRT7 Jul 2, 2025
fd1e492
silence test collection warnings
KRRT7 Jul 2, 2025
f7c8a6b
Update function_ranker.py
KRRT7 Jul 2, 2025
35059a9
Update workload.py
KRRT7 Jul 2, 2025
f74d947
update CI
KRRT7 Jul 2, 2025
9addd95
update cov numbers
KRRT7 Jul 2, 2025
70cecaf
rank only, change formula
KRRT7 Jul 3, 2025
96acfc7
per module ranking
KRRT7 Jul 3, 2025
e5e1ff0
update tests
KRRT7 Jul 3, 2025
eba8cb8
move to env utils, pre-commit
KRRT7 Jul 3, 2025
9955081
Merge branch 'main' of https://github.com/codeflash-ai/codeflash into…
KRRT7 Jul 3, 2025
692f46e
Merge branch 'main' into trace-and-optimize
KRRT7 Jul 3, 2025
e2e6803
add markers
KRRT7 Jul 3, 2025
4560b8b
Merge branch 'main' into trace-and-optimize
KRRT7 Jul 3, 2025
39e0859
Update cli.py
KRRT7 Jul 3, 2025
c09f32e
Revert "Update cli.py"
KRRT7 Jul 3, 2025
60922b8
allow args for the optimize command too
KRRT7 Jul 3, 2025
bf6313f
fix parsing
KRRT7 Jul 4, 2025
87f44a2
fix parsing
KRRT7 Jul 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
11 changes: 10 additions & 1 deletion code_to_optimize/code_directories/simple_tracer_e2e/workload.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from concurrent.futures import ThreadPoolExecutor
from time import sleep


def funcA(number):
Expand Down Expand Up @@ -46,12 +47,20 @@ def _classify(self, features):
class SimpleModel:
@staticmethod
def predict(data):
return [x * 2 for x in data]
result = []
sleep(0.1) # can be optimized away
for i in range(500):
for x in data:
computation = 0
computation += x * i ** 2
result.append(computation)
return result

@classmethod
def create_default(cls):
return cls()


def test_models():
model = AlexNet(num_classes=10)
input_data = [1, 2, 3, 4, 5]
Expand Down
158 changes: 158 additions & 0 deletions codeflash/benchmarking/function_ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.config_consts import DEFAULT_IMPORTANCE_THRESHOLD
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.tracing.profile_stats import ProfileStats

if TYPE_CHECKING:
from pathlib import Path

from codeflash.discovery.functions_to_optimize import FunctionToOptimize


class FunctionRanker:
"""Ranks and filters functions based on a ttX score derived from profiling data.

The ttX score is calculated as:
ttX = own_time + (time_spent_in_callees / call_count)

This score prioritizes functions that are computationally heavy themselves (high `own_time`)
or that make expensive calls to other functions (high average `time_spent_in_callees`).

Functions are first filtered by an importance threshold based on their `own_time` as a
fraction of the total runtime. The remaining functions are then ranked by their ttX score
to identify the best candidates for optimization.
"""

def __init__(self, trace_file_path: Path) -> None:
self.trace_file_path = trace_file_path
self._profile_stats = ProfileStats(trace_file_path.as_posix())
self._function_stats: dict[str, dict] = {}
self.load_function_stats()

def load_function_stats(self) -> None:
try:
for (filename, line_number, func_name), (
call_count,
_num_callers,
total_time_ns,
cumulative_time_ns,
_callers,
) in self._profile_stats.stats.items():
if call_count <= 0:
continue

# Parse function name to handle methods within classes
class_name, qualified_name, base_function_name = (None, func_name, func_name)
if "." in func_name and not func_name.startswith("<"):
parts = func_name.split(".", 1)
if len(parts) == 2:
class_name, base_function_name = parts

# Calculate own time (total time - time spent in subcalls)
own_time_ns = total_time_ns
time_in_callees_ns = cumulative_time_ns - total_time_ns

# Calculate ttX score
ttx_score = own_time_ns + (time_in_callees_ns / call_count)

function_key = f"{filename}:{qualified_name}"
self._function_stats[function_key] = {
"filename": filename,
"function_name": base_function_name,
"qualified_name": qualified_name,
"class_name": class_name,
"line_number": line_number,
"call_count": call_count,
"own_time_ns": own_time_ns,
"cumulative_time_ns": cumulative_time_ns,
"time_in_callees_ns": time_in_callees_ns,
"ttx_score": ttx_score,
}

logger.debug(f"Loaded timing stats for {len(self._function_stats)} functions from trace using ProfileStats")

except Exception as e:
logger.warning(f"Failed to process function stats from trace file {self.trace_file_path}: {e}")
self._function_stats = {}

def _get_function_stats(self, function_to_optimize: FunctionToOptimize) -> dict | None:
target_filename = function_to_optimize.file_path.name
for key, stats in self._function_stats.items():
if stats.get("function_name") == function_to_optimize.function_name and (
key.endswith(f"/{target_filename}") or target_filename in key
):
return stats

logger.debug(
f"Could not find stats for function {function_to_optimize.function_name} in file {target_filename}"
)
return None

def get_function_ttx_score(self, function_to_optimize: FunctionToOptimize) -> float:
stats = self._get_function_stats(function_to_optimize)
return stats["ttx_score"] if stats else 0.0

def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
ranked = sorted(functions_to_optimize, key=self.get_function_ttx_score, reverse=True)
logger.debug(
f"Function ranking order: {[f'{func.function_name} (ttX={self.get_function_ttx_score(func):.2f})' for func in ranked]}"
)
return ranked

def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
return self._get_function_stats(function_to_optimize)

def rerank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
"""Ranks functions based on their ttX score.

This method calculates the ttX score for each function and returns
the functions sorted in descending order of their ttX score.
"""
if not self._function_stats:
logger.warning("No function stats available to rank functions.")
return []

return self.rank_functions(functions_to_optimize)

def rerank_and_filter_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
"""Reranks and filters functions based on their impact on total runtime.

This method first calculates the total runtime of all profiled functions.
It then filters out functions whose own_time is less than a specified
percentage of the total runtime (importance_threshold).

The remaining 'important' functions are then ranked by their ttX score.
"""
stats_map = self._function_stats
if not stats_map:
return []

total_program_time = sum(s["own_time_ns"] for s in stats_map.values() if s.get("own_time_ns", 0) > 0)

if total_program_time == 0:
logger.warning("Total program time is zero, cannot determine function importance.")
return self.rank_functions(functions_to_optimize)

important_functions = []
for func in functions_to_optimize:
func_stats = self._get_function_stats(func)
if func_stats and func_stats.get("own_time_ns", 0) > 0:
importance = func_stats["own_time_ns"] / total_program_time
if importance >= DEFAULT_IMPORTANCE_THRESHOLD:
important_functions.append(func)
else:
logger.debug(
f"Filtering out function {func.qualified_name} with importance "
f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})"
)

logger.info(
f"Filtered down to {len(important_functions)} important functions from {len(functions_to_optimize)} total functions"
)
console.rule()

return self.rank_functions(important_functions)
11 changes: 10 additions & 1 deletion codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ def parse_args() -> Namespace:

init_actions_parser = subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
init_actions_parser.set_defaults(func=install_github_actions)

trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize a Python project.")
from codeflash.tracer import main as tracer_main

trace_optimize.set_defaults(func=tracer_main)

parser.add_argument("--file", help="Try to optimize only this file")
parser.add_argument("--function", help="Try to optimize only this function within the given file path")
parser.add_argument(
Expand Down Expand Up @@ -64,7 +70,8 @@ def parse_args() -> Namespace:
)
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")

args: Namespace = parser.parse_args()
args, unknown_args = parser.parse_known_args()
sys.argv[:] = [sys.argv[0], *unknown_args]
return process_and_validate_cmd_args(args)


Expand Down Expand Up @@ -102,6 +109,8 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
if not Path(test_path).is_file():
exit_with_message(f"Replay test file {test_path} does not exist", error_on_exit=True)
args.replay_test = [Path(replay_test).resolve() for replay_test in args.replay_test]
if env_utils.is_ci():
args.no_pr = True

return args

Expand Down
1 change: 1 addition & 0 deletions codeflash/code_utils/config_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
COVERAGE_THRESHOLD = 60.0
MIN_TESTCASE_PASSED_THRESHOLD = 6
REPEAT_OPTIMIZATION_PROBABILITY = 0.1
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
21 changes: 21 additions & 0 deletions codeflash/code_utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,27 @@ def get_cached_gh_event_data() -> dict[str, Any] | None:
return json.load(f) # type: ignore # noqa


@lru_cache(maxsize=1)
def is_ci() -> bool:
"""Check if running in a CI environment."""
return bool(os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS"))


@lru_cache(maxsize=1)
def is_LSP_enabled() -> bool:
return console.quiet


def is_pr_draft() -> bool:
"""Check if the PR is draft. in the github action context."""
try:
event_path = os.getenv("GITHUB_EVENT_PATH")
pr_number = get_pr_number()
if pr_number is not None and event_path:
with Path(event_path).open() as f:
event_data = json.load(f)
return bool(event_data["pull_request"]["draft"])
return False # noqa
except Exception as e:
logger.warning(f"Error checking if PR is draft: {e}")
return False
39 changes: 34 additions & 5 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,20 @@ def get_functions_to_optimize(
project_root: Path,
module_root: Path,
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
"Only one of optimize_all, replay_test, or file should be provided"
)
functions: dict[str, list[FunctionToOptimize]]
trace_file_path: Path | None = None
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=SyntaxWarning)
if optimize_all:
logger.info("Finding all functions in the module '%s'…", optimize_all)
console.rule()
functions = get_all_files_and_functions(Path(optimize_all))
elif replay_test:
functions = get_all_replay_test_functions(
functions, trace_file_path = get_all_replay_test_functions(
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
)
elif file is not None:
Expand Down Expand Up @@ -206,6 +207,7 @@ def get_functions_to_optimize(
filtered_modified_functions, functions_count = filter_functions(
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
)

logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
if optimize_all:
three_min_in_ns = int(1.8e11)
Expand All @@ -214,7 +216,7 @@ def get_functions_to_optimize(
f"It might take about {humanize_runtime(functions_count * three_min_in_ns)} to fully optimize this project. Codeflash "
f"will keep opening pull requests as it finds optimizations."
)
return filtered_modified_functions, functions_count
return filtered_modified_functions, functions_count, trace_file_path


def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]:
Expand Down Expand Up @@ -272,7 +274,34 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt

def get_all_replay_test_functions(
replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path
) -> dict[Path, list[FunctionToOptimize]]:
) -> tuple[dict[Path, list[FunctionToOptimize]], Path]:
trace_file_path: Path | None = None
for replay_test_file in replay_test:
try:
with replay_test_file.open("r", encoding="utf8") as f:
tree = ast.parse(f.read())
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if (
isinstance(target, ast.Name)
and target.id == "trace_file_path"
and isinstance(node.value, ast.Constant)
and isinstance(node.value.value, str)
):
trace_file_path = Path(node.value.value)
break
if trace_file_path:
break
if trace_file_path:
break
except Exception as e:
logger.warning(f"Error parsing replay test file {replay_test_file}: {e}")

if not trace_file_path:
logger.error("Could not find trace_file_path in replay test files.")
exit_with_message("Could not find trace_file_path in replay test files.")

function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=replay_test)
# Get the absolute file paths for each function, excluding class name if present
filtered_valid_functions = defaultdict(list)
Expand Down Expand Up @@ -317,7 +346,7 @@ def get_all_replay_test_functions(
if filtered_list:
filtered_valid_functions[file_path] = filtered_list

return filtered_valid_functions
return filtered_valid_functions, trace_file_path


def is_git_repo(file_path: str) -> bool:
Expand Down
Loading
Loading