Skip to content

Commit 2e394f6

Browse files
tests and fix global assignments imports
1 parent 2acda6a commit 2e394f6

File tree

6 files changed

+73
-122
lines changed

6 files changed

+73
-122
lines changed
Lines changed: 6 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,25 @@
11
from os import getenv
2-
from typing import Optional
32

4-
from attrs import define, evolve, field
3+
from attrs import define, evolve
54

6-
from code_to_optimize.code_directories.circular_deps.constants import DEFAULT_API_URL, DEFAULT_APP_URL
5+
from constants import DEFAULT_API_URL, DEFAULT_APP_URL
76

87

98
@define
10-
class GalileoApiClient():
11-
"""A Client which has been authenticated for use on secured endpoints
12-
13-
The following are accepted as keyword arguments and will be used to construct httpx Clients internally:
14-
15-
``base_url``: The base URL for the API, all requests are made to a relative path to this URL
16-
This can also be set via the GALILEO_CONSOLE_URL environment variable
17-
18-
``api_key``: The API key to be sent with every request
19-
This can also be set via the GALILEO_API_KEY environment variable
20-
21-
``cookies``: A dictionary of cookies to be sent with every request
22-
23-
``headers``: A dictionary of headers to be sent with every request
24-
25-
``timeout``: The maximum amount of a time a request can take. API functions will raise
26-
httpx.TimeoutException if this is exceeded.
27-
28-
``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production,
29-
but can be set to False for testing purposes.
30-
31-
``follow_redirects``: Whether or not to follow redirects. Default value is False.
32-
33-
``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor.
34-
35-
Attributes:
36-
raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a
37-
status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
38-
argument to the constructor.
39-
token: The token to use for authentication
40-
prefix: The prefix to use for the Authorization header
41-
auth_header_name: The name of the Authorization header
42-
"""
43-
44-
_base_url: Optional[str] = field(factory=lambda: GalileoApiClient.get_api_url(), kw_only=True, alias="base_url")
45-
_api_key: Optional[str] = field(factory=lambda: getenv("GALILEO_API_KEY", None), kw_only=True, alias="api_key")
46-
token: Optional[str] = None
47-
48-
api_key_header_name: str = "Galileo-API-Key"
9+
class ApiClient():
10+
api_key_header_name: str = "API-Key"
4911
client_type_header_name: str = "client-type"
5012
client_type_header_value: str = "sdk-python"
5113

5214
@staticmethod
5315
def get_console_url() -> str:
54-
console_url = getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL)
16+
console_url = getenv("CONSOLE_URL", DEFAULT_API_URL)
5517
if DEFAULT_API_URL == console_url:
5618
return DEFAULT_APP_URL
5719

5820
return console_url
5921

60-
def with_api_key(self, api_key: str) -> "GalileoApiClient":
22+
def with_api_key(self, api_key: str) -> "ApiClient": # ---> here is the problem with circular dependency, this makes libcst thinks that ApiClient needs an import despite it's already in the same file.
6123
"""Get a new client matching this one with a new API key"""
62-
if self._client is not None:
63-
self._client.headers.update({self.api_key_header_name: api_key})
64-
if self._async_client is not None:
65-
self._async_client.headers.update({self.api_key_header_name: api_key})
6624
return evolve(self, api_key=api_key)
6725

