11
11
#
12
12
from __future__ import annotations
13
13
14
+ import contextlib
14
15
import importlib .machinery
15
16
import io
16
17
import json
@@ -99,6 +100,7 @@ def __init__(
99
100
)
100
101
disable = True
101
102
self .disable = disable
103
+ self ._db_lock : threading .Lock | None = None
102
104
if self .disable :
103
105
return
104
106
if sys .getprofile () is not None or sys .gettrace () is not None :
@@ -108,6 +110,9 @@ def __init__(
108
110
)
109
111
self .disable = True
110
112
return
113
+
114
+ self ._db_lock = threading .Lock ()
115
+
111
116
self .con = None
112
117
self .output_file = Path (output ).resolve ()
113
118
self .functions = functions
@@ -130,6 +135,7 @@ def __init__(
130
135
self .timeout = timeout
131
136
self .next_insert = 1000
132
137
self .trace_count = 0
138
+ self .path_cache = {} # Cache for resolved file paths
133
139
134
140
# Profiler variables
135
141
self .bias = 0 # calibration constant
@@ -178,34 +184,55 @@ def __enter__(self) -> None:
178
184
def __exit__ (
179
185
self , exc_type : type [BaseException ] | None , exc_val : BaseException | None , exc_tb : TracebackType | None
180
186
) -> None :
181
- if self .disable :
187
+ if self .disable or self . _db_lock is None :
182
188
return
183
189
sys .setprofile (None )
184
- self .con .commit ()
185
- console .rule ("Codeflash: Traced Program Output End" , style = "bold blue" )
186
- self .create_stats ()
190
+ threading .setprofile (None )
187
191
188
- cur = self .con .cursor ()
189
- cur .execute (
190
- "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
191
- "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
192
- "cumulative_time_ns INTEGER, callers BLOB)"
193
- )
194
- for func , (cc , nc , tt , ct , callers ) in self .stats .items ():
195
- remapped_callers = [{"key" : k , "value" : v } for k , v in callers .items ()]
192
+ with self ._db_lock :
193
+ if self .con is None :
194
+ return
195
+
196
+ self .con .commit () # Commit any pending from tracer_logic
197
+ console .rule ("Codeflash: Traced Program Output End" , style = "bold blue" )
198
+ self .create_stats () # This calls snapshot_stats which uses self.timings
199
+
200
+ cur = self .con .cursor ()
196
201
cur .execute (
197
- "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
198
- (str (Path (func [0 ]).resolve ()), func [1 ], func [2 ], func [3 ], cc , nc , tt , ct , json .dumps (remapped_callers )),
202
+ "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
203
+ "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
204
+ "cumulative_time_ns INTEGER, callers BLOB)"
199
205
)
200
- self .con .commit ()
206
+ # self.stats is populated by snapshot_stats() called within create_stats()
207
+ # Ensure self.stats is accessed after create_stats() and within the lock if it involves DB data
208
+ # For now, assuming self.stats is primarily in-memory after create_stats()
209
+ for func , (cc , nc , tt , ct , callers ) in self .stats .items ():
210
+ remapped_callers = [{"key" : k , "value" : v } for k , v in callers .items ()]
211
+ cur .execute (
212
+ "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
213
+ (
214
+ str (Path (func [0 ]).resolve ()),
215
+ func [1 ],
216
+ func [2 ],
217
+ func [3 ],
218
+ cc ,
219
+ nc ,
220
+ tt ,
221
+ ct ,
222
+ json .dumps (remapped_callers ),
223
+ ),
224
+ )
225
+ self .con .commit ()
201
226
202
- self .make_pstats_compatible ()
203
- self .print_stats ("tottime" )
204
- cur = self .con .cursor ()
205
- cur .execute ("CREATE TABLE total_time (time_ns INTEGER)" )
206
- cur .execute ("INSERT INTO total_time VALUES (?)" , (self .total_tt ,))
207
- self .con .commit ()
208
- self .con .close ()
227
+ self .make_pstats_compatible () # Modifies self.stats and self.timings in-memory
228
+ self .print_stats ("tottime" ) # Uses self.stats, prints to console
229
+
230
+ cur = self .con .cursor () # New cursor
231
+ cur .execute ("CREATE TABLE total_time (time_ns INTEGER)" )
232
+ cur .execute ("INSERT INTO total_time VALUES (?)" , (self .total_tt ,))
233
+ self .con .commit ()
234
+ self .con .close ()
235
+ self .con = None # Mark connection as closed
209
236
210
237
# filter any functions where we did not capture the return
211
238
self .function_modules = [
@@ -245,18 +272,29 @@ def __exit__(
245
272
def tracer_logic (self , frame : FrameType , event : str ) -> None : # noqa: PLR0911
246
273
if event != "call" :
247
274
return
248
- if self . timeout is not None and (time .time () - self .start_time ) > self .timeout :
275
+ if None is not self . timeout and (time .time () - self .start_time ) > self .timeout :
249
276
sys .setprofile (None )
250
277
threading .setprofile (None )
251
278
console .print (f"Codeflash: Timeout reached! Stopping tracing at { self .timeout } seconds." )
252
279
return
253
- code = frame .f_code
280
+ if self .disable or self ._db_lock is None or self .con is None :
281
+ return
254
282
255
- file_name = Path (code .co_filename ).resolve ()
256
- # TODO : It currently doesn't log the last return call from the first function
283
+ code = frame .f_code
257
284
285
+ # Check function name first before resolving path
258
286
if code .co_name in self .ignored_functions :
259
287
return
288
+
289
+ # Now resolve file path only if we need it
290
+ co_filename = code .co_filename
291
+ if co_filename in self .path_cache :
292
+ file_name = self .path_cache [co_filename ]
293
+ else :
294
+ file_name = Path (co_filename ).resolve ()
295
+ self .path_cache [co_filename ] = file_name
296
+ # TODO : It currently doesn't log the last return call from the first function
297
+
260
298
if not file_name .is_relative_to (self .project_root ):
261
299
return
262
300
if not file_name .exists ():
@@ -266,18 +304,29 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
266
304
class_name = None
267
305
arguments = frame .f_locals
268
306
try :
269
- if (
270
- "self" in arguments
271
- and hasattr (arguments ["self" ], "__class__" )
272
- and hasattr (arguments ["self" ].__class__ , "__name__" )
273
- ):
274
- class_name = arguments ["self" ].__class__ .__name__
275
- elif "cls" in arguments and hasattr (arguments ["cls" ], "__name__" ):
276
- class_name = arguments ["cls" ].__name__
307
+ self_arg = arguments .get ("self" )
308
+ if self_arg is not None :
309
+ try :
310
+ class_name = self_arg .__class__ .__name__
311
+ except AttributeError :
312
+ cls_arg = arguments .get ("cls" )
313
+ if cls_arg is not None :
314
+ with contextlib .suppress (AttributeError ):
315
+ class_name = cls_arg .__name__
316
+ else :
317
+ cls_arg = arguments .get ("cls" )
318
+ if cls_arg is not None :
319
+ with contextlib .suppress (AttributeError ):
320
+ class_name = cls_arg .__name__
277
321
except : # noqa: E722
278
322
# someone can override the getattr method and raise an exception. I'm looking at you wrapt
279
323
return
280
- function_qualified_name = f"{ file_name } :{ (class_name + ':' if class_name else '' )} { code .co_name } "
324
+
325
+ try :
326
+ function_qualified_name = f"{ file_name } :{ code .co_qualname } "
327
+ except AttributeError :
328
+ function_qualified_name = f"{ file_name } :{ (class_name + ':' if class_name else '' )} { code .co_name } "
329
+
281
330
if function_qualified_name in self .ignored_qualified_functions :
282
331
return
283
332
if function_qualified_name not in self .function_count :
@@ -310,49 +359,59 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
310
359
311
360
# TODO: Also check if this function arguments are unique from the values logged earlier
312
361
313
- cur = self .con .cursor ()
362
+ with self ._db_lock :
363
+ # Check connection again inside lock, in case __exit__ closed it.
364
+ if self .con is None :
365
+ return
314
366
315
- t_ns = time .perf_counter_ns ()
316
- original_recursion_limit = sys .getrecursionlimit ()
317
- try :
318
- # pickling can be a recursive operator, so we need to increase the recursion limit
319
- sys .setrecursionlimit (10000 )
320
- # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
321
- # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
322
- # leaks, bad references or side effects when unpickling.
323
- arguments = dict (arguments .items ())
324
- if class_name and code .co_name == "__init__" :
325
- del arguments ["self" ]
326
- local_vars = pickle .dumps (arguments , protocol = pickle .HIGHEST_PROTOCOL )
327
- sys .setrecursionlimit (original_recursion_limit )
328
- except (TypeError , pickle .PicklingError , AttributeError , RecursionError , OSError ):
329
- # we retry with dill if pickle fails. It's slower but more comprehensive
367
+ cur = self .con .cursor ()
368
+
369
+ t_ns = time .perf_counter_ns ()
370
+ original_recursion_limit = sys .getrecursionlimit ()
330
371
try :
331
- local_vars = dill .dumps (arguments , protocol = dill .HIGHEST_PROTOCOL )
372
+ # pickling can be a recursive operator, so we need to increase the recursion limit
373
+ sys .setrecursionlimit (10000 )
374
+ # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
375
+ # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
376
+ # leaks, bad references or side effects when unpickling.
377
+ arguments_copy = dict (arguments .items ()) # Use the local 'arguments' from frame.f_locals
378
+ if class_name and code .co_name == "__init__" and "self" in arguments_copy :
379
+ del arguments_copy ["self" ]
380
+ local_vars = pickle .dumps (arguments_copy , protocol = pickle .HIGHEST_PROTOCOL )
332
381
sys .setrecursionlimit (original_recursion_limit )
382
+ except (TypeError , pickle .PicklingError , AttributeError , RecursionError , OSError ):
383
+ # we retry with dill if pickle fails. It's slower but more comprehensive
384
+ try :
385
+ sys .setrecursionlimit (10000 ) # Ensure limit is high for dill too
386
+ # arguments_copy should be used here as well if defined above
387
+ local_vars = dill .dumps (
388
+ arguments_copy if "arguments_copy" in locals () else dict (arguments .items ()),
389
+ protocol = dill .HIGHEST_PROTOCOL ,
390
+ )
391
+ sys .setrecursionlimit (original_recursion_limit )
392
+
393
+ except (TypeError , dill .PicklingError , AttributeError , RecursionError , OSError ):
394
+ self .function_count [function_qualified_name ] -= 1
395
+ return
333
396
334
- except (TypeError , dill .PicklingError , AttributeError , RecursionError , OSError ):
335
- # give up
336
- self .function_count [function_qualified_name ] -= 1
337
- return
338
- cur .execute (
339
- "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)" ,
340
- (
341
- event ,
342
- code .co_name ,
343
- class_name ,
344
- str (file_name ),
345
- frame .f_lineno ,
346
- frame .f_back .__hash__ (),
347
- t_ns ,
348
- local_vars ,
349
- ),
350
- )
351
- self .trace_count += 1
352
- self .next_insert -= 1
353
- if self .next_insert == 0 :
354
- self .next_insert = 1000
355
- self .con .commit ()
397
+ cur .execute (
398
+ "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)" ,
399
+ (
400
+ event ,
401
+ code .co_name ,
402
+ class_name ,
403
+ str (file_name ),
404
+ frame .f_lineno ,
405
+ frame .f_back .__hash__ (),
406
+ t_ns ,
407
+ local_vars ,
408
+ ),
409
+ )
410
+ self .trace_count += 1
411
+ self .next_insert -= 1
412
+ if self .next_insert == 0 :
413
+ self .next_insert = 1000
414
+ self .con .commit ()
356
415
357
416
def trace_callback (self , frame : FrameType , event : str , arg : str | None ) -> None :
358
417
# profiler section
@@ -475,8 +534,9 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
475
534
cc = cc + 1
476
535
477
536
if pfn in callers :
478
- callers [pfn ] = callers [pfn ] + 1 # TODO: gather more
479
- # stats such as the amount of time added to ct courtesy
537
+ # Increment call count between these functions
538
+ callers [pfn ] = callers [pfn ] + 1
539
+ # Note: This tracks stats such as the amount of time added to ct
480
540
# of this specific call, and the contribution to cc
481
541
# courtesy of this call.
482
542
else :
@@ -703,7 +763,7 @@ def create_stats(self) -> None:
703
763
704
764
def snapshot_stats (self ) -> None :
705
765
self .stats = {}
706
- for func , (cc , _ns , tt , ct , caller_dict ) in self .timings .items ():
766
+ for func , (cc , _ns , tt , ct , caller_dict ) in list ( self .timings .items () ):
707
767
callers = caller_dict .copy ()
708
768
nc = 0
709
769
for callcnt in callers .values ():
0 commit comments