Skip to content

Commit a2e78e1

Browse files
Merge pull request #275 from codeflash-ai/dont-optimize-repeatedly-gh-actions
Don't repeatedly optimize gh actions
2 parents 702bee2 + 226acd7 commit a2e78e1

File tree

12 files changed

+964
-47
lines changed

12 files changed

+964
-47
lines changed

codeflash/api/aiservice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def optimize_python_code( # noqa: D417
118118

119119
if response.status_code == 200:
120120
optimizations_json = response.json()["optimizations"]
121-
logger.info(f"Generated {len(optimizations_json)} candidates.")
121+
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
122122
console.rule()
123123
end_time = time.perf_counter()
124124
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
@@ -189,7 +189,7 @@ def optimize_python_code_line_profiler( # noqa: D417
189189

190190
if response.status_code == 200:
191191
optimizations_json = response.json()["optimizations"]
192-
logger.info(f"Generated {len(optimizations_json)} candidates.")
192+
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
193193
console.rule()
194194
return [
195195
OptimizedCandidate(

codeflash/api/cfapi.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88
from typing import TYPE_CHECKING, Any, Optional
99

10+
import git
1011
import requests
1112
import sentry_sdk
1213
from pydantic.json import pydantic_encoder
@@ -191,3 +192,35 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
191192
return {}
192193

193194
return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}
195+
196+
197+
def is_function_being_optimized_again(
198+
owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]]
199+
) -> Any: # noqa: ANN401
200+
"""Check if the function being optimized is being optimized again."""
201+
response = make_cfapi_request(
202+
"/is-already-optimized",
203+
"POST",
204+
{"owner": owner, "repo": repo, "pr_number": pr_number, "code_contexts": code_contexts},
205+
)
206+
response.raise_for_status()
207+
return response.json()
208+
209+
210+
def add_code_context_hash(code_context_hash: str) -> None:
211+
"""Add code context to the DB cache."""
212+
pr_number = get_pr_number()
213+
if pr_number is None:
214+
return
215+
try:
216+
owner, repo = get_repo_owner_and_name()
217+
pr_number = get_pr_number()
218+
except git.exc.InvalidGitRepositoryError:
219+
return
220+
221+
if owner and repo and pr_number is not None:
222+
make_cfapi_request(
223+
"/add-code-hash",
224+
"POST",
225+
{"owner": owner, "repo": repo, "pr_number": pr_number, "code_hash": code_context_hash},
226+
)

codeflash/cli_cmds/console.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,34 @@ def code_print(code_str: str) -> None:
6666

6767

6868
@contextmanager
69-
def progress_bar(message: str, *, transient: bool = False) -> Generator[TaskID, None, None]:
70-
"""Display a progress bar with a spinner and elapsed time."""
71-
progress = Progress(
72-
SpinnerColumn(next(spinners)),
73-
*Progress.get_default_columns(),
74-
TimeElapsedColumn(),
75-
console=console,
76-
transient=transient,
77-
)
78-
task = progress.add_task(message, total=None)
79-
with progress:
80-
yield task
69+
def progress_bar(
70+
message: str, *, transient: bool = False, revert_to_print: bool = False
71+
) -> Generator[TaskID, None, None]:
72+
"""Display a progress bar with a spinner and elapsed time.
73+
74+
If revert_to_print is True, falls back to printing a single logger.info message
75+
instead of showing a progress bar.
76+
"""
77+
if revert_to_print:
78+
logger.info(message)
79+
80+
# Create a fake task ID since we still need to yield something
81+
class DummyTask:
82+
def __init__(self) -> None:
83+
self.id = 0
84+
85+
yield DummyTask().id
86+
else:
87+
progress = Progress(
88+
SpinnerColumn(next(spinners)),
89+
*Progress.get_default_columns(),
90+
TimeElapsedColumn(),
91+
console=console,
92+
transient=transient,
93+
)
94+
task = progress.add_task(message, total=None)
95+
with progress:
96+
yield task
8197

8298

8399
@contextmanager

codeflash/code_utils/config_consts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget
1010
COVERAGE_THRESHOLD = 60.0
1111
MIN_TESTCASE_PASSED_THRESHOLD = 6
12+
REPEAT_OPTIMIZATION_PROBABILITY = 0.1

codeflash/code_utils/git_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66
import tempfile
77
import time
8+
from functools import cache
89
from io import StringIO
910
from pathlib import Path
1011
from typing import TYPE_CHECKING
@@ -79,6 +80,7 @@ def get_git_remotes(repo: Repo) -> list[str]:
7980
return [remote.name for remote in repository.remotes]
8081

8182

83+
@cache
8284
def get_repo_owner_and_name(repo: Repo | None = None, git_remote: str | None = "origin") -> tuple[str, str]:
8385
remote_url = get_remote_url(repo, git_remote) # call only once
8486
remote_url = remote_url.removesuffix(".git") if remote_url.endswith(".git") else remote_url

codeflash/context/code_context_extractor.py

Lines changed: 111 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3+
import hashlib
34
import os
45
from collections import defaultdict
56
from itertools import chain
6-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, cast
78

89
import libcst as cst
910

@@ -31,8 +32,8 @@
3132
def get_code_optimization_context(
3233
function_to_optimize: FunctionToOptimize,
3334
project_root_path: Path,
34-
optim_token_limit: int = 8000,
35-
testgen_token_limit: int = 8000,
35+
optim_token_limit: int = 16000,
36+
testgen_token_limit: int = 16000,
3637
) -> CodeOptimizationContext:
3738
# Get FunctionSource representation of helpers of FTO
3839
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(
@@ -73,6 +74,13 @@ def get_code_optimization_context(
7374
remove_docstrings=False,
7475
code_context_type=CodeContextType.READ_ONLY,
7576
)
77+
hashing_code_context = extract_code_markdown_context_from_files(
78+
helpers_of_fto_dict,
79+
helpers_of_helpers_dict,
80+
project_root_path,
81+
remove_docstrings=True,
82+
code_context_type=CodeContextType.HASHING,
83+
)
7684

7785
# Handle token limits
7886
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code)
@@ -125,11 +133,15 @@ def get_code_optimization_context(
125133
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
126134
if testgen_context_code_tokens > testgen_token_limit:
127135
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
136+
code_hash_context = hashing_code_context.markdown
137+
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
128138

129139
return CodeOptimizationContext(
130140
testgen_context_code=testgen_context_code,
131141
read_writable_code=final_read_writable_code,
132142
read_only_context_code=read_only_context_code,
143+
hashing_code_context=code_hash_context,
144+
hashing_code_context_hash=code_hash,
133145
helper_functions=helpers_of_fto_list,
134146
preexisting_objects=preexisting_objects,
135147
)
@@ -309,8 +321,8 @@ def extract_code_markdown_context_from_files(
309321
logger.debug(f"Error while getting read-only code: {e}")
310322
continue
311323
if code_context.strip():
312-
code_context_with_imports = CodeString(
313-
code=add_needed_imports_from_module(
324+
if code_context_type != CodeContextType.HASHING:
325+
code_context = add_needed_imports_from_module(
314326
src_module_code=original_code,
315327
dst_module_code=code_context,
316328
src_path=file_path,
@@ -319,10 +331,9 @@ def extract_code_markdown_context_from_files(
319331
helper_functions=list(
320332
helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())
321333
),
322-
),
323-
file_path=file_path.relative_to(project_root_path),
324-
)
325-
code_context_markdown.code_strings.append(code_context_with_imports)
334+
)
335+
code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path))
336+
code_context_markdown.code_strings.append(code_string_context)
326337
# Extract code from file paths containing helpers of helpers
327338
for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items():
328339
try:
@@ -343,18 +354,17 @@ def extract_code_markdown_context_from_files(
343354
continue
344355

345356
if code_context.strip():
346-
code_context_with_imports = CodeString(
347-
code=add_needed_imports_from_module(
357+
if code_context_type != CodeContextType.HASHING:
358+
code_context = add_needed_imports_from_module(
348359
src_module_code=original_code,
349360
dst_module_code=code_context,
350361
src_path=file_path,
351362
dst_path=file_path,
352363
project_root=project_root_path,
353364
helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())),
354-
),
355-
file_path=file_path.relative_to(project_root_path),
356-
)
357-
code_context_markdown.code_strings.append(code_context_with_imports)
365+
)
366+
code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path))
367+
code_context_markdown.code_strings.append(code_string_context)
358368
return code_context_markdown
359369