68-
@staticmethod
69-
def get_api_url(base_url: Optional[str] = None) -> str:
70-
api_url = base_url or getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL)
71-
if api_url is None:
72-
raise ValueError("base_url or GALILEO_CONSOLE_URL must be set")
73-
if any(map(api_url.__contains__, ["localhost", "127.0.0.1"])):
74-
api_url = "http://localhost:8088"
75-
else:
76-
api_url = api_url.replace("app.galileo.ai", "api.galileo.ai").replace("console", "api")
77-
return api_url
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import annotations
2+
3+
import urllib.parse
4+
from os import getenv
5+
6+
from attrs import define
7+
from api_client import ApiClient
8+
from constants import DEFAULT_API_URL, DEFAULT_APP_URL
9+
10+
11+
@define
12+
class ApiClient():
13+
14+
@staticmethod
15+
def get_console_url() -> str:
16+
# Cache env lookup for speed
17+
console_url = getenv("CONSOLE_URL")
18+
if not console_url or console_url == DEFAULT_API_URL:
19+
return DEFAULT_APP_URL
20+
return console_url
21+
22+
# Pre-parse netlocs that are checked frequently to avoid parsing repeatedly
23+
_DEFAULT_APP_URL_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc
24+
_DEFAULT_API_URL_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc
25+
26+
def get_dest_url(url: str) -> str:
27+
destination = url if url else ApiClient.get_console_url()
28+
# Replace only if 'console.' is at the beginning to avoid partial matches
29+
if destination.startswith("console."):
30+
destination = "api." + destination[len("console."):]
31+
else:
32+
destination = destination.replace("console.", "api.", 1)
33+
34+
parsed_url = urllib.parse.urlparse(destination)
35+
if parsed_url.netloc == _DEFAULT_APP_URL_NETLOC or parsed_url.netloc == _DEFAULT_API_URL_NETLOC:
36+
return f"{DEFAULT_APP_URL}api/traces"
37+
return f"{parsed_url.scheme}://{parsed_url.netloc}/traces"
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[tool.codeflash]
2+
# All paths are relative to this pyproject.toml's directory.
3+
module-root = "."
4+
tests-root = "tests"
5+
test-framework = "pytest"
6+
ignore-paths = []
7+
formatter-cmds = ["black $file"]

codeflash/code_utils/code_extractor.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -325,36 +325,30 @@ def add_needed_imports_from_module(
325325
)
326326
)
327327
cst.parse_module(src_module_code).visit(gatherer)
328-
scheduled_unused_imports = []
329328
try:
330329
for mod in gatherer.module_imports:
331330
AddImportsVisitor.add_needed_import(dst_context, mod)
332-
scheduled_unused_imports.append((mod, "", ""))
331+
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
333332
for mod, obj_seq in gatherer.object_mapping.items():
334-
logger.debug(f"dst_context.full_module_name: {dst_context.full_module_name}")
335-
logger.debug(f"mod: {mod}")
336-
logger.debug(f"obj_seq: {obj_seq}")
337-
logger.debug(f"helper_functions_fqn: {helper_functions_fqn}")
338333
for obj in obj_seq:
339334
if (
340-
f"{mod}.{obj}" in helper_functions_fqn
341-
or dst_context.full_module_name == mod # avoid circular imports
335+
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
342336
):
343337
continue # Skip adding imports for helper functions already in the context
344338
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
345-
scheduled_unused_imports.append((mod, obj, ""))
339+
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
346340
except Exception as e:
347341
logger.exception(f"Error adding imports to destination module code: {e}")
348342
return dst_module_code
349343
for mod, asname in gatherer.module_aliases.items():
350344
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
351-
scheduled_unused_imports.append((mod, "", asname))
345+
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
352346
for mod, alias_pairs in gatherer.alias_mapping.items():
353347
for alias_pair in alias_pairs:
354348
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
355349
continue
356350
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
357-
scheduled_unused_imports.append((mod, alias_pair[0], alias_pair[1]))
351+
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
358352

