diff --git a/code_to_optimize/code_directories/circular_deps/api_client.py b/code_to_optimize/code_directories/circular_deps/api_client.py new file mode 100644 index 00000000..bc93193d --- /dev/null +++ b/code_to_optimize/code_directories/circular_deps/api_client.py @@ -0,0 +1,25 @@ +from os import getenv + +from attrs import define, evolve + +from constants import DEFAULT_API_URL, DEFAULT_APP_URL + + +@define +class ApiClient(): + api_key_header_name: str = "API-Key" + client_type_header_name: str = "client-type" + client_type_header_value: str = "sdk-python" + + @staticmethod + def get_console_url() -> str: + console_url = getenv("CONSOLE_URL", DEFAULT_API_URL) + if DEFAULT_API_URL == console_url: + return DEFAULT_APP_URL + + return console_url + + def with_api_key(self, api_key: str) -> "ApiClient": # ---> here is the problem with circular dependency, this makes libcst thinks that ApiClient needs an import despite it's already in the same file. + """Get a new client matching this one with a new API key""" + return evolve(self, api_key=api_key) + diff --git a/code_to_optimize/code_directories/circular_deps/constants.py b/code_to_optimize/code_directories/circular_deps/constants.py new file mode 100644 index 00000000..dc4b0638 --- /dev/null +++ b/code_to_optimize/code_directories/circular_deps/constants.py @@ -0,0 +1,8 @@ +DEFAULT_API_URL = "https://api.galileo.ai/" +DEFAULT_APP_URL = "https://app.galileo.ai/" + + +# function_names: GalileoApiClient.get_console_url +# module_abs_path : /home/mohammed/Work/galileo-python/src/galileo/api_client.py +# preexisting_objects: {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))} +# project_root_path: /home/mohammed/Work/galileo-python/src diff --git a/code_to_optimize/code_directories/circular_deps/optimized.py b/code_to_optimize/code_directories/circular_deps/optimized.py new file mode 100644 index 00000000..2fa5d9bd --- /dev/null +++ b/code_to_optimize/code_directories/circular_deps/optimized.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import urllib.parse +from os import getenv + +from attrs import define +from api_client import ApiClient +from constants import DEFAULT_API_URL, DEFAULT_APP_URL + + +@define +class ApiClient(): + + @staticmethod + def get_console_url() -> str: + # Cache env lookup for speed + console_url = getenv("CONSOLE_URL") + if not console_url or console_url == DEFAULT_API_URL: + return DEFAULT_APP_URL + return console_url + +# Pre-parse netlocs that are checked frequently to avoid parsing repeatedly +_DEFAULT_APP_URL_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc +_DEFAULT_API_URL_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc + +def get_dest_url(url: str) -> str: + destination = url if url else ApiClient.get_console_url() + # Replace only if 'console.' is at the beginning to avoid partial matches + if destination.startswith("console."): + destination = "api." + destination[len("console."):] + else: + destination = destination.replace("console.", "api.", 1) + + parsed_url = urllib.parse.urlparse(destination) + if parsed_url.netloc == _DEFAULT_APP_URL_NETLOC or parsed_url.netloc == _DEFAULT_API_URL_NETLOC: + return f"{DEFAULT_APP_URL}api/traces" + return f"{parsed_url.scheme}://{parsed_url.netloc}/traces" \ No newline at end of file diff --git a/code_to_optimize/code_directories/circular_deps/pyproject.toml b/code_to_optimize/code_directories/circular_deps/pyproject.toml new file mode 100644 index 00000000..bddef0ed --- /dev/null +++ b/code_to_optimize/code_directories/circular_deps/pyproject.toml @@ -0,0 +1,7 @@ +[tool.codeflash] +# All paths are relative to this pyproject.toml's directory. +module-root = "." +tests-root = "tests" +test-framework = "pytest" +ignore-paths = [] +formatter-cmds = ["black $file"] diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 0dcc2357..73a1c326 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -331,7 +331,9 @@ def add_needed_imports_from_module( RemoveImportsVisitor.remove_unused_import(dst_context, mod) for mod, obj_seq in gatherer.object_mapping.items(): for obj in obj_seq: - if f"{mod}.{obj}" in helper_functions_fqn: + if ( + f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps + ): continue # Skip adding imports for helper functions already in the context AddImportsVisitor.add_needed_import(dst_context, mod, obj) RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index f0964aae..3c73c591 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -415,12 +415,16 @@ def replace_function_definitions_in_module( ) -> bool: source_code: str = module_abspath.read_text(encoding="utf8") new_code: str = replace_functions_and_add_imports( - source_code, function_names, optimized_code, module_abspath, preexisting_objects, project_root_path + add_global_assignments(optimized_code, source_code), + function_names, + optimized_code, + module_abspath, + preexisting_objects, + project_root_path, ) if is_zero_diff(source_code, new_code): return False - code_with_global_assignments = add_global_assignments(optimized_code, new_code) - module_abspath.write_text(code_with_global_assignments, encoding="utf8") + module_abspath.write_text(new_code, encoding="utf8") return True diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 010d3bc6..25200cb9 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -11,6 +11,8 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent from codeflash.optimization.optimizer import Optimizer +from codeflash.code_utils.code_replacer import replace_functions_and_add_imports +from codeflash.code_utils.code_extractor import add_global_assignments class HelperClass: @@ -2434,3 +2436,22 @@ def simple_method(self): assert "class SimpleClass:" in code_content assert "def simple_method(self):" in code_content assert "return 42" in code_content + + + +def test_replace_functions_and_add_imports(): + path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps" + file_abs_path = path_to_root / "api_client.py" + optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8") + content = Path(file_abs_path).read_text(encoding="utf-8") + new_code = replace_functions_and_add_imports( + source_code= add_global_assignments(optimized_code, content), + function_names= ["ApiClient.get_console_url"], + optimized_code= optimized_code, + module_abspath= Path(file_abs_path), + preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))}, + project_root_path= Path(path_to_root), + ) + assert "import ApiClient" not in new_code, "Error: Circular dependency found" + + assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist" diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 363dbaee..7272163d 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -1693,8 +1693,8 @@ def new_function2(value): print("Hello world") """ expected_code = """import numpy as np -print("Hello world") +print("Hello world") a=2 print("Hello world") def some_fn(): @@ -1712,8 +1712,7 @@ def __init__(self, name): def __call__(self, value): return "I am still old" def new_function2(value): - return cst.ensure_type(value, str) -""" + return cst.ensure_type(value, str)""" code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() code_path.write_text(original_code, encoding="utf-8") tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/") @@ -1769,8 +1768,8 @@ def new_function2(value): print("Hello world") """ expected_code = """import numpy as np -print("Hello world") +print("Hello world") print("Hello world") def some_fn(): a=np.zeros(10) @@ -1846,8 +1845,8 @@ def new_function2(value): print("Hello world") """ expected_code = """import numpy as np -print("Hello world") +print("Hello world") a=3 print("Hello world") def some_fn(): @@ -1922,8 +1921,8 @@ def new_function2(value): print("Hello world") """ expected_code = """import numpy as np -print("Hello world") +print("Hello world") a=2 print("Hello world") def some_fn(): @@ -1999,8 +1998,8 @@ def new_function2(value): print("Hello world") """ expected_code = """import numpy as np -print("Hello world") +print("Hello world") a=3 print("Hello world") def some_fn(): @@ -2082,8 +2081,8 @@ def new_function2(value): print("Hello world") """ expected_code = """import numpy as np -print("Hello world") +print("Hello world") if 2<3: a=4 else: