Skip to content

Commit fc4f2de

Browse files
authored
Merge branch 'main' into tracer-optimization
2 parents edfc240 + 10e8a13 commit fc4f2de

File tree

3 files changed

+711
-18
lines changed

3 files changed

+711
-18
lines changed

codeflash/code_utils/formatter.py

Lines changed: 108 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,78 @@
11
from __future__ import annotations
22

3+
import difflib
34
import os
5+
import re
46
import shlex
7+
import shutil
58
import subprocess
6-
from typing import TYPE_CHECKING
9+
import tempfile
10+
from pathlib import Path
11+
from typing import Optional, Union
712

813
import isort
914

1015
from codeflash.cli_cmds.console import console, logger
1116

12-
if TYPE_CHECKING:
13-
from pathlib import Path
1417

18+
def generate_unified_diff(original: str, modified: str, from_file: str, to_file: str) -> str:
19+
line_pattern = re.compile(r"(.*?(?:\r\n|\n|\r|$))")
1520

16-
def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa
21+
def split_lines(text: str) -> list[str]:
22+
lines = [match[0] for match in line_pattern.finditer(text)]
23+
if lines and lines[-1] == "":
24+
lines.pop()
25+
return lines
26+
27+
original_lines = split_lines(original)
28+
modified_lines = split_lines(modified)
29+
30+
diff_output = []
31+
for line in difflib.unified_diff(original_lines, modified_lines, fromfile=from_file, tofile=to_file, n=5):
32+
if line.endswith("\n"):
33+
diff_output.append(line)
34+
else:
35+
diff_output.append(line + "\n")
36+
diff_output.append("\\ No newline at end of file\n")
37+
38+
return "".join(diff_output)
39+
40+
41+
def apply_formatter_cmds(
42+
cmds: list[str],
43+
path: Path,
44+
test_dir_str: Optional[str],
45+
print_status: bool, # noqa
46+
) -> tuple[Path, str]:
1747
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
18-
formatter_name = formatter_cmds[0].lower()
48+
formatter_name = cmds[0].lower()
49+
should_make_copy = False
50+
file_path = path
51+
52+
if test_dir_str:
53+
should_make_copy = True
54+
file_path = Path(test_dir_str) / "temp.py"
55+
56+
if not cmds or formatter_name == "disabled":
57+
return path, path.read_text(encoding="utf8")
58+
1959
if not path.exists():
20-
msg = f"File {path} does not exist. Cannot format the file."
60+
msg = f"File {path} does not exist. Cannot apply formatter commands."
2161
raise FileNotFoundError(msg)
22-
if formatter_name == "disabled":
23-
return path.read_text(encoding="utf8")
62+
63+
if should_make_copy:
64+
shutil.copy2(path, file_path)
65+
2466
file_token = "$file" # noqa: S105
25-
for command in formatter_cmds:
67+
68+
for command in cmds:
2669
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
27-
formatter_cmd_list = [path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
70+
formatter_cmd_list = [file_path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
2871
try:
2972
result = subprocess.run(formatter_cmd_list, capture_output=True, check=False)
3073
if result.returncode == 0:
3174
if print_status:
32-
console.rule(f"Formatted Successfully with: {formatter_name.replace('$file', path.name)}")
75+
console.rule(f"Formatted Successfully with: {command.replace('$file', path.name)}")
3376
else:
3477
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
3578
except FileNotFoundError as e:
@@ -44,7 +87,60 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True
4487

4588
raise e from None
4689

47-
return path.read_text(encoding="utf8")
90+
return file_path, file_path.read_text(encoding="utf8")
91+
92+
93+
def get_diff_lines_count(diff_output: str) -> int:
94+
lines = diff_output.split("\n")
95+
96+
def is_diff_line(line: str) -> bool:
97+
return line.startswith(("+", "-")) and not line.startswith(("+++", "---"))
98+
99+
diff_lines = [line for line in lines if is_diff_line(line)]
100+
return len(diff_lines)
101+
102+
103+
def format_code(
104+
formatter_cmds: list[str],
105+
path: Union[str, Path],
106+
optimized_function: str = "",
107+
check_diff: bool = False, # noqa
108+
print_status: bool = True, # noqa
109+
) -> str:
110+
with tempfile.TemporaryDirectory() as test_dir_str:
111+
if isinstance(path, str):
112+
path = Path(path)
113+
114+
original_code = path.read_text(encoding="utf8")
115+
original_code_lines = len(original_code.split("\n"))
116+
117+
if check_diff and original_code_lines > 50:
118+
# we dont' count the formatting diff for the optimized function as it should be well-formatted
119+
original_code_without_opfunc = original_code.replace(optimized_function, "")
120+
121+
original_temp = Path(test_dir_str) / "original_temp.py"
122+
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
123+
124+
formatted_temp, formatted_code = apply_formatter_cmds(
125+
formatter_cmds, original_temp, test_dir_str, print_status=False
126+
)
127+
128+
diff_output = generate_unified_diff(
129+
original_code_without_opfunc, formatted_code, from_file=str(original_temp), to_file=str(formatted_temp)
130+
)
131+
diff_lines_count = get_diff_lines_count(diff_output)
132+
133+
max_diff_lines = min(int(original_code_lines * 0.3), 50)
134+
135+
if diff_lines_count > max_diff_lines and max_diff_lines != -1:
136+
logger.debug(
137+
f"Skipping formatting {path}: {diff_lines_count} lines would change (max: {max_diff_lines})"
138+
)
139+
return original_code
140+
# TODO : We can avoid formatting the whole file again and only formatting the optimized code standalone and replace in formatted file above.
141+
_, formatted_code = apply_formatter_cmds(formatter_cmds, path, test_dir_str=None, print_status=print_status)
142+
logger.debug(f"Formatted {path} with commands: {formatter_cmds}")
143+
return formatted_code
48144

49145

50146
def sort_imports(code: str) -> str:

codeflash/optimization/function_optimizer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,10 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
335335
)
336336

337337
new_code, new_helper_code = self.reformat_code_and_helpers(
338-
code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code
338+
code_context.helper_functions,
339+
explanation.file_path,
340+
self.function_to_optimize_source_code,
341+
optimized_function=best_optimization.candidate.source_code,
339342
)
340343

341344
existing_tests = existing_tests_source_for(
@@ -642,20 +645,23 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path,
642645
f.write(helper_code)
643646

644647
def reformat_code_and_helpers(
645-
self, helper_functions: list[FunctionSource], path: Path, original_code: str
648+
self, helper_functions: list[FunctionSource], path: Path, original_code: str, optimized_function: str
646649
) -> tuple[str, dict[Path, str]]:
647650
should_sort_imports = not self.args.disable_imports_sorting
648651
if should_sort_imports and isort.code(original_code) != original_code:
649652
should_sort_imports = False
650653

651-
new_code = format_code(self.args.formatter_cmds, path)
654+
new_code = format_code(self.args.formatter_cmds, path, optimized_function=optimized_function, check_diff=True)
652655
if should_sort_imports:
653656
new_code = sort_imports(new_code)
654657

655658
new_helper_code: dict[Path, str] = {}
656-
helper_functions_paths = {hf.file_path for hf in helper_functions}
657-
for module_abspath in helper_functions_paths:
658-
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
659+
for hp in helper_functions:
660+
module_abspath = hp.file_path
661+
hp_source_code = hp.source_code
662+
formatted_helper_code = format_code(
663+
self.args.formatter_cmds, module_abspath, optimized_function=hp_source_code, check_diff=True
664+
)
659665
if should_sort_imports:
660666
formatted_helper_code = sort_imports(formatted_helper_code)
661667
new_helper_code[module_abspath] = formatted_helper_code

0 commit comments

Comments
 (0)