Skip to content

Commit 059b4dc

Browse files
authored
Merge branch 'main' into trace-and-optimize
2 parents 0b4fcb6 + b759192 commit 059b4dc

File tree

10 files changed

+209
-24
lines changed

10 files changed

+209
-24
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from os import getenv
2+
3+
from attrs import define, evolve
4+
5+
from constants import DEFAULT_API_URL, DEFAULT_APP_URL
6+
7+
8+
@define
9+
class ApiClient():
10+
api_key_header_name: str = "API-Key"
11+
client_type_header_name: str = "client-type"
12+
client_type_header_value: str = "sdk-python"
13+
14+
@staticmethod
15+
def get_console_url() -> str:
16+
console_url = getenv("CONSOLE_URL", DEFAULT_API_URL)
17+
if DEFAULT_API_URL == console_url:
18+
return DEFAULT_APP_URL
19+
20+
return console_url
21+
22+
def with_api_key(self, api_key: str) -> "ApiClient": # ---> here is the problem with circular dependency, this makes libcst thinks that ApiClient needs an import despite it's already in the same file.
23+
"""Get a new client matching this one with a new API key"""
24+
return evolve(self, api_key=api_key)
25+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
DEFAULT_API_URL = "https://api.galileo.ai/"
2+
DEFAULT_APP_URL = "https://app.galileo.ai/"
3+
4+
5+
# function_names: GalileoApiClient.get_console_url
6+
# module_abs_path : /home/mohammed/Work/galileo-python/src/galileo/api_client.py
7+
# preexisting_objects: {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))}
8+
# project_root_path: /home/mohammed/Work/galileo-python/src
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import annotations
2+
3+
import urllib.parse
4+
from os import getenv
5+
6+
from attrs import define
7+
from api_client import ApiClient
8+
from constants import DEFAULT_API_URL, DEFAULT_APP_URL
9+
10+
11+
@define
12+
class ApiClient():
13+
14+
@staticmethod
15+
def get_console_url() -> str:
16+
# Cache env lookup for speed
17+
console_url = getenv("CONSOLE_URL")
18+
if not console_url or console_url == DEFAULT_API_URL:
19+
return DEFAULT_APP_URL
20+
return console_url
21+
22+
# Pre-parse netlocs that are checked frequently to avoid parsing repeatedly
23+
_DEFAULT_APP_URL_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc
24+
_DEFAULT_API_URL_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc
25+
26+
def get_dest_url(url: str) -> str:
27+
destination = url if url else ApiClient.get_console_url()
28+
# Replace only if 'console.' is at the beginning to avoid partial matches
29+
if destination.startswith("console."):
30+
destination = "api." + destination[len("console."):]
31+
else:
32+
destination = destination.replace("console.", "api.", 1)
33+
34+
parsed_url = urllib.parse.urlparse(destination)
35+
if parsed_url.netloc == _DEFAULT_APP_URL_NETLOC or parsed_url.netloc == _DEFAULT_API_URL_NETLOC:
36+
return f"{DEFAULT_APP_URL}api/traces"
37+
return f"{parsed_url.scheme}://{parsed_url.netloc}/traces"
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[tool.codeflash]
2+
# All paths are relative to this pyproject.toml's directory.
3+
module-root = "."
4+
tests-root = "tests"
5+
test-framework = "pytest"
6+
ignore-paths = []
7+
formatter-cmds = ["black $file"]

codeflash/code_utils/code_extractor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,9 @@ def add_needed_imports_from_module(
331331
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
332332
for mod, obj_seq in gatherer.object_mapping.items():
333333
for obj in obj_seq:
334-
if f"{mod}.{obj}" in helper_functions_fqn:
334+
if (
335+
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
336+
):
335337
continue # Skip adding imports for helper functions already in the context
336338
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
337339
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)

