Skip to content

Commit 75810a3

Browse files
Merge pull request #384 from codeflash-ai/trace-and-optimize
introduce a new integrated "codeflash optimize" command
2 parents a54be51 + 87f44a2 commit 75810a3

File tree

16 files changed

+578
-87
lines changed

16 files changed

+578
-87
lines changed
Binary file not shown.

code_to_optimize/code_directories/simple_tracer_e2e/workload.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from concurrent.futures import ThreadPoolExecutor
2+
from time import sleep
23

34

45
def funcA(number):
@@ -46,12 +47,20 @@ def _classify(self, features):
4647
class SimpleModel:
4748
@staticmethod
4849
def predict(data):
49-
return [x * 2 for x in data]
50+
result = []
51+
sleep(0.1) # can be optimized away
52+
for i in range(500):
53+
for x in data:
54+
computation = 0
55+
computation += x * i ** 2
56+
result.append(computation)
57+
return result
5058

5159
@classmethod
5260
def create_default(cls):
5361
return cls()
5462

63+
5564
def test_models():
5665
model = AlexNet(num_classes=10)
5766
input_data = [1, 2, 3, 4, 5]
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from codeflash.cli_cmds.console import console, logger
6+
from codeflash.code_utils.config_consts import DEFAULT_IMPORTANCE_THRESHOLD
7+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
8+
from codeflash.tracing.profile_stats import ProfileStats
9+
10+
if TYPE_CHECKING:
11+
from pathlib import Path
12+
13+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
14+
15+
16+
class FunctionRanker:
17+
"""Ranks and filters functions based on a ttX score derived from profiling data.
18+
19+
The ttX score is calculated as:
20+
ttX = own_time + (time_spent_in_callees / call_count)
21+
22+
This score prioritizes functions that are computationally heavy themselves (high `own_time`)
23+
or that make expensive calls to other functions (high average `time_spent_in_callees`).
24+
25+
Functions are first filtered by an importance threshold based on their `own_time` as a
26+
fraction of the total runtime. The remaining functions are then ranked by their ttX score
27+
to identify the best candidates for optimization.
28+
"""
29+
30+
def __init__(self, trace_file_path: Path) -> None:
31+
self.trace_file_path = trace_file_path
32+
self._profile_stats = ProfileStats(trace_file_path.as_posix())
33+
self._function_stats: dict[str, dict] = {}
34+
self.load_function_stats()
35+
36+
def load_function_stats(self) -> None:
37+
try:
38+
for (filename, line_number, func_name), (
39+
call_count,
40+
_num_callers,
41+
total_time_ns,
42+
cumulative_time_ns,
43+
_callers,
44+
) in self._profile_stats.stats.items():
45+
if call_count <= 0:
46+
continue
47+
48+
# Parse function name to handle methods within classes
49+
class_name, qualified_name, base_function_name = (None, func_name, func_name)
50+
if "." in func_name and not func_name.startswith("<"):
51+
parts = func_name.split(".", 1)
52+
if len(parts) == 2:
53+
class_name, base_function_name = parts
54+
55+
# Calculate own time (total time - time spent in subcalls)
56+
own_time_ns = total_time_ns
57+
time_in_callees_ns = cumulative_time_ns - total_time_ns
58+
59+
# Calculate ttX score
60+
ttx_score = own_time_ns + (time_in_callees_ns / call_count)
61+
62+
function_key = f"{filename}:{qualified_name}"
63+
self._function_stats[function_key] = {
64+
"filename": filename,
65+
"function_name": base_function_name,
66+
"qualified_name": qualified_name,
67+
"class_name": class_name,
68+
"line_number": line_number,
69+
"call_count": call_count,
70+
"own_time_ns": own_time_ns,
71+
"cumulative_time_ns": cumulative_time_ns,
72+
"time_in_callees_ns": time_in_callees_ns,
73+
"ttx_score": ttx_score,
74+
}
75+
76+
logger.debug(f"Loaded timing stats for {len(self._function_stats)} functions from trace using ProfileStats")
77+
78+
except Exception as e:
79+
logger.warning(f"Failed to process function stats from trace file {self.trace_file_path}: {e}")
80+
self._function_stats = {}
81+
82+
def _get_function_stats(self, function_to_optimize: FunctionToOptimize) -> dict | None:
83+
target_filename = function_to_optimize.file_path.name
84+
for key, stats in self._function_stats.items():
85+
if stats.get("function_name") == function_to_optimize.function_name and (
86+
key.endswith(f"/{target_filename}") or target_filename in key
87+
):
88+
return stats
89+
90+
logger.debug(
91+
f"Could not find stats for function {function_to_optimize.function_name} in file {target_filename}"
92+
)
93+
return None
94+
95+
def get_function_ttx_score(self, function_to_optimize: FunctionToOptimize) -> float:
96+
stats = self._get_function_stats(function_to_optimize)
97+
return stats["ttx_score"] if stats else 0.0
98+
99+
def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
100+
ranked = sorted(functions_to_optimize, key=self.get_function_ttx_score, reverse=True)
101+
logger.debug(
102+
f"Function ranking order: {[f'{func.function_name} (ttX={self.get_function_ttx_score(func):.2f})' for func in ranked]}"
103+
)
104+
return ranked
105+
106+
def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
107+
return self._get_function_stats(function_to_optimize)
108+
109+
def rerank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
110+
"""Ranks functions based on their ttX score.
111+
112+
This method calculates the ttX score for each function and returns
113+
the functions sorted in descending order of their ttX score.
114+
"""
115+
if not self._function_stats:
116+
logger.warning("No function stats available to rank functions.")
117+
return []
118+
119+
return self.rank_functions(functions_to_optimize)
120+
121+
def rerank_and_filter_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
122+
"""Reranks and filters functions based on their impact on total runtime.
123+
124+
This method first calculates the total runtime of all profiled functions.
125+
It then filters out functions whose own_time is less than a specified
126+
percentage of the total runtime (importance_threshold).
127+
128+
The remaining 'important' functions are then ranked by their ttX score.
129+
"""
130+
stats_map = self._function_stats
131+
if not stats_map:
132+
return []
133+
134+
total_program_time = sum(s["own_time_ns"] for s in stats_map.values() if s.get("own_time_ns", 0) > 0)
135+
136+
if total_program_time == 0:
137+
logger.warning("Total program time is zero, cannot determine function importance.")
138+
return self.rank_functions(functions_to_optimize)
139+
140+
important_functions = []
141+
for func in functions_to_optimize:
142+
func_stats = self._get_function_stats(func)
143+
if func_stats and func_stats.get("own_time_ns", 0) > 0:
144+
importance = func_stats["own_time_ns"] / total_program_time
145+
if importance >= DEFAULT_IMPORTANCE_THRESHOLD:
146+
important_functions.append(func)
147+
else:
148+
logger.debug(
149+
f"Filtering out function {func.qualified_name} with importance "
150+
f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})"
151+
)
152+
153+
logger.info(
154+
f"Filtered down to {len(important_functions)} important functions from {len(functions_to_optimize)} total functions"
155+
)
156+
console.rule()
157+
158+
return self.rank_functions(important_functions)