359353
try:
360354
parsed_module = cst.parse_module(dst_module_code)
@@ -363,9 +357,6 @@ def add_needed_imports_from_module(
363357
return dst_module_code # Return the original code if there's a syntax error
364358
try:
365359
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
366-
for _import in scheduled_unused_imports:
367-
(_module, _obj, _alias) = _import
368-
RemoveImportsVisitor.remove_unused_import(dst_context, module=_module, obj=_obj, asname=_alias)
369360
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
370361
return transformed_module.code.lstrip("\n")
371362
except Exception as e:

codeflash/code_utils/code_replacer.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -397,13 +397,6 @@ def replace_functions_and_add_imports(
397397
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
398398
project_root_path: Path,
399399
) -> str:
400-
logger.debug("start from here,...")
401-
logger.debug(f"source_code: {source_code}")
402-
logger.debug(f"function_names: {function_names}")
403-
logger.debug(f"optimized_code: {optimized_code}")
404-
logger.debug(f"module_abspath: {module_abspath}")
405-
logger.debug(f"preexisting_objects: {preexisting_objects}")
406-
logger.debug(f"project_root_path: {project_root_path}")
407400
return add_needed_imports_from_module(
408401
optimized_code,
409402
replace_functions_in_file(source_code, function_names, optimized_code, preexisting_objects),
@@ -422,12 +415,16 @@ def replace_function_definitions_in_module(
422415
) -> bool:
423416
source_code: str = module_abspath.read_text(encoding="utf8")
424417
new_code: str = replace_functions_and_add_imports(
425-
source_code, function_names, optimized_code, module_abspath, preexisting_objects, project_root_path
418+
add_global_assignments(optimized_code, source_code),
419+
function_names,
420+
optimized_code,
421+
module_abspath,
422+
preexisting_objects,
423+
project_root_path,
426424
)
427425
if is_zero_diff(source_code, new_code):
428426
return False
429-
code_with_global_assignments = add_global_assignments(optimized_code, new_code)
430-
module_abspath.write_text(code_with_global_assignments, encoding="utf8")
427+
module_abspath.write_text(new_code, encoding="utf8")
431428
return True
432429

433430

tests/test_code_context_extractor.py

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1212
from codeflash.models.models import FunctionParent
1313
from codeflash.optimization.optimizer import Optimizer
14+
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
15+
from codeflash.code_utils.code_extractor import add_global_assignments
1416

1517

1618
class HelperClass:
@@ -2436,51 +2438,20 @@ def simple_method(self):
24362438
assert "return 42" in code_content
24372439

24382440

2439-
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
2441+
24402442
def test_replace_functions_and_add_imports():
24412443
path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps"
2442-
optimized_code = '''from __future__ import annotations
2443-
2444-
import urllib.parse
2445-
from os import getenv
2446-
2447-
from attrs import define
2448-
from code_to_optimize.code_directories.circular_deps.constants import DEFAULT_API_URL, DEFAULT_APP_URL
2449-
2450-
# Precompute constant netlocs for set membership test
2451-
_DEFAULT_APP_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc
2452-
_DEFAULT_API_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc
2453-
_NETLOC_SET = {_DEFAULT_APP_NETLOC, _DEFAULT_API_NETLOC}
2454-
2455-
@define
2456-
class GalileoApiClient():
2457-
2458-
@staticmethod
2459-
def get_console_url() -> str:
2460-
# Return DEFAULT_APP_URL if the env var is not set or set to DEFAULT_API_URL
2461-
console_url = getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL)
2462-
if console_url == DEFAULT_API_URL:
2463-
return DEFAULT_APP_URL
2464-
return console_url
2465-
2466-
def _set_destination(console_url: str) -> str:
2467-
"""
2468-
Parse the console_url and return the destination for the OpenTelemetry traces.
2469-
"""
2470-
destination = (console_url or GalileoApiClient.get_console_url()).replace("console.", "api.")
2471-
parsed_url = urllib.parse.urlparse(destination)
2472-
if parsed_url.netloc in _NETLOC_SET:
2473-
return f"{DEFAULT_APP_URL}api/galileo/otel/traces"
2474-
return f"{parsed_url.scheme}://{parsed_url.netloc}/otel/traces"'''
24752444
file_abs_path = path_to_root / "api_client.py"
2445+
optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8")
24762446
content = Path(file_abs_path).read_text(encoding="utf-8")
24772447
new_code = replace_functions_and_add_imports(
2478-
source_code= content,
2479-
function_names= ["GalileoApiClient.get_console_url"],
2448+
source_code= add_global_assignments(optimized_code, content),
2449+
function_names= ["ApiClient.get_console_url"],
24802450
optimized_code= optimized_code,
2481-
module_abspath= file_abs_path,
2482-
preexisting_objects= {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))},
2451+
module_abspath= Path(file_abs_path),
2452+
preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))},
24832453
project_root_path= Path(path_to_root),
24842454
)
2485-
print(new_code)
2486-
assert 1 == 1
2455+
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
2456+
2457+
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"

0 commit comments

Comments
 (0)