Skip to content

Commit dae4afc

Browse files
committed
Merge branch 'ux-changes' into test-filter-cleanup
2 parents 0fcd59c + a928fc0 commit dae4afc

File tree

16 files changed

+969
-66
lines changed

16 files changed

+969
-66
lines changed

codeflash/LICENSE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Business Source License 1.1
33
Parameters
44

55
Licensor: CodeFlash Inc.
6-
Licensed Work: Codeflash Client version 0.13.x
6+
Licensed Work: Codeflash Client version 0.14.x
77
The Licensed Work is (c) 2024 CodeFlash Inc.
88

99
Additional Use Grant: None. Production use of the Licensed Work is only permitted
@@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte
1313
Platform. Please visit codeflash.ai for further
1414
information.
1515

16-
Change Date: 2029-06-03
16+
Change Date: 2029-06-09
1717

1818
Change License: MIT
1919

codeflash/api/aiservice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def optimize_python_code( # noqa: D417
118118

119119
if response.status_code == 200:
120120
optimizations_json = response.json()["optimizations"]
121-
logger.info(f"Generated {len(optimizations_json)} candidates.")
121+
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
122122
console.rule()
123123
end_time = time.perf_counter()
124124
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
@@ -189,7 +189,7 @@ def optimize_python_code_line_profiler( # noqa: D417
189189

190190
if response.status_code == 200:
191191
optimizations_json = response.json()["optimizations"]
192-
logger.info(f"Generated {len(optimizations_json)} candidates.")
192+
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
193193
console.rule()
194194
return [
195195
OptimizedCandidate(

codeflash/api/cfapi.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88
from typing import TYPE_CHECKING, Any, Optional
99

10+
import git
1011
import requests
1112
import sentry_sdk
1213
from pydantic.json import pydantic_encoder
@@ -191,3 +192,35 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
191192
return {}
192193

193194
return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}
195+
196+
197+
def is_function_being_optimized_again(
198+
owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]]
199+
) -> Any: # noqa: ANN401
200+
"""Check if the function being optimized is being optimized again."""
201+
response = make_cfapi_request(
202+
"/is-already-optimized",
203+
"POST",
204+
{"owner": owner, "repo": repo, "pr_number": pr_number, "code_contexts": code_contexts},
205+
)
206+
response.raise_for_status()
207+
return response.json()
208+
209+
210+
def add_code_context_hash(code_context_hash: str) -> None:
211+
"""Add code context to the DB cache."""
212+
pr_number = get_pr_number()
213+
if pr_number is None:
214+
return
215+
try:
216+
owner, repo = get_repo_owner_and_name()
217+
pr_number = get_pr_number()
218+
except git.exc.InvalidGitRepositoryError:
219+
return
220+
221+
if owner and repo and pr_number is not None:
222+
make_cfapi_request(
223+
"/add-code-hash",
224+
"POST",
225+
{"owner": owner, "repo": repo, "pr_number": pr_number, "code_hash": code_context_hash},
226+
)

codeflash/cli_cmds/console.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,34 @@ def code_print(code_str: str) -> None:
6666

6767

6868
@contextmanager
69-
def progress_bar(message: str, *, transient: bool = False) -> Generator[TaskID, None, None]:
70-
"""Display a progress bar with a spinner and elapsed time."""
71-
progress = Progress(
72-
SpinnerColumn(next(spinners)),
73-
*Progress.get_default_columns(),
74-
TimeElapsedColumn(),
75-
console=console,
76-
transient=transient,
77-
)
78-
task = progress.add_task(message, total=None)
79-
with progress:
80-
yield task
69+
def progress_bar(
70+
message: str, *, transient: bool = False, revert_to_print: bool = False
71+
) -> Generator[TaskID, None, None]:
72+
"""Display a progress bar with a spinner and elapsed time.
73+
74+
If revert_to_print is True, falls back to printing a single logger.info message
75+
instead of showing a progress bar.
76+
"""
77+
if revert_to_print:
78+
logger.info(message)
79+
80+
# Create a fake task ID since we still need to yield something
81+
class DummyTask:
82+
def __init__(self) -> None:
83+
self.id = 0
84+
85+
yield DummyTask().id
86+
else:
87+
progress = Progress(
88+
SpinnerColumn(next(spinners)),
89+
*Progress.get_default_columns(),
90+
TimeElapsedColumn(),
91+
console=console,
92+
transient=transient,
93+
)
94+
task = progress.add_task(message, total=None)
95+
with progress:
96+
yield task
8197

8298

8399
@contextmanager

codeflash/code_utils/code_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
import re
66
import shutil
77
import site
8+
import sys
89
from contextlib import contextmanager
910
from functools import lru_cache
1011
from pathlib import Path
1112
from tempfile import TemporaryDirectory
1213

1314
import tomlkit
1415

15-
from codeflash.cli_cmds.console import logger
16+
from codeflash.cli_cmds.console import logger, paneled_text
1617
from codeflash.code_utils.config_parser import find_pyproject_toml
1718

1819
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
@@ -213,3 +214,9 @@ def cleanup_paths(paths: list[Path]) -> None:
213214
def restore_conftest(path_to_content_map: dict[Path, str]) -> None:
214215
for path, file_content in path_to_content_map.items():
215216
path.write_text(file_content, encoding="utf8")
217+
218+
219+
def exit_with_message(message: str, *, error_on_exit: bool = False) -> None:
220+
paneled_text(message, panel_args={"style": "red"})
221+
222+
sys.exit(1 if error_on_exit else 0)

codeflash/code_utils/config_consts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget
1010
COVERAGE_THRESHOLD = 60.0
1111
MIN_TESTCASE_PASSED_THRESHOLD = 6
12+
REPEAT_OPTIMIZATION_PROBABILITY = 0.1

codeflash/code_utils/env_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

33
import os
4-
import sys
54
import tempfile
65
from functools import lru_cache
76
from pathlib import Path
87
from typing import Optional
98

109
from codeflash.cli_cmds.console import logger
10+
from codeflash.code_utils.code_utils import exit_with_message
1111
from codeflash.code_utils.formatter import format_code
1212
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
1313

@@ -24,11 +24,11 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
2424
try:
2525
format_code(formatter_cmds, tmp_file, print_status=False)
2626
except Exception:
27-
print(
28-
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again."
27+
exit_with_message(
28+
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.",
29+
error_on_exit=True,
2930
)
30-
if exit_on_failure:
31-
sys.exit(1)
31+
3232
return return_code
3333

3434

codeflash/code_utils/git_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66
import tempfile
77
import time
8+
from functools import cache
89
from io import StringIO
910
from pathlib import Path
1011
from typing import TYPE_CHECKING
@@ -79,6 +80,7 @@ def get_git_remotes(repo: Repo) -> list[str]:
7980
return [remote.name for remote in repository.remotes]
8081

8182

83+
@cache
8284
def get_repo_owner_and_name(repo: Repo | None = None, git_remote: str | None = "origin") -> tuple[str, str]:
8385
remote_url = get_remote_url(repo, git_remote) # call only once
8486
remote_url = remote_url.removesuffix(".git") if remote_url.endswith(".git") else remote_url

0 commit comments

Comments
 (0)