codeflash/cli_cmds/cli.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,36 @@ def parse_args() -> Namespace:
2222

2323
init_actions_parser = subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
2424
init_actions_parser.set_defaults(func=install_github_actions)
25+
26+
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize a Python project.")
27+
28+
from codeflash.tracer import main as tracer_main
29+
30+
trace_optimize.set_defaults(func=tracer_main)
31+
32+
trace_optimize.add_argument(
33+
"--max-function-count",
34+
type=int,
35+
default=100,
36+
help="The maximum number of times to trace a single function. More calls to a function will not be traced. Default is 100.",
37+
)
38+
trace_optimize.add_argument(
39+
"--timeout",
40+
type=int,
41+
help="The maximum time in seconds to trace the entire workflow. Default is indefinite. This is useful while tracing really long workflows, to not wait indefinitely.",
42+
)
43+
trace_optimize.add_argument(
44+
"--output",
45+
type=str,
46+
default="codeflash.trace",
47+
help="The file to save the trace to. Default is codeflash.trace.",
48+
)
49+
trace_optimize.add_argument(
50+
"--config-file-path",
51+
type=str,
52+
help="The path to the pyproject.toml file which stores the Codeflash config. This is auto-discovered by default.",
53+
)
54+
2555
parser.add_argument("--file", help="Try to optimize only this file")
2656
parser.add_argument("--function", help="Try to optimize only this function within the given file path")
2757
parser.add_argument(
@@ -64,7 +94,8 @@ def parse_args() -> Namespace:
6494
)
6595
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
6696

67-
args: Namespace = parser.parse_args()
97+
args, unknown_args = parser.parse_known_args()
98+
sys.argv[:] = [sys.argv[0], *unknown_args]
6899
return process_and_validate_cmd_args(args)
69100

70101

@@ -102,6 +133,8 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
102133
if not Path(test_path).is_file():
103134
exit_with_message(f"Replay test file {test_path} does not exist", error_on_exit=True)
104135
args.replay_test = [Path(replay_test).resolve() for replay_test in args.replay_test]
136+
if env_utils.is_ci():
137+
args.no_pr = True
105138

106139
return args
107140