codeflash/code_utils/code_replacer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,12 +415,16 @@ def replace_function_definitions_in_module(
415415
) -> bool:
416416
source_code: str = module_abspath.read_text(encoding="utf8")
417417
new_code: str = replace_functions_and_add_imports(
418-
source_code, function_names, optimized_code, module_abspath, preexisting_objects, project_root_path
418+
add_global_assignments(optimized_code, source_code),
419+
function_names,
420+
optimized_code,
421+
module_abspath,
422+
preexisting_objects,
423+
project_root_path,
419424
)
420425
if is_zero_diff(source_code, new_code):
421426
return False
422-
code_with_global_assignments = add_global_assignments(optimized_code, new_code)
423-
module_abspath.write_text(code_with_global_assignments, encoding="utf8")
427+
module_abspath.write_text(new_code, encoding="utf8")
424428
return True
425429

426430

codeflash/lsp/beta.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,27 +36,71 @@ def get_optimizable_functions(
3636
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
3737
) -> dict[str, list[str]]:
3838
file_path = Path(uris.to_fs_path(params.textDocument.uri))
39-
server.optimizer.args.file = file_path
40-
server.optimizer.args.previous_checkpoint_functions = False
41-
optimizable_funcs, _ = server.optimizer.get_optimizable_functions()
42-
path_to_qualified_names = {}
43-
for path, functions in optimizable_funcs.items():
44-
path_to_qualified_names[path.as_posix()] = [func.qualified_name for func in functions]
45-
return path_to_qualified_names
39+
server.show_message_log(f"Getting optimizable functions for: {file_path}", "Info")
40+
41+
# Save original args to restore later
42+
original_file = getattr(server.optimizer.args, "file", None)
43+
original_function = getattr(server.optimizer.args, "function", None)
44+
original_checkpoint = getattr(server.optimizer.args, "previous_checkpoint_functions", None)
45+
46+
server.show_message_log(f"Original args - file: {original_file}, function: {original_function}", "Info")
47+
48+
try:
49+
# Set temporary args for this request only
50+
server.optimizer.args.file = file_path
51+
server.optimizer.args.function = None # Always get ALL functions, not just one
52+
server.optimizer.args.previous_checkpoint_functions = False
53+
54+
server.show_message_log("Calling get_optimizable_functions...", "Info")
55+
optimizable_funcs, _ = server.optimizer.get_optimizable_functions()
56+
57+
path_to_qualified_names = {}
58+
for path, functions in optimizable_funcs.items():
59+
path_to_qualified_names[path.as_posix()] = [func.qualified_name for func in functions]
60+
61+
server.show_message_log(
62+
f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info"
63+
)
64+
return path_to_qualified_names
65+
finally:
66+
# Restore original args to prevent state corruption
67+
if original_file is not None:
68+
server.optimizer.args.file = original_file
69+
if original_function is not None:
70+
server.optimizer.args.function = original_function
71+
else:
72+
server.optimizer.args.function = None
73+
if original_checkpoint is not None:
74+
server.optimizer.args.previous_checkpoint_functions = original_checkpoint
75+
76+
server.show_message_log(
77+
f"Restored args - file: {server.optimizer.args.file}, function: {server.optimizer.args.function}", "Info"
78+
)
4679

4780

4881
@server.feature("initializeFunctionOptimization")
4982
def initialize_function_optimization(
5083
server: CodeflashLanguageServer, params: FunctionOptimizationParams
5184
) -> dict[str, str]:
5285
file_path = Path(uris.to_fs_path(params.textDocument.uri))
86+
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info")
87+
88+
# IMPORTANT: Store the specific function for optimization, but don't corrupt global state
5389
server.optimizer.args.function = params.functionName
5490
server.optimizer.args.file = file_path
91+
92+
server.show_message_log(
93+
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
94+
)
95+
5596
optimizable_funcs, _ = server.optimizer.get_optimizable_functions()
5697
if not optimizable_funcs:
98+
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
5799
return {"functionName": params.functionName, "status": "not found", "args": None}
100+
58101
fto = optimizable_funcs.popitem()[1][0]
59102
server.optimizer.current_function_being_optimized = fto
103+
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
60104
return {"functionName": params.functionName, "status": "success"}
61105

