1
1
from __future__ import annotations
2
2
3
3
import sqlite3
4
- import textwrap
5
4
from pathlib import Path
6
5
from typing import TYPE_CHECKING , Any
7
6
@@ -43,6 +42,7 @@ def get_next_arg_and_return(
43
42
44
43
45
44
def get_function_alias (module : str , function_name : str ) -> str :
45
+ # This is already pretty optimal.
46
46
return "_" .join (module .split ("." )) + "_" + function_name
47
47
48
48
@@ -66,152 +66,144 @@ def create_trace_replay_test_code(
66
66
A string containing the test code
67
67
68
68
"""
69
- assert test_framework in [ "pytest" , "unittest" ]
69
+ assert test_framework in ( "pytest" , "unittest" )
70
70
71
- # Create Imports
72
- imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
73
- { "import unittest" if test_framework == "unittest" else "" }
74
- from codeflash.benchmarking.replay_test import get_next_arg_and_return
75
- """
71
+ # Precompute aliases and filepaths
72
+ func_aliases , class_aliases , classfunc_aliases , file_paths = _get_aliases_and_paths (functions_data )
76
73
74
+ # Build function imports in one pass
77
75
function_imports = []
78
76
for func in functions_data :
79
77
module_name = func .get ("module_name" )
80
78
function_name = func .get ("function_name" )
81
79
class_name = func .get ("class_name" , "" )
82
80
if class_name :
83
- function_imports .append (
84
- f"from { module_name } import { class_name } as { get_function_alias (module_name , class_name )} "
85
- )
81
+ cname_alias = class_aliases [class_name ]
82
+ function_imports .append (f"from { module_name } import { class_name } as { cname_alias } " )
86
83
else :
87
- function_imports .append (
88
- f"from { module_name } import { function_name } as { get_function_alias (module_name , function_name )} "
89
- )
90
-
91
- imports += "\n " .join (function_imports )
92
-
93
- functions_to_optimize = sorted (
94
- {func .get ("function_name" ) for func in functions_data if func .get ("function_name" ) != "__init__" }
84
+ alias = func_aliases [(module_name , function_name )]
85
+ function_imports .append (f"from { module_name } import { function_name } as { alias } " )
86
+ imports = (
87
+ "from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle\n "
88
+ f"{ 'import unittest' if test_framework == 'unittest' else '' } \n "
89
+ "from codeflash.benchmarking.replay_test import get_next_arg_and_return\n " + "\n " .join (function_imports )
95
90
)
91
+
92
+ # Precompute functions_to_optimize efficiently using set and list since sorted(set(...))
93
+ functions_set = {func ["function_name" ] for func in functions_data if func ["function_name" ] != "__init__" }
94
+ functions_to_optimize = sorted (functions_set )
96
95
metadata = f"""functions = { functions_to_optimize }
97
96
trace_file_path = r"{ trace_file } "
98
97
"""
99
- # Templates for different types of tests
100
- test_function_body = textwrap .dedent (
101
- """\
102
- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):
103
- args = pickle.loads(args_pkl)
104
- kwargs = pickle.loads(kwargs_pkl)
105
- ret = {function_name}(*args, **kwargs)
106
- """
107
- )
108
98
109
- test_method_body = textwrap .dedent (
110
- """\
111
- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
112
- args = pickle.loads(args_pkl)
113
- kwargs = pickle.loads(kwargs_pkl){filter_variables}
114
- function_name = "{orig_function_name}"
115
- if not args:
116
- raise ValueError("No arguments provided for the method.")
117
- if function_name == "__init__":
118
- ret = {class_name_alias}(*args[1:], **kwargs)
119
- else:
120
- ret = {class_name_alias}{method_name}(*args, **kwargs)
121
- """
99
+ # Prepare templates only once
100
+ test_function_body = (
101
+ "for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, "
102
+ 'benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", '
103
+ 'file_path=r"{file_path}", num_to_get={max_run_count}):\n '
104
+ " args = pickle.loads(args_pkl)\n "
105
+ " kwargs = pickle.loads(kwargs_pkl)\n "
106
+ " ret = {function_name}(*args, **kwargs)\n "
122
107
)
123
-
124
- test_class_method_body = textwrap .dedent (
125
- """\
126
- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
127
- args = pickle.loads(args_pkl)
128
- kwargs = pickle.loads(kwargs_pkl){filter_variables}
129
- if not args:
130
- raise ValueError("No arguments provided for the method.")
131
- ret = {class_name_alias}{method_name}(*args[1:], **kwargs)
132
- """
108
+ test_method_body = (
109
+ "for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, "
110
+ 'benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", '
111
+ 'file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n '
112
+ " args = pickle.loads(args_pkl)\n "
113
+ " kwargs = pickle.loads(kwargs_pkl){filter_variables}\n "
114
+ ' function_name = "{orig_function_name}"\n '
115
+ " if not args:\n "
116
+ ' raise ValueError("No arguments provided for the method.")\n '
117
+ ' if function_name == "__init__":\n '
118
+ " ret = {class_name_alias}(*args[1:], **kwargs)\n "
119
+ " else:\n "
120
+ " ret = {class_name_alias}{method_name}(*args, **kwargs)\n "
133
121
)
134
- test_static_method_body = textwrap .dedent (
135
- """\
136
- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
137
- args = pickle.loads(args_pkl)
138
- kwargs = pickle.loads(kwargs_pkl){filter_variables}
139
- ret = {class_name_alias}{method_name}(*args, **kwargs)
140
- """
122
+ test_class_method_body = (
123
+ "for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, "
124
+ 'benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", '
125
+ 'file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n '
126
+ " args = pickle.loads(args_pkl)\n "
127
+ " kwargs = pickle.loads(kwargs_pkl){filter_variables}\n "
128
+ " if not args:\n "
129
+ ' raise ValueError("No arguments provided for the method.")\n '
130
+ " ret = {class_name_alias}{method_name}(*args[1:], **kwargs)\n "
141
131
)
142
-
143
- # Create main body
144
-
132
+ test_static_method_body = (
133
+ "for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, "
134
+ 'benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", '
135
+ 'file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n '
136
+ " args = pickle.loads(args_pkl)\n "
137
+ " kwargs = pickle.loads(kwargs_pkl){filter_variables}\n "
138
+ " ret = {class_name_alias}{method_name}(*args, **kwargs)\n "
139
+ )
140
+ test_bodies = {
141
+ "function" : test_function_body ,
142
+ "method" : test_method_body ,
143
+ "classmethod" : test_class_method_body ,
144
+ "staticmethod" : test_static_method_body ,
145
+ }
146
+
147
+ # Precompute the format values up-front for all functions
145
148
if test_framework == "unittest" :
146
- self = "self"
147
- test_template = "\n class TestTracedFunctions(unittest.TestCase):\n "
149
+ self_str = "self"
150
+ test_template_list = ["\n class TestTracedFunctions(unittest.TestCase):\n " ]
151
+ indent_level = " "
152
+ def_line = " "
148
153
else :
149
- test_template = ""
150
- self = ""
154
+ self_str = ""
155
+ test_template_list = []
156
+ indent_level = " "
157
+ def_line = ""
151
158
152
159
for func in functions_data :
153
- module_name = func . get ( "module_name" )
154
- function_name = func . get ( "function_name" )
160
+ module_name = func [ "module_name" ]
161
+ function_name = func [ "function_name" ]
155
162
class_name = func .get ("class_name" )
156
- file_path = func .get ("file_path" )
157
- benchmark_function_name = func .get ("benchmark_function_name" )
158
- function_properties = func .get ("function_properties" )
163
+ file_path = func ["file_path" ]
164
+ file_path_posix = file_paths [file_path ]
165
+ benchmark_function_name = func ["benchmark_function_name" ]
166
+ function_properties = func ["function_properties" ]
159
167
if not class_name :
160
- alias = get_function_alias (module_name , function_name )
161
- test_body = test_function_body .format (
168
+ alias = func_aliases [(module_name , function_name )]
169
+ template = test_bodies ["function" ]
170
+ test_body_filled = template .format (
162
171
benchmark_function_name = benchmark_function_name ,
163
172
orig_function_name = function_name ,
164
173
function_name = alias ,
165
- file_path = Path ( file_path ). as_posix () ,
174
+ file_path = file_path_posix ,
166
175
max_run_count = max_run_count ,
167
176
)
168
177
else :
169
- class_name_alias = get_function_alias (module_name , class_name )
170
- alias = get_function_alias (module_name , class_name + "_" + function_name )
171
-
178
+ class_name_alias = class_aliases [class_name ]
179
+ alias = classfunc_aliases [(module_name , class_name , function_name )]
172
180
filter_variables = ""
173
- # filter_variables = '\n args.pop("cls", None)'
174
181
method_name = "." + function_name if function_name != "__init__" else ""
175
182
if function_properties .is_classmethod :
176
- test_body = test_class_method_body .format (
177
- benchmark_function_name = benchmark_function_name ,
178
- orig_function_name = function_name ,
179
- file_path = Path (file_path ).as_posix (),
180
- class_name_alias = class_name_alias ,
181
- class_name = class_name ,
182
- method_name = method_name ,
183
- max_run_count = max_run_count ,
184
- filter_variables = filter_variables ,
185
- )
183
+ template = test_bodies ["classmethod" ]
186
184
elif function_properties .is_staticmethod :
187
- test_body = test_static_method_body .format (
188
- benchmark_function_name = benchmark_function_name ,
189
- orig_function_name = function_name ,
190
- file_path = Path (file_path ).as_posix (),
191
- class_name_alias = class_name_alias ,
192
- class_name = class_name ,
193
- method_name = method_name ,
194
- max_run_count = max_run_count ,
195
- filter_variables = filter_variables ,
196
- )
185
+ template = test_bodies ["staticmethod" ]
197
186
else :
198
- test_body = test_method_body .format (
199
- benchmark_function_name = benchmark_function_name ,
200
- orig_function_name = function_name ,
201
- file_path = Path (file_path ).as_posix (),
202
- class_name_alias = class_name_alias ,
203
- class_name = class_name ,
204
- method_name = method_name ,
205
- max_run_count = max_run_count ,
206
- filter_variables = filter_variables ,
207
- )
187
+ template = test_bodies ["method" ]
188
+ test_body_filled = template .format (
189
+ benchmark_function_name = benchmark_function_name ,
190
+ orig_function_name = function_name ,
191
+ file_path = file_path_posix ,
192
+ class_name_alias = class_name_alias ,
193
+ class_name = class_name ,
194
+ method_name = method_name ,
195
+ max_run_count = max_run_count ,
196
+ filter_variables = filter_variables ,
197
+ )
208
198
209
- formatted_test_body = textwrap .indent (test_body , " " if test_framework == "unittest" else " " )
199
+ # No repeated indent/dedent. Do indent directly, as we know where to indent.
200
+ formatted_test_body = "" .join (
201
+ indent_level + line if line .strip () else line for line in test_body_filled .splitlines (True )
202
+ )
210
203
211
- test_template += " " if test_framework == "unittest" else ""
212
- test_template += f"def test_{ alias } ({ self } ):\n { formatted_test_body } \n "
204
+ test_template_list .append (f"{ def_line } def test_{ alias } ({ self_str } ):\n { formatted_test_body } \n " )
213
205
214
- return imports + "\n " + metadata + "\n " + test_template
206
+ return imports + "\n " + metadata + "\n " + "" . join ( test_template_list )
215
207
216
208
217
209
def generate_replay_test (
@@ -294,3 +286,29 @@ def generate_replay_test(
294
286
logger .info (f"Error generating replay tests: { e } " )
295
287
296
288
return count
289
+
290
+
291
+ def _get_aliases_and_paths (functions_data ):
292
+ # Precompute all needed aliases and file posix paths up front in a single pass
293
+ func_aliases = {}
294
+ class_aliases = {}
295
+ classfunc_aliases = {}
296
+ file_paths = {}
297
+ for func in functions_data :
298
+ module_name = func .get ("module_name" )
299
+ function_name = func .get ("function_name" )
300
+ class_name = func .get ("class_name" , "" )
301
+ file_path = func .get ("file_path" )
302
+ # Precompute Path(file_path).as_posix() once per unique file_path
303
+ if file_path not in file_paths :
304
+ file_paths [file_path ] = Path (file_path ).as_posix ()
305
+ if class_name :
306
+ # avoid re-calculating class alias if already done
307
+ if class_name not in class_aliases :
308
+ class_aliases [class_name ] = get_function_alias (module_name , class_name )
309
+ classfunc_key = (module_name , class_name , function_name )
310
+ classfunc_aliases [classfunc_key ] = get_function_alias (module_name , class_name + "_" + function_name )
311
+ else :
312
+ # alias for global function
313
+ func_aliases [(module_name , function_name )] = get_function_alias (module_name , function_name )
314
+ return func_aliases , class_aliases , classfunc_aliases , file_paths
0 commit comments