Skip to content

Commit 2b5fa6e

Browse files
Merge pull request #448 from codeflash-ai/fix/cicular-imports-in-helper-context
[FIX] Circular dependency and global assignments imports
2 parents 132b92d + 0b18bca commit 2b5fa6e

File tree

8 files changed

+115
-12
lines changed

8 files changed

+115
-12
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from os import getenv
2+
3+
from attrs import define, evolve
4+
5+
from constants import DEFAULT_API_URL, DEFAULT_APP_URL
6+
7+
8+
@define
9+
class ApiClient():
10+
api_key_header_name: str = "API-Key"
11+
client_type_header_name: str = "client-type"
12+
client_type_header_value: str = "sdk-python"
13+
14+
@staticmethod
15+
def get_console_url() -> str:
16+
console_url = getenv("CONSOLE_URL", DEFAULT_API_URL)
17+
if DEFAULT_API_URL == console_url:
18+
return DEFAULT_APP_URL
19+
20+
return console_url
21+
22+
def with_api_key(self, api_key: str) -> "ApiClient": # ---> here is the problem with circular dependency, this makes libcst thinks that ApiClient needs an import despite it's already in the same file.
23+
"""Get a new client matching this one with a new API key"""
24+
return evolve(self, api_key=api_key)
25+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
DEFAULT_API_URL = "https://api.galileo.ai/"
2+
DEFAULT_APP_URL = "https://app.galileo.ai/"
3+
4+
5+
# function_names: GalileoApiClient.get_console_url
6+
# module_abs_path : /home/mohammed/Work/galileo-python/src/galileo/api_client.py
7+
# preexisting_objects: {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))}
8+
# project_root_path: /home/mohammed/Work/galileo-python/src
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import annotations
2+
3+
import urllib.parse
4+
from os import getenv
5+
6+
from attrs import define
7+
from api_client import ApiClient
8+
from constants import DEFAULT_API_URL, DEFAULT_APP_URL
9+
10+
11+
@define
12+
class ApiClient():
13+
14+
@staticmethod
15+
def get_console_url() -> str:
16+
# Cache env lookup for speed
17+
console_url = getenv("CONSOLE_URL")
18+
if not console_url or console_url == DEFAULT_API_URL:
19+
return DEFAULT_APP_URL
20+
return console_url
21+
22+
# Pre-parse netlocs that are checked frequently to avoid parsing repeatedly
23+
_DEFAULT_APP_URL_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc
24+
_DEFAULT_API_URL_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc
25+
26+
def get_dest_url(url: str) -> str:
27+
destination = url if url else ApiClient.get_console_url()
28+
# Replace only if 'console.' is at the beginning to avoid partial matches
29+
if destination.startswith("console."):
30+
destination = "api." + destination[len("console."):]
31+
else:
32+
destination = destination.replace("console.", "api.", 1)
33+
34+
parsed_url = urllib.parse.urlparse(destination)
35+
if parsed_url.netloc == _DEFAULT_APP_URL_NETLOC or parsed_url.netloc == _DEFAULT_API_URL_NETLOC:
36+
return f"{DEFAULT_APP_URL}api/traces"
37+
return f"{parsed_url.scheme}://{parsed_url.netloc}/traces"
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[tool.codeflash]
2+
# All paths are relative to this pyproject.toml's directory.
3+
module-root = "."
4+
tests-root = "tests"
5+
test-framework = "pytest"
6+
ignore-paths = []
7+
formatter-cmds = ["black $file"]

codeflash/code_utils/code_extractor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,9 @@ def add_needed_imports_from_module(
331331
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
332332
for mod, obj_seq in gatherer.object_mapping.items():
333333
for obj in obj_seq:
334-
if f"{mod}.{obj}" in helper_functions_fqn:
334+
if (
335+
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
336+
):
335337
continue # Skip adding imports for helper functions already in the context
336338
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
337339
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)