62106

@@ -136,11 +180,20 @@ def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimization
136180

137181

138182
@server.feature("performFunctionOptimization")
139-
def perform_function_optimization(
183+
def perform_function_optimization( # noqa: PLR0911
140184
server: CodeflashLanguageServer, params: FunctionOptimizationParams
141185
) -> dict[str, str]:
186+
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
142187
current_function = server.optimizer.current_function_being_optimized
143188

189+
if not current_function:
190+
server.show_message_log(f"No current function being optimized for {params.functionName}", "Error")
191+
return {
192+
"functionName": params.functionName,
193+
"status": "error",
194+
"message": "No function currently being optimized",
195+
}
196+
144197
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
145198

146199
validated_original_code, original_module_ast = module_prep_result
@@ -214,19 +267,29 @@ def perform_function_optimization(
214267
)
215268

216269
if not best_optimization:
270+
server.show_message_log(
271+
f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning"
272+
)
217273
return {
218274
"functionName": params.functionName,
219275
"status": "error",
220276
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
221277
}
222278

223279
optimized_source = best_optimization.candidate.source_code
280+
speedup = original_code_baseline.runtime / best_optimization.runtime
281+
282+
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
283+
284+
# CRITICAL: Clear the function filter after optimization to prevent state corruption
285+
server.optimizer.args.function = None
286+
server.show_message_log("Cleared function filter to prevent state corruption", "Info")
224287

225288
return {
226289
"functionName": params.functionName,
227290
"status": "success",
228291
"message": "Optimization completed successfully",
229-
"extra": f"Speedup: {original_code_baseline.runtime / best_optimization.runtime:.2f}x faster",
292+
"extra": f"Speedup: {speedup:.2f}x faster",
230293
"optimization": optimized_source,
231294
}
232295

codeflash/optimization/optimizer.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
1313
from codeflash.cli_cmds.console import console, logger, progress_bar
1414
from codeflash.code_utils import env_utils
15+
from codeflash.code_utils.code_utils import cleanup_paths
1516
from codeflash.code_utils.env_utils import get_pr_number
1617
from codeflash.either import is_successful
1718
from codeflash.models.models import ValidCode
@@ -248,10 +249,10 @@ def run(self) -> None:
248249
return
249250
if not env_utils.check_formatter_installed(self.args.formatter_cmds):
250251
return
251-
252252
if self.args.no_draft and is_pr_draft():
253253
logger.warning("PR is in draft mode, skipping optimization")
254254
return
255+
cleanup_paths(Optimizer.find_leftover_instrumented_test_files(self.test_cfg.tests_root))
255256

256257
function_optimizer = None
257258
file_to_funcs_to_optimize, num_optimizable_functions = self.get_optimizable_functions()
@@ -326,9 +327,27 @@ def run(self) -> None:
326327

327328
self.cleanup_temporary_paths()
328329

329-
def cleanup_temporary_paths(self) -> None:
330-
from codeflash.code_utils.code_utils import cleanup_paths
330+
@staticmethod
331+
def find_leftover_instrumented_test_files(test_root: Path) -> list[Path]:
332+
"""Search for all paths within the test_root that match the following patterns.
331333
334+
- 'test.*__perf_test_{0,1}.py'
335+
- 'test_.*__unit_test_{0,1}.py'
336+
- 'test_.*__perfinstrumented.py'
337+
- 'test_.*__perfonlyinstrumented.py'
338+
Returns a list of matching file paths.
339+
"""
340+
import re
341+
342+
pattern = re.compile(
343+
r"(?:test.*__perf_test_\d?\.py|test_.*__unit_test_\d?\.py|test_.*__perfinstrumented\.py|test_.*__perfonlyinstrumented\.py)$"
344+
)
345+
346+
return [
347+
file_path for file_path in test_root.rglob("*") if file_path.is_file() and pattern.match(file_path.name)
348+
]
349+
350+
def cleanup_temporary_paths(self) -> None:
332351
if self.current_function_optimizer:
333352
self.current_function_optimizer.cleanup_generated_files()
334353

tests/test_code_context_extractor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1212
from codeflash.models.models import FunctionParent
1313
from codeflash.optimization.optimizer import Optimizer
14+
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
15+
from codeflash.code_utils.code_extractor import add_global_assignments
1416

1517

1618
class HelperClass:
@@ -2434,3 +2436,22 @@ def simple_method(self):
24342436
assert "class SimpleClass:" in code_content
24352437
assert "def simple_method(self):" in code_content
24362438
assert "return 42" in code_content
2439+
2440+
2441+
2442+
def test_replace_functions_and_add_imports():
2443+
path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps"
2444+
file_abs_path = path_to_root / "api_client.py"
2445+
optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8")
2446+
content = Path(file_abs_path).read_text(encoding="utf-8")
2447+
new_code = replace_functions_and_add_imports(
2448+
source_code= add_global_assignments(optimized_code, content),
2449+
function_names= ["ApiClient.get_console_url"],
2450+
optimized_code= optimized_code,
2451+
module_abspath= Path(file_abs_path),
2452+
preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))},
2453+
project_root_path= Path(path_to_root),
2454+
)
2455+
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
2456+
2457+
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"

