Skip to content

add utility to get installed libraries #373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic.json import pydantic_encoder

from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_utils import get_installed_packages
from codeflash.code_utils.env_utils import get_codeflash_api_key
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
from codeflash.models.models import OptimizedCandidate
Expand All @@ -27,6 +28,7 @@ class AiServiceClient:
def __init__(self) -> None:
self.base_url = self.get_aiservice_base_url()
self.headers = {"Authorization": f"Bearer {get_codeflash_api_key()}", "Connection": "close"}
self.installed_packages = get_installed_packages()

def get_aiservice_base_url(self) -> str:
if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local":
Expand Down Expand Up @@ -66,6 +68,9 @@ def make_ai_service_request(
"""
url = f"{self.base_url}/ai{endpoint}"
if method.upper() == "POST":
assert payload is not None, "Payload must be provided for POST requests"
if self.installed_packages:
payload["installed_packages"] = self.installed_packages
json_payload = json.dumps(payload, indent=None, default=pydantic_encoder)
headers = {**self.headers, "Content-Type": "application/json"}
response = requests.post(url, data=json_payload, headers=headers, timeout=timeout)
Expand Down
47 changes: 45 additions & 2 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from functools import lru_cache
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING

import tomlkit

Expand All @@ -19,8 +20,12 @@
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)


if TYPE_CHECKING:
from collections.abc import Generator


@contextmanager
def custom_addopts() -> None:
def custom_addopts() -> Generator[None, None, None]:
pyproject_file = find_pyproject_toml()
original_content = None
non_blacklist_plugin_args = ""
Expand Down Expand Up @@ -58,7 +63,7 @@ def custom_addopts() -> None:


@contextmanager
def add_addopts_to_pyproject() -> None:
def add_addopts_to_pyproject() -> Generator[None, None, None]:
pyproject_file = find_pyproject_toml()
original_content = None
try:
Expand Down Expand Up @@ -220,3 +225,41 @@ def exit_with_message(message: str, *, error_on_exit: bool = False) -> None:
paneled_text(message, panel_args={"style": "red"})

sys.exit(1 if error_on_exit else 0)


blacklist_installed_pkgs = {
"codeflash",
"pytest",
"coverage",
"__", # this is for private packages or ones that contain "__" in order to mangle names i.e 3204bda914b7f2c6f497__mypyc
"setuptools",
"pip",
"wheel",
"importlib_metadata",
"importlib_resources",
"isort",
"black",
"tomlkit",
"stubs",
}


def get_installed_packages() -> list[str]:
try:
try:
import importlib.metadata as importlib_metadata
except ImportError:
import importlib_metadata
except ImportError:
return []

try:
pkgs = importlib_metadata.packages_distributions().keys()
except AttributeError:
pkgs = [dist.metadata.get("Name", "") for dist in importlib_metadata.distributions()]

return [
pkg
for pkg in pkgs
if pkg and not pkg.startswith("_") and not any(blacklisted in pkg for blacklisted in blacklist_installed_pkgs)
]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ ignore = [
"S301",
"D104",
"PERF203",
"LOG015"
"LOG015",
"PLC0415"
]

[tool.ruff.lint.flake8-type-checking]
Expand Down
Loading