Skip to content

Commit 08464c4

Browse files
committed
rank functions
1 parent 09bf156 commit 08464c4

File tree

2 files changed

+190
-3
lines changed

2 files changed

+190
-3
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from __future__ import annotations
2+
3+
import sqlite3
4+
from typing import TYPE_CHECKING
5+
6+
from codeflash.cli_cmds.console import logger
7+
8+
if TYPE_CHECKING:
9+
from pathlib import Path
10+
11+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
12+
13+
14+
class FunctionRanker:
15+
"""Ranks functions for optimization based on trace data using ttX scoring.
16+
17+
ttX = own_time + (time_spent_in_callees x call_count)
18+
19+
This prioritizes functions that:
20+
1. Take significant time themselves (own_time)
21+
2. Are called frequently and have expensive subcalls (time_spent_in_callees x call_count)
22+
"""
23+
24+
def __init__(self, trace_file_path: Path) -> None:
25+
self.trace_file_path = trace_file_path
26+
self._function_stats = None
27+
28+
def _load_function_stats(self) -> dict[str, dict]:
29+
"""Load function timing statistics from trace database."""
30+
if self._function_stats is not None:
31+
return self._function_stats
32+
33+
self._function_stats = {}
34+
35+
try:
36+
with sqlite3.connect(self.trace_file_path) as conn:
37+
cursor = conn.cursor()
38+
39+
cursor.execute("""
40+
SELECT
41+
filename,
42+
line_number,
43+
function,
44+
class_name,
45+
call_count_nonrecursive,
46+
total_time_ns,
47+
cumulative_time_ns
48+
FROM pstats
49+
WHERE call_count_nonrecursive > 0
50+
""")
51+
52+
for row in cursor.fetchall():
53+
filename, line_number, function_name, class_name, call_count, total_time_ns, cumulative_time_ns = (
54+
row
55+
)
56+
57+
if class_name and class_name.strip():
58+
qualified_name = f"{class_name}.{function_name}"
59+
else:
60+
qualified_name = function_name
61+
62+
# Calculate own time (total time - time spent in subcalls)
63+
own_time_ns = total_time_ns
64+
time_in_callees_ns = cumulative_time_ns - total_time_ns
65+
66+
# Calculate ttX score
67+
ttx_score = own_time_ns + (time_in_callees_ns * call_count)
68+
69+
function_key = f"{filename}:{qualified_name}"
70+
self._function_stats[function_key] = {
71+
"filename": filename,
72+
"function_name": function_name,
73+
"qualified_name": qualified_name,
74+
"class_name": class_name,
75+
"line_number": line_number,
76+
"call_count": call_count,
77+
"own_time_ns": own_time_ns,
78+
"cumulative_time_ns": cumulative_time_ns,
79+
"time_in_callees_ns": time_in_callees_ns,
80+
"ttx_score": ttx_score,
81+
}
82+
83+
logger.debug(f"Loaded timing stats for {len(self._function_stats)} functions from trace")
84+
85+
except Exception as e:
86+
logger.warning(f"Failed to load function stats from trace file {self.trace_file_path}: {e}")
87+
self._function_stats = {}
88+
89+
return self._function_stats
90+
91+
def get_function_ttx_score(self, function_to_optimize: FunctionToOptimize) -> float:
92+
stats = self._load_function_stats()
93+
94+
possible_keys = [
95+
f"{function_to_optimize.file_path}:{function_to_optimize.qualified_name}",
96+
f"{function_to_optimize.file_path}:{function_to_optimize.function_name}",
97+
]
98+
99+
for key in possible_keys:
100+
if key in stats:
101+
return stats[key]["ttx_score"]
102+
103+
# If not found in trace data, return 0 (will be ranked last)
104+
return 0.0
105+
106+
def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
107+
# Calculate ttX scores for all functions
108+
function_scores = []
109+
for func in functions_to_optimize:
110+
ttx_score = self.get_function_ttx_score(func)
111+
function_scores.append((func, ttx_score))
112+
113+
# Sort by ttX score descending (highest impact first)
114+
function_scores.sort(key=lambda x: x[1], reverse=True)
115+
116+
logger.info("Function ranking by ttX score:")
117+
for i, (func, score) in enumerate(function_scores[:10]): # Top 10
118+
logger.info(f" {i + 1}. {func.qualified_name} (ttX: {score:.0f}ns)")
119+
120+
ranked_functions = [func for func, _ in function_scores]
121+
logger.info(f"Ranked {len(ranked_functions)} functions by optimization priority")
122+
123+
return ranked_functions
124+
125+
def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
126+
stats = self._load_function_stats()
127+
128+
possible_keys = [
129+
f"{function_to_optimize.file_path}:{function_to_optimize.qualified_name}",
130+
f"{function_to_optimize.file_path}:{function_to_optimize.function_name}",
131+
]
132+
133+
for key in possible_keys:
134+
if key in stats:
135+
return stats[key]
136+
137+
return None

