Skip to content

Commit 0b4fcb6

Browse files
committed
rank functions
1 parent 09bf156 commit 0b4fcb6

File tree

3 files changed

+205
-5
lines changed

3 files changed

+205
-5
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from codeflash.cli_cmds.console import logger
6+
from codeflash.tracing.profile_stats import ProfileStats
7+
8+
if TYPE_CHECKING:
9+
from pathlib import Path
10+
11+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
12+
13+
14+
class FunctionRanker:
15+
"""Ranks functions for optimization based on trace data using ttX scoring.
16+
17+
ttX = own_time + (time_spent_in_callees x call_count)
18+
19+
This prioritizes functions that:
20+
1. Take significant time themselves (own_time)
21+
2. Are called frequently and have expensive subcalls (time_spent_in_callees x call_count)
22+
"""
23+
24+
def __init__(self, trace_file_path: Path) -> None:
25+
self.trace_file_path = trace_file_path
26+
self._function_stats = None
27+
28+
def load_function_stats(self) -> dict[str, dict]:
29+
"""Load function timing statistics from trace database using ProfileStats."""
30+
if self._function_stats is not None:
31+
return self._function_stats
32+
33+
self._function_stats = {}
34+
35+
try:
36+
profile_stats = ProfileStats(self.trace_file_path.as_posix())
37+
38+
# Access the stats dictionary directly from ProfileStats
39+
for (filename, line_number, function_name), (
40+
call_count,
41+
_num_callers,
42+
total_time_ns,
43+
cumulative_time_ns,
44+
_callers,
45+
) in profile_stats.stats.items():
46+
if call_count <= 0:
47+
continue
48+
49+
if "." in function_name and not function_name.startswith("<"):
50+
parts = function_name.split(".", 1)
51+
if len(parts) == 2:
52+
class_name, method_name = parts
53+
qualified_name = function_name
54+
base_function_name = method_name
55+
else:
56+
class_name = None
57+
qualified_name = function_name
58+
base_function_name = function_name
59+
else:
60+
class_name = None
61+
qualified_name = function_name
62+
base_function_name = function_name
63+
64+
# Calculate own time (total time - time spent in subcalls)
65+
own_time_ns = total_time_ns
66+
time_in_callees_ns = cumulative_time_ns - total_time_ns
67+
68+
# Calculate ttX score
69+
ttx_score = own_time_ns + (time_in_callees_ns * call_count)
70+
71+
function_key = f"{filename}:{qualified_name}"
72+
self._function_stats[function_key] = {
73+
"filename": filename,
74+
"function_name": base_function_name,
75+
"qualified_name": qualified_name,
76+
"class_name": class_name,
77+
"line_number": line_number,
78+
"call_count": call_count,
79+
"own_time_ns": own_time_ns,
80+
"cumulative_time_ns": cumulative_time_ns,
81+
"time_in_callees_ns": time_in_callees_ns,
82+
"ttx_score": ttx_score,
83+
}
84+
85+
logger.debug(f"Loaded timing stats for {len(self._function_stats)} functions from trace using ProfileStats")
86+
87+
except Exception as e:
88+
logger.warning(f"Failed to load function stats from trace file {self.trace_file_path}: {e}")
89+
self._function_stats = {}
90+
91+
return self._function_stats
92+
93+
def get_function_ttx_score(self, function_to_optimize: FunctionToOptimize) -> float:
94+
stats = self.load_function_stats()
95+
96+
possible_keys = [
97+
f"{function_to_optimize.file_path}:{function_to_optimize.qualified_name}",
98+
f"{function_to_optimize.file_path}:{function_to_optimize.function_name}",
99+
]
100+
101+
for key in possible_keys:
102+
if key in stats:
103+
return stats[key]["ttx_score"]
104+
105+
return 0.0
106+
107+
def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
108+
# Calculate ttX scores for all functions
109+
function_scores = []
110+
for func in functions_to_optimize:
111+
ttx_score = self.get_function_ttx_score(func)
112+
function_scores.append((func, ttx_score))
113+
114+
# Sort by ttX score descending (highest impact first)
115+
function_scores.sort(key=lambda x: x[1], reverse=True)
116+
117+
# logger.info("Function ranking by ttX score:")
118+
# for i, (func, score) in enumerate(function_scores[:10]): # Top 10
119+
# logger.info(f" {i + 1}. {func.qualified_name} (ttX: {score:.0f}ns)")
120+
121+
ranked_functions = [func for func, _ in function_scores]
122+
logger.info(f"Ranked {len(ranked_functions)} functions by optimization priority")
123+
124+
return ranked_functions
125+
126+
def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
127+
stats = self.load_function_stats()
128+
129+
possible_keys = [
130+
f"{function_to_optimize.file_path}:{function_to_optimize.qualified_name}",
131+
f"{function_to_optimize.file_path}:{function_to_optimize.function_name}",
132+
]
133+
134+
for key in possible_keys:
135+
if key in stats:
136+
return stats[key]
137+
138+
return None

