diff --git a/codeflash/code_utils/tabulate.py b/codeflash/code_utils/tabulate.py index bc42cd03..0d132f93 100644 --- a/codeflash/code_utils/tabulate.py +++ b/codeflash/code_utils/tabulate.py @@ -1,5 +1,11 @@ """Adapted from tabulate (https://github.com/astanin/python-tabulate) written by Sergey Astanin and contributors (MIT License).""" +from __future__ import annotations +import warnings +import wcwidth +from itertools import chain, zip_longest as izip_longest +from collections.abc import Iterable + """Pretty-print tabular data.""" # ruff: noqa @@ -650,94 +656,80 @@ def tabulate( rowalign=None, maxheadercolwidths=None, ): + # Shortcuts & locals if tabular_data is None: tabular_data = [] + # 1. Normalize tabular data once list_of_lists, headers, headers_pad = _normalize_tabular_data(tabular_data, headers, showindex=showindex) - list_of_lists, separating_lines = _remove_separating_lines(list_of_lists) + list_of_lists, _ = _remove_separating_lines(list_of_lists) # separating_lines not used - # PrettyTable formatting does not use any extra padding. - # Numbers are not parsed and are treated the same as strings for alignment. - # Check if pretty is the format being used and override the defaults so it - # does not impact other formats. - min_padding = MIN_PADDING + # 2. Pre-calculate format switches (reduce repeated logic) + min_padding = 0 if tablefmt == "pretty" else MIN_PADDING if tablefmt == "pretty": - min_padding = 0 disable_numparse = True numalign = "center" if numalign == _DEFAULT_ALIGN else numalign stralign = "center" if stralign == _DEFAULT_ALIGN else stralign else: numalign = "decimal" if numalign == _DEFAULT_ALIGN else numalign stralign = "left" if stralign == _DEFAULT_ALIGN else stralign - - # 'colon_grid' uses colons in the line beneath the header to represent a column's - # alignment instead of literally aligning the text differently. Hence, - # left alignment of the data in the text output is enforced. if tablefmt == "colon_grid": colglobalalign = "left" headersglobalalign = "left" - # optimization: look for ANSI control codes once, - # enable smart width functions only if a control code is found - # - # convert the headers and rows into a single, tab-delimited string ensuring - # that any bytestrings are decoded safely (i.e. errors ignored) - plain_text = "\t".join( - chain( - # headers - map(_to_str, headers), - # rows: chain the rows together into a single iterable after mapping - # the bytestring conversino to each cell value - chain.from_iterable(map(_to_str, row) for row in list_of_lists), - ) - ) - + # 3. Prepare plain_text for features detection + # Flatten quite efficiently + # (The main cost here is table flattening for detection. Avoid generator object cost with a one-liner.) + if headers: + iters = chain(map(_to_str, headers), (cell for row in list_of_lists for cell in map(_to_str, row))) + else: + iters = (cell for row in list_of_lists for cell in map(_to_str, row)) + plain_text = "\t".join(iters) has_invisible = _ansi_codes.search(plain_text) is not None - enable_widechars = wcwidth is not None and WIDE_CHARS_MODE + is_multiline = False if not isinstance(tablefmt, TableFormat) and tablefmt in multiline_formats and _is_multiline(plain_text): tablefmt = multiline_formats.get(tablefmt, tablefmt) is_multiline = True - else: - is_multiline = False width_fn = _choose_width_fn(has_invisible, enable_widechars, is_multiline) - # format rows and columns, convert numeric values to strings - cols = list(izip_longest(*list_of_lists)) - numparses = _expand_numparse(disable_numparse, len(cols)) - coltypes = [_column_type(col, numparse=np) for col, np in zip(cols, numparses)] - if isinstance(floatfmt, str): # old version - float_formats = len(cols) * [floatfmt] # just duplicate the string to use in each column - else: # if floatfmt is list, tuple etc we have one per column - float_formats = list(floatfmt) - if len(float_formats) < len(cols): - float_formats.extend((len(cols) - len(float_formats)) * [_DEFAULT_FLOATFMT]) - if isinstance(intfmt, str): # old version - int_formats = len(cols) * [intfmt] # just duplicate the string to use in each column - else: # if intfmt is list, tuple etc we have one per column - int_formats = list(intfmt) - if len(int_formats) < len(cols): - int_formats.extend((len(cols) - len(int_formats)) * [_DEFAULT_INTFMT]) - if isinstance(missingval, str): - missing_vals = len(cols) * [missingval] + # 4. Transpose data only once, for column-oriented transforms + # Avoid expensive list + zip + star unpacking overhead by storing list_of_lists directly + data_rows = list_of_lists + ncols = len(data_rows[0]) if data_rows else len(headers) + cols = [list(col) for col in izip_longest(*data_rows, fillvalue="")] + + # 5. Pre-compute per-column formatting parameters (avoid loop in loop) + numparses = _expand_numparse(disable_numparse, ncols) + coltypes = [] + append_coltype = coltypes.append + for col, np in zip(cols, numparses): + append_coltype(_column_type(col, numparse=np)) + float_formats = ( + [floatfmt] * ncols + if isinstance(floatfmt, str) + else list(floatfmt) + [_DEFAULT_FLOATFMT] * (ncols - len(floatfmt)) + ) + int_formats = ( + [intfmt] * ncols if isinstance(intfmt, str) else list(intfmt) + [_DEFAULT_INTFMT] * (ncols - len(intfmt)) + ) + missing_vals = ( + [missingval] * ncols + if isinstance(missingval, str) + else list(missingval) + [_DEFAULT_MISSINGVAL] * (ncols - len(missingval)) + ) + + # 6. Pre-format all columns (avoid repeated conversion/type detection) + formatted_cols = [] + for c, ct, fl_fmt, int_fmt, miss_v in zip(cols, coltypes, float_formats, int_formats, missing_vals): + formatted_cols.append([_format(v, ct, fl_fmt, int_fmt, miss_v, has_invisible) for v in c]) + + # 7. Alignment selection (avoid dict/set lookups per-column by building list-style) + if colglobalalign is not None: + aligns = [colglobalalign] * ncols else: - missing_vals = list(missingval) - if len(missing_vals) < len(cols): - missing_vals.extend((len(cols) - len(missing_vals)) * [_DEFAULT_MISSINGVAL]) - cols = [ - [_format(v, ct, fl_fmt, int_fmt, miss_v, has_invisible) for v in c] - for c, ct, fl_fmt, int_fmt, miss_v in zip(cols, coltypes, float_formats, int_formats, missing_vals) - ] - - # align columns - # first set global alignment - if colglobalalign is not None: # if global alignment provided - aligns = [colglobalalign] * len(cols) - else: # default aligns = [numalign if ct in {int, float} else stralign for ct in coltypes] - # then specific alignments if colalign is not None: - assert isinstance(colalign, Iterable) if isinstance(colalign, str): warnings.warn( f"As a string, `colalign` is interpreted as {[c for c in colalign]}. " @@ -745,33 +737,35 @@ def tabulate( stacklevel=2, ) for idx, align in enumerate(colalign): - if not idx < len(aligns): + if idx >= len(aligns): break if align != "global": aligns[idx] = align - minwidths = [width_fn(h) + min_padding for h in headers] if headers else [0] * len(cols) - aligns_copy = aligns.copy() - # Reset alignments in copy of alignments list to "left" for 'colon_grid' format, - # which enforces left alignment in the text output of the data. - if tablefmt == "colon_grid": - aligns_copy = ["left"] * len(cols) - cols = [ - _align_column(c, a, minw, has_invisible, enable_widechars, is_multiline, preserve_whitespace) - for c, a, minw in zip(cols, aligns_copy, minwidths) - ] - aligns_headers = None + # 8. Compute minimum widths in a branch to avoid repeated expression evaluation + if headers: + # Precompute column min widths (includes header + padding) + minwidths = [width_fn(h) + min_padding for h in headers] + else: + minwidths = [0] * ncols + + aligns_copy = aligns if tablefmt != "colon_grid" else ["left"] * ncols + + # 9. Align all columns (single allocation per column) + aligned_cols = [] + for c, a, minw in zip(formatted_cols, aligns_copy, minwidths): + aligned_cols.append( + _align_column(c, a, minw, has_invisible, enable_widechars, is_multiline, preserve_whitespace) + ) + + # 10. Handle header alignment and formatting if headers: - # align headers and add headers - t_cols = cols or [[""]] * len(headers) - # first set global alignment - if headersglobalalign is not None: # if global alignment provided - aligns_headers = [headersglobalalign] * len(t_cols) - else: # default + t_cols = aligned_cols or [[""]] * ncols + if headersglobalalign is not None: + aligns_headers = [headersglobalalign] * ncols + else: aligns_headers = aligns or [stralign] * len(headers) - # then specific header alignments if headersalign is not None: - assert isinstance(headersalign, Iterable) if isinstance(headersalign, str): warnings.warn( f"As a string, `headersalign` is interpreted as {[c for c in headersalign]}. " @@ -781,28 +775,47 @@ def tabulate( ) for idx, align in enumerate(headersalign): hidx = headers_pad + idx - if not hidx < len(aligns_headers): + if hidx >= len(aligns_headers): break - if align == "same" and hidx < len(aligns): # same as column align + if align == "same" and hidx < len(aligns): aligns_headers[hidx] = aligns[hidx] elif align != "global": aligns_headers[hidx] = align - minwidths = [max(minw, max(width_fn(cl) for cl in c)) for minw, c in zip(minwidths, t_cols)] + # 1. Optimize minwidths by combining two loops into one, avoid repeated width_fn calls + for i in range(ncols): + if t_cols[i]: + minwidths[i] = max(minwidths[i], max(width_fn(x) for x in t_cols[i])) + # 2. Optimize headers alignment: single pass, in-place headers = [ _align_header(h, a, minw, width_fn(h), is_multiline, width_fn) for h, a, minw in zip(headers, aligns_headers, minwidths) ] - rows = list(zip(*cols)) + # Transpose aligned_cols for rows + rows = list(zip(*aligned_cols)) else: - minwidths = [max(width_fn(cl) for cl in c) for c in cols] - rows = list(zip(*cols)) + # No headers: just use widest cell for minwidth + for i in range(ncols): + if aligned_cols[i]: + minwidths[i] = max(width_fn(x) for x in aligned_cols[i]) + rows = list(zip(*aligned_cols)) + # Get TableFormat up front if not isinstance(tablefmt, TableFormat): tablefmt = _table_formats.get(tablefmt, _table_formats["simple"]) ra_default = rowalign if isinstance(rowalign, str) else None rowaligns = _expand_iterable(rowalign, len(rows), ra_default) - return _format_table(tablefmt, headers, aligns_headers, rows, minwidths, aligns, is_multiline, rowaligns=rowaligns) + # 11. Table rendering (as per original logic) + return _format_table( + tablefmt, + headers, + aligns_headers if headers else None, + rows, + minwidths, + aligns, + is_multiline, + rowaligns=rowaligns, + ) def _expand_numparse(disable_numparse, column_count): diff --git a/codeflash/code_utils/time_utils.py b/codeflash/code_utils/time_utils.py index 4e32eeda..6a086077 100644 --- a/codeflash/code_utils/time_utils.py +++ b/codeflash/code_utils/time_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime as dt import re @@ -53,47 +55,40 @@ def humanize_runtime(time_in_ns: int) -> str: def format_time(nanoseconds: int) -> str: """Format nanoseconds into a human-readable string with 3 significant digits when needed.""" - # Define conversion factors and units + # Fast branch for correct input if not isinstance(nanoseconds, int): raise TypeError("Input must be an integer.") if nanoseconds < 0: raise ValueError("Input must be a positive integer.") - conversions = [(1_000_000_000, "s"), (1_000_000, "ms"), (1_000, "μs"), (1, "ns")] - - # Handle nanoseconds case directly (no decimal formatting needed) if nanoseconds < 1_000: return f"{nanoseconds}ns" - - # Find appropriate unit - for divisor, unit in conversions: - if nanoseconds >= divisor: - value = nanoseconds / divisor - int_value = nanoseconds // divisor - - # Use integer formatting for values >= 100 - if int_value >= 100: - formatted_value = f"{int_value:.0f}" - # Format with precision for 3 significant digits - elif value >= 100: - formatted_value = f"{value:.0f}" - elif value >= 10: - formatted_value = f"{value:.1f}" + # Avoid extra allocations by not rebuilding the conversion table every time + convs = ((1_000_000_000, "s"), (1_000_000, "ms"), (1_000, "μs"), (1, "ns")) + n = nanoseconds + for div, unit in convs: + if n >= div: + val = n / div + ival = n // div + if ival >= 100: + fval = f"{ival:.0f}" + elif val >= 100: + fval = f"{val:.0f}" + elif val >= 10: + fval = f"{val:.1f}" else: - formatted_value = f"{value:.2f}" - - return f"{formatted_value}{unit}" - - # This should never be reached, but included for completeness + fval = f"{val:.2f}" + return f"{fval}{unit}" + # Defensive fallback for completeness return f"{nanoseconds}ns" def format_perf(percentage: float) -> str: """Format percentage into a human-readable string with 3 significant digits when needed.""" - percentage_abs = abs(percentage) - if percentage_abs >= 100: + abs_perc = abs(percentage) + if abs_perc >= 100: return f"{percentage:.0f}" - if percentage_abs >= 10: + if abs_perc >= 10: return f"{percentage:.1f}" - if percentage_abs >= 1: + if abs_perc >= 1: return f"{percentage:.2f}" return f"{percentage:.3f}" diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index 706947d1..827f0ab3 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -20,7 +20,9 @@ from codeflash.code_utils.tabulate import tabulate from codeflash.code_utils.time_utils import format_perf, format_time from codeflash.github.PrComment import FileDiffContent, PrComment +from codeflash.models.models import FunctionCalledInTest, InvocationId from codeflash.result.critic import performance_gain +from codeflash.verification.verification_utils import TestConfig if TYPE_CHECKING: from codeflash.models.models import FunctionCalledInTest, InvocationId @@ -38,85 +40,82 @@ def existing_tests_source_for( test_files = function_to_tests.get(function_qualified_name_with_modules_from_root) if not test_files: return "" - output: str = "" + output = "" rows = [] headers = ["Test File::Test Function", "Original ⏱️", "Optimized ⏱️", "Speedup"] tests_root = test_cfg.tests_root - original_tests_to_runtimes: dict[Path, dict[str, int]] = {} - optimized_tests_to_runtimes: dict[Path, dict[str, int]] = {} - non_generated_tests = set() - for test_file in test_files: - non_generated_tests.add(test_file.tests_in_file.test_file) - # TODO confirm that original and optimized have the same keys - all_invocation_ids = original_runtimes_all.keys() | optimized_runtimes_all.keys() + + # Build non_generated_tests set (minimal extra work) + non_generated_tests = {t.tests_in_file.test_file for t in test_files} + + # Use defaultdict paddings (saves many dict lookups and conditionals) + from collections import defaultdict + + original_tests_to_runtimes = defaultdict(dict) + optimized_tests_to_runtimes = defaultdict(dict) + + # Union all invocation ids from both dicts, filter early by test file + all_invocation_ids = set(original_runtimes_all) | set(optimized_runtimes_all) + path_cache = {} for invocation_id in all_invocation_ids: - abs_path = Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve() + # Path construction optimized: use a cache + tid = invocation_id.test_module_path + if tid in path_cache: + abs_path = path_cache[tid] + else: + abs_path = Path(tid.replace(".", os.sep)).with_suffix(".py").resolve() # expensive + path_cache[tid] = abs_path if abs_path not in non_generated_tests: continue - if abs_path not in original_tests_to_runtimes: - original_tests_to_runtimes[abs_path] = {} - if abs_path not in optimized_tests_to_runtimes: - optimized_tests_to_runtimes[abs_path] = {} + # Update per-path, per-name runtime dicts qualified_name = ( - invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator] + f"{invocation_id.test_class_name}.{invocation_id.test_function_name}" if invocation_id.test_class_name else invocation_id.test_function_name ) - if qualified_name not in original_tests_to_runtimes[abs_path]: - original_tests_to_runtimes[abs_path][qualified_name] = 0 # type: ignore[index] - if qualified_name not in optimized_tests_to_runtimes[abs_path]: - optimized_tests_to_runtimes[abs_path][qualified_name] = 0 # type: ignore[index] + # Initialize to 0 only if not present; avoid redundant assignment + orig = original_tests_to_runtimes[abs_path] + opt = optimized_tests_to_runtimes[abs_path] + if qualified_name not in orig: + orig[qualified_name] = 0 + if qualified_name not in opt: + opt[qualified_name] = 0 if invocation_id in original_runtimes_all: - original_tests_to_runtimes[abs_path][qualified_name] += min(original_runtimes_all[invocation_id]) # type: ignore[index] + orig[qualified_name] += min(original_runtimes_all[invocation_id]) if invocation_id in optimized_runtimes_all: - optimized_tests_to_runtimes[abs_path][qualified_name] += min(optimized_runtimes_all[invocation_id]) # type: ignore[index] - # parse into string - all_abs_paths = ( - original_tests_to_runtimes.keys() - ) # both will have the same keys as some default values are assigned in the previous loop + opt[qualified_name] += min(optimized_runtimes_all[invocation_id]) + + # Collect output rows in one pass + all_abs_paths = list(original_tests_to_runtimes.keys()) for filename in sorted(all_abs_paths): - all_qualified_names = original_tests_to_runtimes[ - filename - ].keys() # both will have the same keys as some default values are assigned in the previous loop - for qualified_name in sorted(all_qualified_names): - # if not present in optimized output nan - if ( - original_tests_to_runtimes[filename][qualified_name] != 0 - and optimized_tests_to_runtimes[filename][qualified_name] != 0 - ): - print_optimized_runtime = format_time(optimized_tests_to_runtimes[filename][qualified_name]) - print_original_runtime = format_time(original_tests_to_runtimes[filename][qualified_name]) + orig_runtimes = original_tests_to_runtimes[filename] + opt_runtimes = optimized_tests_to_runtimes[filename] + qualified_names = sorted(orig_runtimes) + # Only process func names with nonzero original and optimized + for qn in qualified_names: + o_time = orig_runtimes[qn] + opt_time = opt_runtimes[qn] + if o_time != 0 and opt_time != 0: + print_optimized_runtime = format_time(opt_time) + print_original_runtime = format_time(o_time) print_filename = filename.relative_to(tests_root) - greater = ( - optimized_tests_to_runtimes[filename][qualified_name] - > original_tests_to_runtimes[filename][qualified_name] - ) - perf_gain = format_perf( - performance_gain( - original_runtime_ns=original_tests_to_runtimes[filename][qualified_name], - optimized_runtime_ns=optimized_tests_to_runtimes[filename][qualified_name], - ) - * 100 - ) - if greater: - rows.append( - [ - f"`{print_filename.as_posix()}::{qualified_name}`", - f"{print_original_runtime}", - f"{print_optimized_runtime}", - f"⚠️{perf_gain}%", - ] - ) + # Branch for emoji + perf_gain_val = performance_gain(original_runtime_ns=o_time, optimized_runtime_ns=opt_time) * 100 + perf_gain_str = format_perf(perf_gain_val) + if opt_time > o_time: + emoji = "⚠️" else: - rows.append( - [ - f"`{print_filename.as_posix()}::{qualified_name}`", - f"{print_original_runtime}", - f"{print_optimized_runtime}", - f"✅{perf_gain}%", - ] - ) - output += tabulate( # type: ignore[no-untyped-call] + emoji = "✅" + rows.append( + [ + f"`{print_filename.as_posix()}::{qn}`", + print_original_runtime, + print_optimized_runtime, + f"{emoji}{perf_gain_str}%", + ] + ) + + output += tabulate( headers=headers, tabular_data=rows, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True ) output += "\n"