codeflash/code_utils/code_replacer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,12 +415,16 @@ def replace_function_definitions_in_module(
415415
) -> bool:
416416
source_code: str = module_abspath.read_text(encoding="utf8")
417417
new_code: str = replace_functions_and_add_imports(
418-
source_code, function_names, optimized_code, module_abspath, preexisting_objects, project_root_path
418+
add_global_assignments(optimized_code, source_code),
419+
function_names,
420+
optimized_code,
421+
module_abspath,
422+
preexisting_objects,
423+
project_root_path,
419424
)
420425
if is_zero_diff(source_code, new_code):
421426
return False
422-
code_with_global_assignments = add_global_assignments(optimized_code, new_code)
423-
module_abspath.write_text(code_with_global_assignments, encoding="utf8")
427+
module_abspath.write_text(new_code, encoding="utf8")
424428
return True
425429

426430

tests/test_code_context_extractor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1212
from codeflash.models.models import FunctionParent
1313
from codeflash.optimization.optimizer import Optimizer
14+
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
15+
from codeflash.code_utils.code_extractor import add_global_assignments
1416

1517

1618
class HelperClass:
@@ -2434,3 +2436,22 @@ def simple_method(self):
24342436
assert "class SimpleClass:" in code_content
24352437
assert "def simple_method(self):" in code_content
24362438
assert "return 42" in code_content
2439+
2440+
2441+
2442+
def test_replace_functions_and_add_imports():
2443+
path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps"
2444+
file_abs_path = path_to_root / "api_client.py"
2445+
optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8")
2446+
content = Path(file_abs_path).read_text(encoding="utf-8")
2447+
new_code = replace_functions_and_add_imports(
2448+
source_code= add_global_assignments(optimized_code, content),
2449+
function_names= ["ApiClient.get_console_url"],
2450+
optimized_code= optimized_code,
2451+
module_abspath= Path(file_abs_path),
2452+
preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))},
2453+
project_root_path= Path(path_to_root),
2454+
)
2455+
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
2456+
2457+
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"

tests/test_code_replacement.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,8 +1693,8 @@ def new_function2(value):
16931693
print("Hello world")
16941694
"""
16951695
expected_code = """import numpy as np
1696-
print("Hello world")
16971696
1697+
print("Hello world")
16981698
a=2
16991699
print("Hello world")
17001700
def some_fn():
@@ -1712,8 +1712,7 @@ def __init__(self, name):
17121712
def __call__(self, value):
17131713
return "I am still old"
17141714
def new_function2(value):
1715-
return cst.ensure_type(value, str)
1716-
"""
1715+
return cst.ensure_type(value, str)"""
17171716
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
17181717
code_path.write_text(original_code, encoding="utf-8")
17191718
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
@@ -1769,8 +1768,8 @@ def new_function2(value):
17691768
print("Hello world")
17701769
"""
17711770
expected_code = """import numpy as np
1772-
print("Hello world")
17731771
1772+
print("Hello world")
17741773
print("Hello world")
17751774
def some_fn():
17761775
a=np.zeros(10)
@@ -1846,8 +1845,8 @@ def new_function2(value):
18461845
print("Hello world")
18471846
"""
18481847
expected_code = """import numpy as np
1849-
print("Hello world")
18501848
1849+
print("Hello world")
18511850
a=3
18521851
print("Hello world")
18531852
def some_fn():
@@ -1922,8 +1921,8 @@ def new_function2(value):
19221921
print("Hello world")
19231922
"""
19241923
expected_code = """import numpy as np
1925-
print("Hello world")
19261924
1925+
print("Hello world")
19271926
a=2
19281927
print("Hello world")
19291928
def some_fn():
@@ -1999,8 +1998,8 @@ def new_function2(value):
19991998
print("Hello world")
20001999
"""
20012000
expected_code = """import numpy as np
2002-
print("Hello world")
20032001
2002+
print("Hello world")
20042003
a=3
20052004
print("Hello world")
20062005
def some_fn():
@@ -2082,8 +2081,8 @@ def new_function2(value):
20822081
print("Hello world")
20832082
"""
20842083
expected_code = """import numpy as np
2085-
print("Hello world")
20862084
2085+
print("Hello world")
20872086
if 2<3:
20882087
a=4
20892088
else:

0 commit comments

Comments
 (0)