codeflash/code_utils/config_consts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
COVERAGE_THRESHOLD = 60.0
1111
MIN_TESTCASE_PASSED_THRESHOLD = 6
1212
REPEAT_OPTIMIZATION_PROBABILITY = 0.1
13+
DEFAULT_IMPORTANCE_THRESHOLD = 0.001

codeflash/code_utils/env_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,27 @@ def get_cached_gh_event_data() -> dict[str, Any] | None:
112112
return json.load(f) # type: ignore # noqa
113113

114114

115+
@lru_cache(maxsize=1)
116+
def is_ci() -> bool:
117+
"""Check if running in a CI environment."""
118+
return bool(os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS"))
119+
120+
115121
@lru_cache(maxsize=1)
116122
def is_LSP_enabled() -> bool:
117123
return console.quiet
124+
125+
126+
def is_pr_draft() -> bool:
127+
"""Check if the PR is draft. in the github action context."""
128+
try:
129+
event_path = os.getenv("GITHUB_EVENT_PATH")
130+
pr_number = get_pr_number()
131+
if pr_number is not None and event_path:
132+
with Path(event_path).open() as f:
133+
event_data = json.load(f)
134+
return bool(event_data["pull_request"]["draft"])
135+
return False # noqa
136+
except Exception as e:
137+
logger.warning(f"Error checking if PR is draft: {e}")
138+
return False

codeflash/discovery/functions_to_optimize.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,19 +160,20 @@ def get_functions_to_optimize(
160160
project_root: Path,
161161
module_root: Path,
162162
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
163-
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
163+
) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
164164
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
165165
"Only one of optimize_all, replay_test, or file should be provided"
166166
)
167167
functions: dict[str, list[FunctionToOptimize]]
168+
trace_file_path: Path | None = None
168169
with warnings.catch_warnings():
169170
warnings.simplefilter(action="ignore", category=SyntaxWarning)
170171
if optimize_all:
171172
logger.info("Finding all functions in the module '%s'…", optimize_all)
172173
console.rule()
173174
functions = get_all_files_and_functions(Path(optimize_all))
174175
elif replay_test:
175-
functions = get_all_replay_test_functions(
176+
functions, trace_file_path = get_all_replay_test_functions(
176177
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
177178
)
178179
elif file is not None:
@@ -208,6 +209,7 @@ def get_functions_to_optimize(
208209
filtered_modified_functions, functions_count = filter_functions(
209210
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
210211
)
212+
211213
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
212214
if optimize_all:
213215
three_min_in_ns = int(1.8e11)
@@ -216,7 +218,7 @@ def get_functions_to_optimize(
216218
f"It might take about {humanize_runtime(functions_count * three_min_in_ns)} to fully optimize this project. Codeflash "
217219
f"will keep opening pull requests as it finds optimizations."
218220
)
219-
return filtered_modified_functions, functions_count
221+
return filtered_modified_functions, functions_count, trace_file_path
220222

221223

222224
def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]:
@@ -274,7 +276,34 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
274276

275277
def get_all_replay_test_functions(
276278
replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path
277-
) -> dict[Path, list[FunctionToOptimize]]:
279+
) -> tuple[dict[Path, list[FunctionToOptimize]], Path]:
280+
trace_file_path: Path | None = None
281+
for replay_test_file in replay_test:
282+
try:
283+
with replay_test_file.open("r", encoding="utf8") as f:
284+
tree = ast.parse(f.read())
285+
for node in ast.walk(tree):
286+
if isinstance(node, ast.Assign):
287+
for target in node.targets:
288+
if (
289+
isinstance(target, ast.Name)
290+
and target.id == "trace_file_path"
291+
and isinstance(node.value, ast.Constant)
292+
and isinstance(node.value.value, str)
293+
):
294+
trace_file_path = Path(node.value.value)
295+
break
296+
if trace_file_path:
297+
break
298+
if trace_file_path:
299+
break
300+
except Exception as e:
301+
logger.warning(f"Error parsing replay test file {replay_test_file}: {e}")
302+
303+
if not trace_file_path:
304+
logger.error("Could not find trace_file_path in replay test files.")
305+
exit_with_message("Could not find trace_file_path in replay test files.")
306+
278307
function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=replay_test)
279308
# Get the absolute file paths for each function, excluding class name if present
280309
filtered_valid_functions = defaultdict(list)
@@ -319,7 +348,7 @@ def get_all_replay_test_functions(
319348
if filtered_list:
320349
filtered_valid_functions[file_path] = filtered_list
321350

322-
return filtered_valid_functions
351+
return filtered_valid_functions, trace_file_path
323352

324353

325354
def is_git_repo(file_path: str) -> bool:

0 commit comments

Comments
 (0)