Skip to content

Commit 36ce827

Browse files
authored
Merge pull request #215 from codeflash-ai/tracer-optimization
Tracer optimization
2 parents 10e8a13 + fc4f2de commit 36ce827

File tree

6 files changed

+554
-106
lines changed

6 files changed

+554
-106
lines changed

.github/workflows/unit-tests.yaml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,4 @@ jobs:
3232
run: uvx poetry install --with dev
3333

3434
- name: Unit tests
35-
run: uvx poetry run pytest tests/ --cov --cov-report=xml --benchmark-skip -m "not ci_skip"
36-
37-
- name: Upload coverage reports to Codecov
38-
uses: codecov/codecov-action@v5
39-
if: matrix.python-version == '3.12.1'
40-
with:
41-
token: ${{ secrets.CODECOV_TOKEN }}
35+
run: uvx poetry run pytest tests/ --benchmark-skip -m "not ci_skip"

codeflash/discovery/functions_to_optimize.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,14 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
599599

600600

601601
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
602-
return any(isinstance(node, ast.Return) for node in ast.walk(function_node))
602+
# Custom DFS, return True as soon as a Return node is found
603+
stack = [function_node]
604+
while stack:
605+
node = stack.pop()
606+
if isinstance(node, ast.Return):
607+
return True
608+
stack.extend(ast.iter_child_nodes(node))
609+
return False
603610

604611

605612
def function_is_a_property(function_node: FunctionDef | AsyncFunctionDef) -> bool:

codeflash/tracer.py

Lines changed: 137 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#
1212
from __future__ import annotations
1313

14+
import contextlib
1415
import importlib.machinery
1516
import io
1617
import json
@@ -99,6 +100,7 @@ def __init__(
99100
)
100101
disable = True
101102
self.disable = disable
103+
self._db_lock: threading.Lock | None = None
102104
if self.disable:
103105
return
104106
if sys.getprofile() is not None or sys.gettrace() is not None:
@@ -108,6 +110,9 @@ def __init__(
108110
)
109111
self.disable = True
110112
return
113+
114+
self._db_lock = threading.Lock()
115+
111116
self.con = None
112117
self.output_file = Path(output).resolve()
113118
self.functions = functions
@@ -130,6 +135,7 @@ def __init__(
130135
self.timeout = timeout
131136
self.next_insert = 1000
132137
self.trace_count = 0
138+
self.path_cache = {} # Cache for resolved file paths
133139

134140
# Profiler variables
135141
self.bias = 0 # calibration constant
@@ -178,34 +184,55 @@ def __enter__(self) -> None:
178184
def __exit__(
179185
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
180186
) -> None:
181-
if self.disable:
187+
if self.disable or self._db_lock is None:
182188
return
183189
sys.setprofile(None)
184-
self.con.commit()
185-
console.rule("Codeflash: Traced Program Output End", style="bold blue")
186-
self.create_stats()
190+
threading.setprofile(None)
187191

188-
cur = self.con.cursor()
189-
cur.execute(
190-
"CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
191-
"call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
192-
"cumulative_time_ns INTEGER, callers BLOB)"
193-
)
194-
for func, (cc, nc, tt, ct, callers) in self.stats.items():
195-
remapped_callers = [{"key": k, "value": v} for k, v in callers.items()]
192+
with self._db_lock:
193+
if self.con is None:
194+
return
195+
196+
self.con.commit() # Commit any pending from tracer_logic
197+
console.rule("Codeflash: Traced Program Output End", style="bold blue")
198+
self.create_stats() # This calls snapshot_stats which uses self.timings
199+
200+
cur = self.con.cursor()
196201
cur.execute(
197-
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
198-
(str(Path(func[0]).resolve()), func[1], func[2], func[3], cc, nc, tt, ct, json.dumps(remapped_callers)),
202+
"CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
203+
"call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
204+
"cumulative_time_ns INTEGER, callers BLOB)"
199205
)
200-
self.con.commit()
206+
# self.stats is populated by snapshot_stats() called within create_stats()
207+
# Ensure self.stats is accessed after create_stats() and within the lock if it involves DB data
208+
# For now, assuming self.stats is primarily in-memory after create_stats()
209+
for func, (cc, nc, tt, ct, callers) in self.stats.items():
210+
remapped_callers = [{"key": k, "value": v} for k, v in callers.items()]
211+
cur.execute(
212+
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
213+
(
214+
str(Path(func[0]).resolve()),
215+
func[1],
216+
func[2],
217+
func[3],
218+
cc,
219+
nc,
220+
tt,
221+
ct,
222+
json.dumps(remapped_callers),
223+
),
224+
)
225+
self.con.commit()
201226

202-
self.make_pstats_compatible()
203-
self.print_stats("tottime")
204-
cur = self.con.cursor()
205-
cur.execute("CREATE TABLE total_time (time_ns INTEGER)")
206-
cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,))
207-
self.con.commit()
208-
self.con.close()
227+
self.make_pstats_compatible() # Modifies self.stats and self.timings in-memory
228+
self.print_stats("tottime") # Uses self.stats, prints to console
229+
230+
cur = self.con.cursor() # New cursor
231+
cur.execute("CREATE TABLE total_time (time_ns INTEGER)")
232+
cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,))
233+
self.con.commit()
234+
self.con.close()
235+
self.con = None # Mark connection as closed
209236