tests/test_code_replacement.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,8 +1693,8 @@ def new_function2(value):
16931693
print("Hello world")
16941694
"""
16951695
expected_code = """import numpy as np
1696-
print("Hello world")
16971696
1697+
print("Hello world")
16981698
a=2
16991699
print("Hello world")
17001700
def some_fn():
@@ -1712,8 +1712,7 @@ def __init__(self, name):
17121712
def __call__(self, value):
17131713
return "I am still old"
17141714
def new_function2(value):
1715-
return cst.ensure_type(value, str)
1716-
"""
1715+
return cst.ensure_type(value, str)"""
17171716
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
17181717
code_path.write_text(original_code, encoding="utf-8")
17191718
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
@@ -1769,8 +1768,8 @@ def new_function2(value):
17691768
print("Hello world")
17701769
"""
17711770
expected_code = """import numpy as np
1772-
print("Hello world")
17731771
1772+
print("Hello world")
17741773
print("Hello world")
17751774
def some_fn():
17761775
a=np.zeros(10)
@@ -1846,8 +1845,8 @@ def new_function2(value):
18461845
print("Hello world")
18471846
"""
18481847
expected_code = """import numpy as np
1849-
print("Hello world")
18501848
1849+
print("Hello world")
18511850
a=3
18521851
print("Hello world")
18531852
def some_fn():
@@ -1922,8 +1921,8 @@ def new_function2(value):
19221921
print("Hello world")
19231922
"""
19241923
expected_code = """import numpy as np
1925-
print("Hello world")
19261924
1925+
print("Hello world")
19271926
a=2
19281927
print("Hello world")
19291928
def some_fn():
@@ -1999,8 +1998,8 @@ def new_function2(value):
19991998
print("Hello world")
20001999
"""
20012000
expected_code = """import numpy as np
2002-
print("Hello world")
20032001
2002+
print("Hello world")
20042003
a=3
20052004
print("Hello world")
20062005
def some_fn():
@@ -2082,8 +2081,8 @@ def new_function2(value):
20822081
print("Hello world")
20832082
"""
20842083
expected_code = """import numpy as np
2085-
print("Hello world")
20862084
2085+
print("Hello world")
20872086
if 2<3:
20882087
a=4
20892088
else:

0 commit comments

Comments
 (0)