Skip to content

Commit 4debe7e

Browse files
committed
introduce a new integrated "codeflash optimize" command
1 parent 73f6fa0 commit 4debe7e

File tree

3 files changed

+53
-11
lines changed

3 files changed

+53
-11
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ def parse_args() -> Namespace:
2222

2323
init_actions_parser = subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
2424
init_actions_parser.set_defaults(func=install_github_actions)
25+
26+
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize a Python project.")
27+
from codeflash.tracer import main as tracer_main
28+
29+
trace_optimize.set_defaults(func=tracer_main)
30+
2531
parser.add_argument("--file", help="Try to optimize only this file")
2632
parser.add_argument("--function", help="Try to optimize only this function within the given file path")
2733
parser.add_argument(
@@ -64,7 +70,8 @@ def parse_args() -> Namespace:
6470
)
6571
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
6672

67-
args: Namespace = parser.parse_args()
73+
args, unknown_args = parser.parse_known_args()
74+
sys.argv[:] = [sys.argv[0], *unknown_args]
6875
return process_and_validate_cmd_args(args)
6976

7077

codeflash/picklepatch/pickle_patcher.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@ def _create_placeholder(obj: object, error_msg: str, path: list[str]) -> PickleP
7979
except: # noqa: E722
8080
obj_str = f"<unprintable object of type {obj_type.__name__}>"
8181

82-
print(f"Creating placeholder for {obj_type.__name__} at path {'->'.join(path) or 'root'}: {error_msg}")
83-
8482
placeholder = PicklePlaceholder(obj_type.__name__, obj_str, error_msg, path)
8583

8684
# Add this type to our known unpicklable types cache

codeflash/tracer.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def __init__(
123123
self.function_count = defaultdict(int)
124124
self.current_file_path = Path(__file__).resolve()
125125
self.ignored_qualified_functions = {
126-
f"{self.current_file_path}:Tracer:__exit__",
127-
f"{self.current_file_path}:Tracer:__enter__",
126+
f"{self.current_file_path}:Tracer.__exit__",
127+
f"{self.current_file_path}:Tracer.__enter__",
128128
}
129129
self.max_function_count = max_function_count
130130
self.config, found_config_path = parse_config_file(config_file_path)
@@ -133,6 +133,7 @@ def __init__(
133133
self.ignored_functions = {"<listcomp>", "<genexpr>", "<dictcomp>", "<setcomp>", "<lambda>", "<module>"}
134134

135135
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_") # noqa: SLF001
136+
self.replay_test_file_path: Path | None = None
136137

137138
assert timeout is None or timeout > 0, "Timeout should be greater than 0"
138139
self.timeout = timeout
@@ -283,6 +284,7 @@ def __exit__(
283284

284285
with Path(test_file_path).open("w", encoding="utf8") as file:
285286
file.write(replay_test)
287+
self.replay_test_file_path = test_file_path
286288

287289
console.print(
288290
f"Codeflash: Traced {self.trace_count} function calls successfully and replay test created at - {test_file_path}",
@@ -347,8 +349,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
347349
try:
348350
function_qualified_name = f"{file_name}:{code.co_qualname}"
349351
except AttributeError:
350-
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
351-
352+
function_qualified_name = f"{file_name}:{(class_name + '.' if class_name else '')}{code.co_name}"
352353
if function_qualified_name in self.ignored_qualified_functions:
353354
return
354355
if function_qualified_name not in self.function_count:
@@ -701,7 +702,7 @@ def print_stats(self, sort: str | int | tuple = -1) -> None:
701702
border_style="blue",
702703
title="[bold]Function Profile[/bold] (ordered by internal time)",
703704
title_style="cyan",
704-
caption=f"Showing top 25 of {len(self.stats)} functions",
705+
caption=f"Showing top {min(25, len(self.stats))} of {len(self.stats)} functions",
705706
)
706707

707708
table.add_column("Calls", justify="right", style="green", width=10)
@@ -793,7 +794,7 @@ def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, An
793794

794795
def main() -> ArgumentParser:
795796
parser = ArgumentParser(allow_abbrev=False)
796-
parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", required=True)
797+
parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", default="codeflash.trace")
797798
parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None)
798799
parser.add_argument(
799800
"--max-function-count",
@@ -815,6 +816,7 @@ def main() -> ArgumentParser:
815816
"with the codeflash config. Will be auto-discovered if not specified.",
816817
default=None,
817818
)
819+
parser.add_argument("--trace-only", action="store_true", help="Trace and create replay tests only, don't optimize")
818820

819821
if not sys.argv[1:]:
820822
parser.print_usage()
@@ -827,6 +829,7 @@ def main() -> ArgumentParser:
827829
# to the output file at startup.
828830
if args.outfile is not None:
829831
args.outfile = Path(args.outfile).resolve()
832+
outfile = args.outfile
830833

831834
if len(unknown_args) > 0:
832835
if args.module:
@@ -848,14 +851,48 @@ def main() -> ArgumentParser:
848851
"__cached__": None,
849852
}
850853
try:
851-
Tracer(
854+
tracer = Tracer(
852855
output=args.outfile,
853856
functions=args.only_functions,
854857
max_function_count=args.max_function_count,
855858
timeout=args.tracer_timeout,
856859
config_file_path=args.codeflash_config,
857860
command=" ".join(sys.argv),
858-
).runctx(code, globs, None)
861+
)
862+
tracer.runctx(code, globs, None)
863+
replay_test_path = tracer.replay_test_file_path
864+
if not args.trace_only and replay_test_path is not None:
865+
del tracer
866+
867+
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
868+
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO
869+
from codeflash.cli_cmds.console import paneled_text
870+
from codeflash.telemetry import posthog_cf
871+
from codeflash.telemetry.sentry import init_sentry
872+
873+
sys.argv = ["codeflash", "--replay-test", str(replay_test_path)]
874+
875+
args = parse_args()
876+
paneled_text(
877+
CODEFLASH_LOGO,
878+
panel_args={"title": "https://codeflash.ai", "expand": False},
879+
text_args={"style": "bold gold3"},
880+
)
881+
882+
args = process_pyproject_config(args)
883+
args.previous_checkpoint_functions = None
884+
init_sentry(not args.disable_telemetry, exclude_errors=True)
885+
posthog_cf.initialize_posthog(not args.disable_telemetry)
886+
887+
from codeflash.optimization import optimizer
888+
889+
optimizer.run_with_args(args)
890+
891+
# Delete the trace file and the replay test file if they exist
892+
if outfile:
893+
outfile.unlink(missing_ok=True)
894+
if replay_test_path:
895+
replay_test_path.unlink(missing_ok=True)
859896

860897
except BrokenPipeError as exc:
861898
# Prevent "Exception ignored" during interpreter shutdown.

0 commit comments

Comments
 (0)