1
1
from __future__ import annotations
2
2
3
+ import difflib
3
4
import os
5
+ import re
4
6
import shlex
7
+ import shutil
5
8
import subprocess
6
- from typing import TYPE_CHECKING
9
+ import tempfile
10
+ from pathlib import Path
11
+ from typing import Optional , Union
7
12
8
13
import isort
9
14
10
15
from codeflash .cli_cmds .console import console , logger
11
16
12
- if TYPE_CHECKING :
13
- from pathlib import Path
14
17
18
+ def generate_unified_diff (original : str , modified : str , from_file : str , to_file : str ) -> str :
19
+ line_pattern = re .compile (r"(.*?(?:\r\n|\n|\r|$))" )
15
20
16
- def format_code (formatter_cmds : list [str ], path : Path , print_status : bool = True ) -> str : # noqa
21
+ def split_lines (text : str ) -> list [str ]:
22
+ lines = [match [0 ] for match in line_pattern .finditer (text )]
23
+ if lines and lines [- 1 ] == "" :
24
+ lines .pop ()
25
+ return lines
26
+
27
+ original_lines = split_lines (original )
28
+ modified_lines = split_lines (modified )
29
+
30
+ diff_output = []
31
+ for line in difflib .unified_diff (original_lines , modified_lines , fromfile = from_file , tofile = to_file , n = 5 ):
32
+ if line .endswith ("\n " ):
33
+ diff_output .append (line )
34
+ else :
35
+ diff_output .append (line + "\n " )
36
+ diff_output .append ("\\ No newline at end of file\n " )
37
+
38
+ return "" .join (diff_output )
39
+
40
+
41
+ def apply_formatter_cmds (
42
+ cmds : list [str ],
43
+ path : Path ,
44
+ test_dir_str : Optional [str ],
45
+ print_status : bool , # noqa
46
+ ) -> tuple [Path , str ]:
17
47
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
18
- formatter_name = formatter_cmds [0 ].lower ()
48
+ formatter_name = cmds [0 ].lower ()
49
+ should_make_copy = False
50
+ file_path = path
51
+
52
+ if test_dir_str :
53
+ should_make_copy = True
54
+ file_path = Path (test_dir_str ) / "temp.py"
55
+
56
+ if not cmds or formatter_name == "disabled" :
57
+ return path , path .read_text (encoding = "utf8" )
58
+
19
59
if not path .exists ():
20
- msg = f"File { path } does not exist. Cannot format the file ."
60
+ msg = f"File { path } does not exist. Cannot apply formatter commands ."
21
61
raise FileNotFoundError (msg )
22
- if formatter_name == "disabled" :
23
- return path .read_text (encoding = "utf8" )
62
+
63
+ if should_make_copy :
64
+ shutil .copy2 (path , file_path )
65
+
24
66
file_token = "$file" # noqa: S105
25
- for command in formatter_cmds :
67
+
68
+ for command in cmds :
26
69
formatter_cmd_list = shlex .split (command , posix = os .name != "nt" )
27
- formatter_cmd_list = [path .as_posix () if chunk == file_token else chunk for chunk in formatter_cmd_list ]
70
+ formatter_cmd_list = [file_path .as_posix () if chunk == file_token else chunk for chunk in formatter_cmd_list ]
28
71
try :
29
72
result = subprocess .run (formatter_cmd_list , capture_output = True , check = False )
30
73
if result .returncode == 0 :
31
74
if print_status :
32
- console .rule (f"Formatted Successfully with: { formatter_name .replace ('$file' , path .name )} " )
75
+ console .rule (f"Formatted Successfully with: { command .replace ('$file' , path .name )} " )
33
76
else :
34
77
logger .error (f"Failed to format code with { ' ' .join (formatter_cmd_list )} " )
35
78
except FileNotFoundError as e :
@@ -44,7 +87,60 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True
44
87
45
88
raise e from None
46
89
47
- return path .read_text (encoding = "utf8" )
90
+ return file_path , file_path .read_text (encoding = "utf8" )
91
+
92
+
93
+ def get_diff_lines_count (diff_output : str ) -> int :
94
+ lines = diff_output .split ("\n " )
95
+
96
+ def is_diff_line (line : str ) -> bool :
97
+ return line .startswith (("+" , "-" )) and not line .startswith (("+++" , "---" ))
98
+
99
+ diff_lines = [line for line in lines if is_diff_line (line )]
100
+ return len (diff_lines )
101
+
102
+
103
+ def format_code (
104
+ formatter_cmds : list [str ],
105
+ path : Union [str , Path ],
106
+ optimized_function : str = "" ,
107
+ check_diff : bool = False , # noqa
108
+ print_status : bool = True , # noqa
109
+ ) -> str :
110
+ with tempfile .TemporaryDirectory () as test_dir_str :
111
+ if isinstance (path , str ):
112
+ path = Path (path )
113
+
114
+ original_code = path .read_text (encoding = "utf8" )
115
+ original_code_lines = len (original_code .split ("\n " ))
116
+
117
+ if check_diff and original_code_lines > 50 :
118
+ # we dont' count the formatting diff for the optimized function as it should be well-formatted
119
+ original_code_without_opfunc = original_code .replace (optimized_function , "" )
120
+
121
+ original_temp = Path (test_dir_str ) / "original_temp.py"
122
+ original_temp .write_text (original_code_without_opfunc , encoding = "utf8" )
123
+
124
+ formatted_temp , formatted_code = apply_formatter_cmds (
125
+ formatter_cmds , original_temp , test_dir_str , print_status = False
126
+ )
127
+
128
+ diff_output = generate_unified_diff (
129
+ original_code_without_opfunc , formatted_code , from_file = str (original_temp ), to_file = str (formatted_temp )
130
+ )
131
+ diff_lines_count = get_diff_lines_count (diff_output )
132
+
133
+ max_diff_lines = min (int (original_code_lines * 0.3 ), 50 )
134
+
135
+ if diff_lines_count > max_diff_lines and max_diff_lines != - 1 :
136
+ logger .debug (
137
+ f"Skipping formatting { path } : { diff_lines_count } lines would change (max: { max_diff_lines } )"
138
+ )
139
+ return original_code
140
+ # TODO : We can avoid formatting the whole file again and only formatting the optimized code standalone and replace in formatted file above.
141
+ _ , formatted_code = apply_formatter_cmds (formatter_cmds , path , test_dir_str = None , print_status = print_status )
142
+ logger .debug (f"Formatted { path } with commands: { formatter_cmds } " )
143
+ return formatted_code
48
144
49
145
50
146
def sort_imports (code : str ) -> str :
0 commit comments