Skip to content

Replay tests and tracer improvments #316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
from codeflash.cli_cmds.console import logger
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.version import __version__ as version

Expand Down Expand Up @@ -42,7 +43,7 @@ def parse_args() -> Namespace:
)
parser.add_argument("--test-framework", choices=["pytest", "unittest"], default="pytest")
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
parser.add_argument("--replay-test", type=str, help="Path to replay test to optimize functions from")
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
parser.add_argument(
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
)
Expand Down Expand Up @@ -83,25 +84,22 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
sys.exit()
if not check_running_in_git_repo(module_root=args.module_root):
if not confirm_proceeding_with_no_git_repo():
logger.critical("No git repository detected and user aborted run. Exiting...")
sys.exit(1)
exit_with_message("No git repository detected and user aborted run. Exiting...", error_on_exit=True)
args.no_pr = True
if args.function and not args.file:
logger.error("If you specify a --function, you must specify the --file it is in")
sys.exit(1)
exit_with_message("If you specify a --function, you must specify the --file it is in", error_on_exit=True)
if args.file:
if not Path(args.file).exists():
logger.error(f"File {args.file} does not exist")
sys.exit(1)
exit_with_message(f"File {args.file} does not exist", error_on_exit=True)
args.file = Path(args.file).resolve()
if not args.no_pr:
owner, repo = get_repo_owner_and_name()
require_github_app_or_exit(owner, repo)
if args.replay_test:
if not Path(args.replay_test).is_file():
logger.error(f"Replay test file {args.replay_test} does not exist")
sys.exit(1)
args.replay_test = Path(args.replay_test).resolve()
for test_path in args.replay_test:
if not Path(test_path).is_file():
exit_with_message(f"Replay test file {test_path} does not exist", error_on_exit=True)
args.replay_test = [Path(replay_test).resolve() for replay_test in args.replay_test]

return args

Expand All @@ -110,8 +108,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
try:
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
except ValueError as e:
logger.error(e)
sys.exit(1)
exit_with_message(f"Error parsing config file: {e}", error_on_exit=True)
supported_keys = [
"module_root",
"tests_root",
Expand Down Expand Up @@ -206,8 +203,7 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
)
apologize_and_exit()
if not args.no_pr and not check_and_push_branch(git_repo):
logger.critical("❌ Branch is not pushed. Exiting...")
sys.exit(1)
exit_with_message("Branch is not pushed...", error_on_exit=True)
owner, repo = get_repo_owner_and_name(git_repo)
if not args.no_pr:
require_github_app_or_exit(owner, repo)
Expand Down
2 changes: 1 addition & 1 deletion codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def custom_addopts() -> None:
# Backup original addopts
original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "")
# nothing to do if no addopts present
if original_addopts != "":
if original_addopts != "" and isinstance(original_addopts, list):
original_addopts = [x.strip() for x in original_addopts]
non_blacklist_plugin_args = re.sub(r"-n(?: +|=)\S+", "", " ".join(original_addopts)).split(" ")
non_blacklist_plugin_args = [x for x in non_blacklist_plugin_args if x != ""]
Expand Down
8 changes: 4 additions & 4 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:

def get_functions_to_optimize(
optimize_all: str | None,
replay_test: str | None,
replay_test: list[Path] | None,
file: Path | None,
only_get_this_function: str | None,
test_cfg: TestConfig,
Expand All @@ -169,7 +169,7 @@ def get_functions_to_optimize(
logger.info("Finding all functions in the module '%s'…", optimize_all)
console.rule()
functions = get_all_files_and_functions(Path(optimize_all))
elif replay_test is not None:
elif replay_test:
functions = get_all_replay_test_functions(
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
)
Expand Down Expand Up @@ -271,9 +271,9 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt


def get_all_replay_test_functions(
replay_test: Path, test_cfg: TestConfig, project_root_path: Path
replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path
) -> dict[Path, list[FunctionToOptimize]]:
function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=replay_test)
# Get the absolute file paths for each function, excluding class name if present
filtered_valid_functions = defaultdict(list)
file_to_functions_map = defaultdict(list)
Expand Down
4 changes: 2 additions & 2 deletions codeflash/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
del arguments_copy["self"]
local_vars = pickle.dumps(arguments_copy, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
except Exception:
# we retry with dill if pickle fails. It's slower but more comprehensive
try:
sys.setrecursionlimit(10000) # Ensure limit is high for dill too
Expand All @@ -390,7 +390,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
)
sys.setrecursionlimit(original_recursion_limit)

except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
except Exception:
self.function_count[function_qualified_name] -= 1
return

Expand Down
Loading