Skip to content

Commit 60922b8

Browse files
committed
allow args for the optimize command too
1 parent c09f32e commit 60922b8

File tree

2 files changed

+50
-9
lines changed

2 files changed

+50
-9
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,34 @@ def parse_args() -> Namespace:
2424
init_actions_parser.set_defaults(func=install_github_actions)
2525

2626
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize a Python project.")
27+
2728
from codeflash.tracer import main as tracer_main
2829

2930
trace_optimize.set_defaults(func=tracer_main)
3031

32+
trace_optimize.add_argument(
33+
"--max-function-count",
34+
type=int,
35+
default=100,
36+
help="The maximum number of times to trace a single function. More calls to a function will not be traced. Default is 100.",
37+
)
38+
trace_optimize.add_argument(
39+
"--timeout",
40+
type=int,
41+
help="The maximum time in seconds to trace the entire workflow. Default is indefinite. This is useful while tracing really long workflows, to not wait indefinitely.",
42+
)
43+
trace_optimize.add_argument(
44+
"--output",
45+
type=str,
46+
default="codeflash.trace",
47+
help="The file to save the trace to. Default is codeflash.trace.",
48+
)
49+
trace_optimize.add_argument(
50+
"--config-file-path",
51+
type=str,
52+
help="The path to the pyproject.toml file which stores the Codeflash config. This is auto-discovered by default.",
53+
)
54+
3155
parser.add_argument("--file", help="Try to optimize only this file")
3256
parser.add_argument("--function", help="Try to optimize only this function within the given file path")
3357
parser.add_argument(

codeflash/tracer.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from codeflash.verification.verification_utils import get_test_file_path
4747

4848
if TYPE_CHECKING:
49+
from argparse import Namespace
4950
from types import FrameType, TracebackType
5051

5152

@@ -798,7 +799,7 @@ def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, An
798799
return self
799800

800801

801-
def main() -> ArgumentParser:
802+
def main(args: Namespace | None = None) -> ArgumentParser:
802803
parser = ArgumentParser(allow_abbrev=False)
803804
parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", default="codeflash.trace")
804805
parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None)
@@ -824,18 +825,34 @@ def main() -> ArgumentParser:
824825
)
825826
parser.add_argument("--trace-only", action="store_true", help="Trace and create replay tests only, don't optimize")
826827

827-
if not sys.argv[1:]:
828-
parser.print_usage()
829-
sys.exit(2)
828+
if args is not None:
829+
parsed_args = args
830+
parsed_args.outfile = getattr(args, "output", "codeflash.trace")
831+
parsed_args.only_functions = getattr(args, "only_functions", None)
832+
parsed_args.max_function_count = getattr(args, "max_function_count", 100)
833+
parsed_args.tracer_timeout = getattr(args, "timeout", None)
834+
parsed_args.codeflash_config = getattr(args, "config_file_path", None)
835+
parsed_args.trace_only = getattr(args, "trace_only", False)
836+
parsed_args.module = False
837+
838+
if getattr(args, "disable", False):
839+
console.rule("Codeflash: Tracer disabled by --disable option", style="bold red")
840+
return parser
841+
842+
unknown_args = []
843+
else:
844+
if not sys.argv[1:]:
845+
parser.print_usage()
846+
sys.exit(2)
830847

831-
args, unknown_args = parser.parse_known_args()
832-
sys.argv[:] = unknown_args
848+
parsed_args, unknown_args = parser.parse_known_args()
849+
sys.argv[:] = unknown_args
833850

834851
# The script that we're profiling may chdir, so capture the absolute path
835852
# to the output file at startup.
836-
if args.outfile is not None:
837-
args.outfile = Path(args.outfile).resolve()
838-
outfile = args.outfile
853+
if parsed_args.outfile is not None:
854+
parsed_args.outfile = Path(parsed_args.outfile).resolve()
855+
outfile = parsed_args.outfile
839856

840857
if len(unknown_args) > 0:
841858
if args.module:

0 commit comments

Comments
 (0)