360370

@@ -492,6 +502,8 @@ def parse_code_and_prune_cst(
492502
filtered_node, found_target = prune_cst_for_testgen_code(
493503
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
494504
)
505+
elif code_context_type == CodeContextType.HASHING:
506+
filtered_node, found_target = prune_cst_for_code_hashing(module, target_functions)
495507
else:
496508
raise ValueError(f"Unknown code_context_type: {code_context_type}") # noqa: EM102
497509

@@ -583,6 +595,90 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
583595
return (node.with_changes(**updates) if updates else node), True
584596

585597

598+
def prune_cst_for_code_hashing( # noqa: PLR0911
599+
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
600+
) -> tuple[cst.CSTNode | None, bool]:
601+
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
602+
603+
Returns
604+
-------
605+
(filtered_node, found_target):
606+
filtered_node: The modified CST node or None if it should be removed.
607+
found_target: True if a target function was found in this node's subtree.
608+
609+
"""
610+
if isinstance(node, (cst.Import, cst.ImportFrom)):
611+
return None, False
612+
613+
if isinstance(node, cst.FunctionDef):
614+
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
615+
if qualified_name in target_functions:
616+
new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body
617+
return node.with_changes(body=new_body), True
618+
return None, False
619+
620+
if isinstance(node, cst.ClassDef):
621+
# Do not recurse into nested classes
622+
if prefix:
623+
return None, False
624+
# Assuming always an IndentedBlock
625+
if not isinstance(node.body, cst.IndentedBlock):
626+
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
627+
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
628+
new_class_body: list[cst.CSTNode] = []
629+
found_target = False
630+
631+
for stmt in node.body.body:
632+
if isinstance(stmt, cst.FunctionDef):
633+
qualified_name = f"{class_prefix}.{stmt.name.value}"
634+
if qualified_name in target_functions:
635+
stmt_with_changes = stmt.with_changes(
636+
body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body))
637+
)
638+
new_class_body.append(stmt_with_changes)
639+
found_target = True
640+
# If no target functions found, remove the class entirely
641+
if not new_class_body or not found_target:
642+
return None, False
643+
return node.with_changes(
644+
body=cst.IndentedBlock(cast("list[cst.BaseStatement]", new_class_body))
645+
) if new_class_body else None, found_target
646+
647+
# For other nodes, we preserve them only if they contain target functions in their children.
648+
section_names = get_section_names(node)
649+
if not section_names:
650+
return node, False
651+
652+
updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
653+
found_any_target = False
654+
655+
for section in section_names:
656+
original_content = getattr(node, section, None)
657+
if isinstance(original_content, (list, tuple)):
658+
new_children = []
659+
section_found_target = False
660+
for child in original_content:
661+
filtered, found_target = prune_cst_for_code_hashing(child, target_functions, prefix)
662+
if filtered:
663+
new_children.append(filtered)
664+
section_found_target |= found_target
665+
666+
if section_found_target:
667+
found_any_target = True
668+
updates[section] = new_children
669+
elif original_content is not None:
670+
filtered, found_target = prune_cst_for_code_hashing(original_content, target_functions, prefix)
671+
if found_target:
672+
found_any_target = True
673+
if filtered:
674+
updates[section] = filtered
675+
676+
if not found_any_target:
677+
return None, False
678+
679+
return (node.with_changes(**updates) if updates else node), True
680+
681+
586682
def prune_cst_for_read_only_code( # noqa: PLR0911
587683
node: cst.CSTNode,
588684
target_functions: set[str],

0 commit comments

Comments
 (0)