210237
# filter any functions where we did not capture the return
211238
self.function_modules = [
@@ -245,18 +272,29 @@ def __exit__(
245272
def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
246273
if event != "call":
247274
return
248-
if self.timeout is not None and (time.time() - self.start_time) > self.timeout:
275+
if None is not self.timeout and (time.time() - self.start_time) > self.timeout:
249276
sys.setprofile(None)
250277
threading.setprofile(None)
251278
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
252279
return
253-
code = frame.f_code
280+
if self.disable or self._db_lock is None or self.con is None:
281+
return
254282

255-
file_name = Path(code.co_filename).resolve()
256-
# TODO : It currently doesn't log the last return call from the first function
283+
code = frame.f_code
257284

285+
# Check function name first before resolving path
258286
if code.co_name in self.ignored_functions:
259287
return
288+
289+
# Now resolve file path only if we need it
290+
co_filename = code.co_filename
291+
if co_filename in self.path_cache:
292+
file_name = self.path_cache[co_filename]
293+
else:
294+
file_name = Path(co_filename).resolve()
295+
self.path_cache[co_filename] = file_name
296+
# TODO : It currently doesn't log the last return call from the first function
297+
260298
if not file_name.is_relative_to(self.project_root):
261299
return
262300
if not file_name.exists():
@@ -266,18 +304,29 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
266304
class_name = None
267305
arguments = frame.f_locals
268306
try:
269-
if (
270-
"self" in arguments
271-
and hasattr(arguments["self"], "__class__")
272-
and hasattr(arguments["self"].__class__, "__name__")
273-
):
274-
class_name = arguments["self"].__class__.__name__
275-
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
276-
class_name = arguments["cls"].__name__
307+
self_arg = arguments.get("self")
308+
if self_arg is not None:
309+
try:
310+
class_name = self_arg.__class__.__name__
311+
except AttributeError:
312+
cls_arg = arguments.get("cls")
313+
if cls_arg is not None:
314+
with contextlib.suppress(AttributeError):
315+
class_name = cls_arg.__name__
316+
else:
317+
cls_arg = arguments.get("cls")
318+
if cls_arg is not None:
319+
with contextlib.suppress(AttributeError):
320+
class_name = cls_arg.__name__
277321
except: # noqa: E722
278322
# someone can override the getattr method and raise an exception. I'm looking at you wrapt
279323
return
280-
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
324+
325+
try:
326+
function_qualified_name = f"{file_name}:{code.co_qualname}"
327+
except AttributeError:
328+
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
329+
281330
if function_qualified_name in self.ignored_qualified_functions:
282331
return
283332
if function_qualified_name not in self.function_count:
@@ -310,49 +359,59 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
310359

311360
# TODO: Also check if this function arguments are unique from the values logged earlier
312361

313-
cur = self.con.cursor()
362+
with self._db_lock:
363+
# Check connection again inside lock, in case __exit__ closed it.
364+
if self.con is None:
365+
return
314366

315-
t_ns = time.perf_counter_ns()
316-
original_recursion_limit = sys.getrecursionlimit()
317-
try:
318-
# pickling can be a recursive operator, so we need to increase the recursion limit
319-
sys.setrecursionlimit(10000)
320-
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
321-
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
322-
# leaks, bad references or side effects when unpickling.
323-
arguments = dict(arguments.items())
324-
if class_name and code.co_name == "__init__":
325-
del arguments["self"]
326-
local_vars = pickle.dumps(arguments, protocol=pickle.HIGHEST_PROTOCOL)
327-
sys.setrecursionlimit(original_recursion_limit)
328-
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
329-
# we retry with dill if pickle fails. It's slower but more comprehensive
367+
cur = self.con.cursor()
368+
369+
t_ns = time.perf_counter_ns()
370+
original_recursion_limit = sys.getrecursionlimit()
330371
try:
331-
local_vars = dill.dumps(arguments, protocol=dill.HIGHEST_PROTOCOL)
372+
# pickling can be a recursive operator, so we need to increase the recursion limit
373+
sys.setrecursionlimit(10000)
374+
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
375+
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
376+
# leaks, bad references or side effects when unpickling.
377+
arguments_copy = dict(arguments.items()) # Use the local 'arguments' from frame.f_locals
378+
if class_name and code.co_name == "__init__" and "self" in arguments_copy:
379+
del arguments_copy["self"]
380+
local_vars = pickle.dumps(arguments_copy, protocol=pickle.HIGHEST_PROTOCOL)
332381
sys.setrecursionlimit(original_recursion_limit)
382+
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
383+
# we retry with dill if pickle fails. It's slower but more comprehensive
384+
try:
385+
sys.setrecursionlimit(10000) # Ensure limit is high for dill too
386+
# arguments_copy should be used here as well if defined above
387+
local_vars = dill.dumps(
388+
arguments_copy if "arguments_copy" in locals() else dict(arguments.items()),
389+
protocol=dill.HIGHEST_PROTOCOL,
390+
)
391+
sys.setrecursionlimit(original_recursion_limit)
392+
393+
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
394+
self.function_count[function_qualified_name] -= 1
395+
return
333396

334-
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
335-
# give up
336-
self.function_count[function_qualified_name] -= 1
337-
return
338-
cur.execute(
339-
"INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
340-
(
341-
event,
342-
code.co_name,
343-
class_name,
344-
str(file_name),
345-
frame.f_lineno,
346-
frame.f_back.__hash__(),
347-
t_ns,
348-
local_vars,
349-
),
350-
)
351-
self.trace_count += 1
352-
self.next_insert -= 1
353-
if self.next_insert == 0:
354-
self.next_insert = 1000
355-
self.con.commit()
397+
cur.execute(
398+
"INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
399+
(
400+
event,
401+
code.co_name,
402+
class_name,
403+
str(file_name),
404+
frame.f_lineno,
405+
frame.f_back.__hash__(),
406+
t_ns,
407+
local_vars,
408+
),
409+
)
410+
self.trace_count += 1
411+
self.next_insert -= 1
412+
if self.next_insert == 0:
413+
self.next_insert = 1000
414+
self.con.commit()
356415

357416
def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None:
358417
# profiler section
@@ -475,8 +534,9 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
475534
cc = cc + 1
476535

477536
if pfn in callers:
478-
callers[pfn] = callers[pfn] + 1 # TODO: gather more
479-
# stats such as the amount of time added to ct courtesy
537+
# Increment call count between these functions
538+
callers[pfn] = callers[pfn] + 1
539+
# Note: This tracks stats such as the amount of time added to ct
480540
# of this specific call, and the contribution to cc
481541
# courtesy of this call.
482542
else:
@@ -703,7 +763,7 @@ def create_stats(self) -> None:
703763

704764
def snapshot_stats(self) -> None:
705765
self.stats = {}
706-
for func, (cc, _ns, tt, ct, caller_dict) in self.timings.items():
766+
for func, (cc, _ns, tt, ct, caller_dict) in list(self.timings.items()):
707767
callers = caller_dict.copy()
708768
nc = 0
709769
for callcnt in callers.values():

poetry.lock

Lines changed: 1 addition & 20 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ types-cffi = ">=1.16.0.20240331"
115115
types-openpyxl = ">=3.1.5.20241020"
116116
types-regex = ">=2024.9.11.20240912"
117117
types-python-dateutil = ">=2.9.0.20241003"
118-
pytest-cov = "^6.0.0"
119118
pytest-benchmark = ">=5.1.0"
120119
types-gevent = "^24.11.0.20241230"
121120
types-greenlet = "^3.1.0.20241221"

0 commit comments

Comments
 (0)