codeflash/discovery/functions_to_optimize.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,15 @@ def get_functions_to_optimize(
163163
"Only one of optimize_all, replay_test, or file should be provided"
164164
)
165165
functions: dict[str, list[FunctionToOptimize]]
166+
trace_file_path: Path | None = None
166167
with warnings.catch_warnings():
167168
warnings.simplefilter(action="ignore", category=SyntaxWarning)
168169
if optimize_all:
169170
logger.info("Finding all functions in the module '%s'…", optimize_all)
170171
console.rule()
171172
functions = get_all_files_and_functions(Path(optimize_all))
172173
elif replay_test:
173-
functions = get_all_replay_test_functions(
174+
functions, trace_file_path = get_all_replay_test_functions(
174175
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
175176
)
176177
elif file is not None:
@@ -206,6 +207,28 @@ def get_functions_to_optimize(
206207
filtered_modified_functions, functions_count = filter_functions(
207208
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
208209
)
210+
211+
if trace_file_path and trace_file_path.exists():
212+
from codeflash.benchmarking.function_ranker import FunctionRanker
213+
214+
ranker = FunctionRanker(trace_file_path)
215+
216+
all_functions = []
217+
for file_functions in filtered_modified_functions.values():
218+
all_functions.extend(file_functions)
219+
220+
if all_functions:
221+
ranked_functions = ranker.rank_functions(all_functions)
222+
223+
ranked_dict = {}
224+
for func in ranked_functions:
225+
if func.file_path not in ranked_dict:
226+
ranked_dict[func.file_path] = []
227+
ranked_dict[func.file_path].append(func)
228+
229+
filtered_modified_functions = ranked_dict
230+
logger.info(f"Ranked {len(all_functions)} functions by optimization priority using trace data")
231+
209232
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
210233
if optimize_all:
211234
three_min_in_ns = int(1.8e11)
@@ -272,7 +295,34 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
272295

273296
def get_all_replay_test_functions(
274297
replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path
275-
) -> dict[Path, list[FunctionToOptimize]]:
298+
) -> tuple[dict[Path, list[FunctionToOptimize]], Path]:
299+
trace_file_path: Path | None = None
300+
for replay_test_file in replay_test:
301+
try:
302+
with replay_test_file.open("r", encoding="utf8") as f:
303+
tree = ast.parse(f.read())
304+
for node in ast.walk(tree):
305+
if isinstance(node, ast.Assign):
306+
for target in node.targets:
307+
if (
308+
isinstance(target, ast.Name)
309+
and target.id == "trace_file_path"
310+
and isinstance(node.value, ast.Constant)
311+
and isinstance(node.value.value, str)
312+
):
313+
trace_file_path = Path(node.value.value)
314+
break
315+
if trace_file_path:
316+
break
317+
if trace_file_path:
318+
break
319+
except Exception as e:
320+
logger.warning(f"Error parsing replay test file {replay_test_file}: {e}")
321+
322+
if not trace_file_path:
323+
logger.error("Could not find trace_file_path in replay test files.")
324+
exit_with_message("Could not find trace_file_path in replay test files.")
325+
276326
function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=replay_test)
277327
# Get the absolute file paths for each function, excluding class name if present
278328
filtered_valid_functions = defaultdict(list)
@@ -317,7 +367,7 @@ def get_all_replay_test_functions(
317367
if filtered_list:
318368
filtered_valid_functions[file_path] = filtered_list
319369

320-
return filtered_valid_functions
370+
return filtered_valid_functions, trace_file_path
321371

322372

323373
def is_git_repo(file_path: str) -> bool:

codeflash/tracing/profile_stats.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,27 @@ def create_stats(self) -> None:
2727
filename,
2828
line_number,
2929
function,
30+
class_name,
3031
call_count_nonrecursive,
3132
num_callers,
3233
total_time_ns,
3334
cumulative_time_ns,
3435
callers,
3536
) in pdata:
3637
loaded_callers = json.loads(callers)
37-
unmapped_callers = {caller["key"]: caller["value"] for caller in loaded_callers}
38-
self.stats[(filename, line_number, function)] = (
38+
unmapped_callers = {}
39+
for caller in loaded_callers:
40+
caller_key = caller["key"]
41+
if isinstance(caller_key, list):
42+
caller_key = tuple(caller_key)
43+
elif not isinstance(caller_key, tuple):
44+
caller_key = (caller_key,) if not isinstance(caller_key, (list, tuple)) else tuple(caller_key)
45+
unmapped_callers[caller_key] = caller["value"]
46+
47+
# Create function key with class name if present (matching tracer.py format)
48+
function_name = f"{class_name}.{function}" if class_name else function
49+
50+
self.stats[(filename, line_number, function_name)] = (
3951
call_count_nonrecursive,
4052
num_callers,
4153
total_time_ns / time_conversion_factor if time_conversion_factor != 1 else total_time_ns,

0 commit comments

Comments
 (0)