diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 81ab84d4..0c4bb252 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -10,6 +10,7 @@ from pydantic.json import pydantic_encoder from codeflash.cli_cmds.console import console, logger +from codeflash.code_utils.code_utils import get_installed_packages from codeflash.code_utils.env_utils import get_codeflash_api_key from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name from codeflash.models.models import OptimizedCandidate @@ -27,6 +28,7 @@ class AiServiceClient: def __init__(self) -> None: self.base_url = self.get_aiservice_base_url() self.headers = {"Authorization": f"Bearer {get_codeflash_api_key()}", "Connection": "close"} + self.installed_packages = get_installed_packages() def get_aiservice_base_url(self) -> str: if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local": @@ -66,6 +68,9 @@ def make_ai_service_request( """ url = f"{self.base_url}/ai{endpoint}" if method.upper() == "POST": + assert payload is not None, "Payload must be provided for POST requests" + if self.installed_packages: + payload["installed_packages"] = self.installed_packages json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) headers = {**self.headers, "Content-Type": "application/json"} response = requests.post(url, data=json_payload, headers=headers, timeout=timeout) diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 82a5b979..0c860298 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -10,6 +10,7 @@ from functools import lru_cache from pathlib import Path from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING import tomlkit @@ -19,8 +20,12 @@ ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE) +if TYPE_CHECKING: + from collections.abc import Generator + + @contextmanager -def custom_addopts() -> None: +def custom_addopts() -> Generator[None, None, None]: pyproject_file = find_pyproject_toml() original_content = None non_blacklist_plugin_args = "" @@ -58,7 +63,7 @@ def custom_addopts() -> None: @contextmanager -def add_addopts_to_pyproject() -> None: +def add_addopts_to_pyproject() -> Generator[None, None, None]: pyproject_file = find_pyproject_toml() original_content = None try: @@ -220,3 +225,41 @@ def exit_with_message(message: str, *, error_on_exit: bool = False) -> None: paneled_text(message, panel_args={"style": "red"}) sys.exit(1 if error_on_exit else 0) + + +blacklist_installed_pkgs = { + "codeflash", + "pytest", + "coverage", + "__", # this is for private packages or ones that contain "__" in order to mangle names i.e 3204bda914b7f2c6f497__mypyc + "setuptools", + "pip", + "wheel", + "importlib_metadata", + "importlib_resources", + "isort", + "black", + "tomlkit", + "stubs", +} + + +def get_installed_packages() -> list[str]: + try: + try: + import importlib.metadata as importlib_metadata + except ImportError: + import importlib_metadata + except ImportError: + return [] + + try: + pkgs = importlib_metadata.packages_distributions().keys() + except AttributeError: + pkgs = [dist.metadata.get("Name", "") for dist in importlib_metadata.distributions()] + + return [ + pkg + for pkg in pkgs + if pkg and not pkg.startswith("_") and not any(blacklisted in pkg for blacklisted in blacklist_installed_pkgs) + ] diff --git a/pyproject.toml b/pyproject.toml index 3a990197..e056740f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -237,7 +237,8 @@ ignore = [ "S301", "D104", "PERF203", - "LOG015" + "LOG015", + "PLC0415" ] [tool.ruff.lint.flake8-type-checking]