codeflash/discovery/functions_to_optimize.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,15 @@ def get_functions_to_optimize(
163163
"Only one of optimize_all, replay_test, or file should be provided"
164164
)
165165
functions: dict[str, list[FunctionToOptimize]]
166+
trace_file_path: Path | None = None
166167
with warnings.catch_warnings():
167168
warnings.simplefilter(action="ignore", category=SyntaxWarning)
168169
if optimize_all:
169170
logger.info("Finding all functions in the module '%s'…", optimize_all)
170171
console.rule()
171172
functions = get_all_files_and_functions(Path(optimize_all))
172173
elif replay_test:
173-
functions = get_all_replay_test_functions(
174+
functions, trace_file_path = get_all_replay_test_functions(
174175
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
175176
)
176177
elif file is not None:
@@ -206,6 +207,28 @@ def get_functions_to_optimize(
206207
filtered_modified_functions, functions_count = filter_functions(
207208
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
208209
)
210+
211+
if trace_file_path and trace_file_path.exists():
212+
from codeflash.benchmarking.function_ranker import FunctionRanker
213+
214+
ranker = FunctionRanker(trace_file_path)
215+
216+
all_functions = []
217+
for file_functions in filtered_modified_functions.values():
218+
all_functions.extend(file_functions)
219+
220+
if all_functions:
221+
ranked_functions = ranker.rank_functions(all_functions)
222+
223+
ranked_dict = {}
224+
for func in ranked_functions:
225+
if func.file_path not in ranked_dict:
226+
ranked_dict[func.file_path] = []
227+
ranked_dict[func.file_path].append(func)
228+
229+
filtered_modified_functions = ranked_dict
230+
logger.info(f"Ranked {len(all_functions)} functions by optimization priority using trace data")
231+
209232
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
210233
if optimize_all:
211234
three_min_in_ns = int(1.8e11)
@@ -272,7 +295,34 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
272295

273296
def get_all_replay_test_functions(
274297
replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path
275-
) -> dict[Path, list[FunctionToOptimize]]:
298+
) -> tuple[dict[Path, list[FunctionToOptimize]], Path]:
299+
trace_file_path: Path | None = None
300+
for replay_test_file in replay_test:
301+
try:
302+
with replay_test_file.open("r", encoding="utf8") as f:
303+
tree = ast.parse(f.read())
304+
for node in ast.walk(tree):
305+
if isinstance(node, ast.Assign):
306+
for target in node.targets:
307+
if (
308+
isinstance(target, ast.Name)
309+
and target.id == "trace_file_path"
310+
and isinstance(node.value, ast.Constant)
311+
and isinstance(node.value.value, str)
312+
):
313+
trace_file_path = Path(node.value.value)
314+
break
315+
if trace_file_path:
316+
break
317+
if trace_file_path:
318+
break
319+
except Exception as e:
320+
logger.warning(f"Error parsing replay test file {replay_test_file}: {e}")
321+
322+
if not trace_file_path:
323+
logger.error("Could not find trace_file_path in replay test files.")
324+
exit_with_message("Could not find trace_file_path in replay test files.")
325+
276326
function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=replay_test)
277327
# Get the absolute file paths for each function, excluding class name if present
278328
filtered_valid_functions = defaultdict(list)
@@ -317,7 +367,7 @@ def get_all_replay_test_functions(
317367
if filtered_list:
318368
filtered_valid_functions[file_path] = filtered_list
319369

320-
return filtered_valid_functions
370+
return filtered_valid_functions, trace_file_path
321371

322372

323373
def is_git_repo(file_path: str) -> bool:

0 commit comments

Comments
 (0)