Skip to content

[FIX] Circular dependency and global assignments imports #448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions code_to_optimize/code_directories/circular_deps/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from os import getenv

from attrs import define, evolve

from constants import DEFAULT_API_URL, DEFAULT_APP_URL


@define
class ApiClient():
api_key_header_name: str = "API-Key"
client_type_header_name: str = "client-type"
client_type_header_value: str = "sdk-python"

@staticmethod
def get_console_url() -> str:
console_url = getenv("CONSOLE_URL", DEFAULT_API_URL)
if DEFAULT_API_URL == console_url:
return DEFAULT_APP_URL

return console_url

def with_api_key(self, api_key: str) -> "ApiClient": # ---> here is the problem with circular dependency, this makes libcst thinks that ApiClient needs an import despite it's already in the same file.
"""Get a new client matching this one with a new API key"""
return evolve(self, api_key=api_key)

8 changes: 8 additions & 0 deletions code_to_optimize/code_directories/circular_deps/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
DEFAULT_API_URL = "https://api.galileo.ai/"
DEFAULT_APP_URL = "https://app.galileo.ai/"


# function_names: GalileoApiClient.get_console_url
# module_abs_path : /home/mohammed/Work/galileo-python/src/galileo/api_client.py
# preexisting_objects: {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))}
# project_root_path: /home/mohammed/Work/galileo-python/src
37 changes: 37 additions & 0 deletions code_to_optimize/code_directories/circular_deps/optimized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

import urllib.parse
from os import getenv

from attrs import define
from api_client import ApiClient
from constants import DEFAULT_API_URL, DEFAULT_APP_URL


@define
class ApiClient():

@staticmethod
def get_console_url() -> str:
# Cache env lookup for speed
console_url = getenv("CONSOLE_URL")
if not console_url or console_url == DEFAULT_API_URL:
return DEFAULT_APP_URL
return console_url

# Pre-parse netlocs that are checked frequently to avoid parsing repeatedly
_DEFAULT_APP_URL_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc
_DEFAULT_API_URL_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc

def get_dest_url(url: str) -> str:
destination = url if url else ApiClient.get_console_url()
# Replace only if 'console.' is at the beginning to avoid partial matches
if destination.startswith("console."):
destination = "api." + destination[len("console."):]
else:
destination = destination.replace("console.", "api.", 1)

parsed_url = urllib.parse.urlparse(destination)
if parsed_url.netloc == _DEFAULT_APP_URL_NETLOC or parsed_url.netloc == _DEFAULT_API_URL_NETLOC:
return f"{DEFAULT_APP_URL}api/traces"
return f"{parsed_url.scheme}://{parsed_url.netloc}/traces"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[tool.codeflash]
# All paths are relative to this pyproject.toml's directory.
module-root = "."
tests-root = "tests"
test-framework = "pytest"
ignore-paths = []
formatter-cmds = ["black $file"]
4 changes: 3 additions & 1 deletion codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def add_needed_imports_from_module(
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
for mod, obj_seq in gatherer.object_mapping.items():
for obj in obj_seq:
if f"{mod}.{obj}" in helper_functions_fqn:
if (
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
):
continue # Skip adding imports for helper functions already in the context
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
Expand Down
10 changes: 7 additions & 3 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,12 +415,16 @@ def replace_function_definitions_in_module(
) -> bool:
source_code: str = module_abspath.read_text(encoding="utf8")
new_code: str = replace_functions_and_add_imports(
source_code, function_names, optimized_code, module_abspath, preexisting_objects, project_root_path
add_global_assignments(optimized_code, source_code),
function_names,
optimized_code,
module_abspath,
preexisting_objects,
project_root_path,
)
if is_zero_diff(source_code, new_code):
return False
code_with_global_assignments = add_global_assignments(optimized_code, new_code)
module_abspath.write_text(code_with_global_assignments, encoding="utf8")
module_abspath.write_text(new_code, encoding="utf8")
return True


Expand Down
21 changes: 21 additions & 0 deletions tests/test_code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.code_utils.code_extractor import add_global_assignments


class HelperClass:
Expand Down Expand Up @@ -2434,3 +2436,22 @@ def simple_method(self):
assert "class SimpleClass:" in code_content
assert "def simple_method(self):" in code_content
assert "return 42" in code_content



def test_replace_functions_and_add_imports():
path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps"
file_abs_path = path_to_root / "api_client.py"
optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8")
content = Path(file_abs_path).read_text(encoding="utf-8")
new_code = replace_functions_and_add_imports(
source_code= add_global_assignments(optimized_code, content),
function_names= ["ApiClient.get_console_url"],
optimized_code= optimized_code,
module_abspath= Path(file_abs_path),
preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))},
project_root_path= Path(path_to_root),
)
assert "import ApiClient" not in new_code, "Error: Circular dependency found"

assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
15 changes: 7 additions & 8 deletions tests/test_code_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -1693,8 +1693,8 @@ def new_function2(value):
print("Hello world")
"""
expected_code = """import numpy as np
print("Hello world")

print("Hello world")
a=2
print("Hello world")
def some_fn():
Expand All @@ -1712,8 +1712,7 @@ def __init__(self, name):
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
"""
return cst.ensure_type(value, str)"""
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
code_path.write_text(original_code, encoding="utf-8")
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
Expand Down Expand Up @@ -1769,8 +1768,8 @@ def new_function2(value):
print("Hello world")
"""
expected_code = """import numpy as np
print("Hello world")

print("Hello world")
print("Hello world")
def some_fn():
a=np.zeros(10)
Expand Down Expand Up @@ -1846,8 +1845,8 @@ def new_function2(value):
print("Hello world")
"""
expected_code = """import numpy as np
print("Hello world")

print("Hello world")
a=3
print("Hello world")
def some_fn():
Expand Down Expand Up @@ -1922,8 +1921,8 @@ def new_function2(value):
print("Hello world")
"""
expected_code = """import numpy as np
print("Hello world")

print("Hello world")
a=2
print("Hello world")
def some_fn():
Expand Down Expand Up @@ -1999,8 +1998,8 @@ def new_function2(value):
print("Hello world")
"""
expected_code = """import numpy as np
print("Hello world")

print("Hello world")
a=3
print("Hello world")
def some_fn():
Expand Down Expand Up @@ -2082,8 +2081,8 @@ def new_function2(value):
print("Hello world")
"""
expected_code = """import numpy as np
print("Hello world")

print("Hello world")
if 2<3:
a=4
else:
Expand Down
Loading