Skip to content

Commit 4cf8ea5

Browse files
authored
Merge pull request #309 from codeflash-ai/updated-vsc-extension
barebones LSP Server implementation
2 parents 2344c2c + 6347a8a commit 4cf8ea5

File tree

6 files changed

+772
-314
lines changed

6 files changed

+772
-314
lines changed

codeflash/lsp/__init__.py

Whitespace-only changes.

codeflash/lsp/beta.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from pathlib import Path
5+
from typing import TYPE_CHECKING
6+
7+
from pygls import uris
8+
9+
from codeflash.either import is_successful
10+
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol
11+
12+
if TYPE_CHECKING:
13+
from lsprotocol import types
14+
15+
from codeflash.models.models import GeneratedTestsList, OptimizationSet
16+
17+
18+
@dataclass
19+
class OptimizableFunctionsParams:
20+
textDocument: types.TextDocumentIdentifier # noqa: N815
21+
22+
23+
@dataclass
24+
class FunctionOptimizationParams:
25+
textDocument: types.TextDocumentIdentifier # noqa: N815
26+
functionName: str # noqa: N815
27+
28+
29+
server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
30+
31+
32+
@server.feature("getOptimizableFunctions")
33+
def get_optimizable_functions(
34+
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
35+
) -> dict[str, list[str]]:
36+
file_path = Path(uris.to_fs_path(params.textDocument.uri))
37+
server.optimizer.args.file = file_path
38+
server.optimizer.args.previous_checkpoint_functions = False
39+
optimizable_funcs, _ = server.optimizer.get_optimizable_functions()
40+
path_to_qualified_names = {}
41+
for path, functions in optimizable_funcs.items():
42+
path_to_qualified_names[path.as_posix()] = [func.qualified_name for func in functions]
43+
return path_to_qualified_names
44+
45+
46+
@server.feature("initializeFunctionOptimization")
47+
def initialize_function_optimization(
48+
server: CodeflashLanguageServer, params: FunctionOptimizationParams
49+
) -> dict[str, str]:
50+
file_path = Path(uris.to_fs_path(params.textDocument.uri))
51+
server.optimizer.args.function = params.functionName
52+
server.optimizer.args.file = file_path
53+
optimizable_funcs, _ = server.optimizer.get_optimizable_functions()
54+
if not optimizable_funcs:
55+
return {"functionName": params.functionName, "status": "not found", "args": None}
56+
fto = optimizable_funcs.popitem()[1][0]
57+
server.optimizer.current_function_being_optimized = fto
58+
return {"functionName": params.functionName, "status": "success", "info": fto.server_info}
59+
60+
61+
@server.feature("discoverFunctionTests")
62+
def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
63+
current_function = server.optimizer.current_function_being_optimized
64+
65+
optimizable_funcs = {current_function.file_path: [current_function]}
66+
67+
function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
68+
# mocking in order to get things going
69+
return {"functionName": params.functionName, "status": "success", "generated_tests": str(num_discovered_tests)}
70+
71+
72+
@server.feature("prepareOptimization")
73+
def prepare_optimization(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
74+
current_function = server.optimizer.current_function_being_optimized
75+
76+
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
77+
validated_original_code, original_module_ast = module_prep_result
78+
79+
function_optimizer = server.optimizer.create_function_optimizer(
80+
current_function,
81+
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
82+
original_module_ast=original_module_ast,
83+
original_module_path=current_function.file_path,
84+
)
85+
86+
server.optimizer.current_function_optimizer = function_optimizer
87+
if not function_optimizer:
88+
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
89+
90+
initialization_result = function_optimizer.can_be_optimized()
91+
if not is_successful(initialization_result):
92+
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
93+
94+
return {"functionName": params.functionName, "status": "success", "message": "Optimization preparation completed"}
95+
96+
97+
@server.feature("generateTests")
98+
def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
99+
function_optimizer = server.optimizer.current_function_optimizer
100+
if not function_optimizer:
101+
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
102+
103+
initialization_result = function_optimizer.can_be_optimized()
104+
if not is_successful(initialization_result):
105+
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
106+
107+
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
108+
109+
test_setup_result = function_optimizer.generate_and_instrument_tests(
110+
code_context, should_run_experiment=should_run_experiment
111+
)
112+
if not is_successful(test_setup_result):
113+
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
114+
generated_tests_list: GeneratedTestsList
115+
optimizations_set: OptimizationSet
116+
generated_tests_list, _, concolic__test_str, optimizations_set = test_setup_result.unwrap()
117+
118+
generated_tests: list[str] = [
119+
generated_test.generated_original_test_source for generated_test in generated_tests_list.generated_tests
120+
]
121+
optimizations_dict = {
122+
candidate.optimization_id: {"source_code": candidate.source_code, "explanation": candidate.explanation}
123+
for candidate in optimizations_set.control + optimizations_set.experiment
124+
}
125+
126+
return {
127+
"functionName": params.functionName,
128+
"status": "success",
129+
"message": {"generated_tests": generated_tests, "optimizations": optimizations_dict},
130+
}
131+
132+
133+
@server.feature("performFunctionOptimization")
134+
def perform_function_optimization(
135+
server: CodeflashLanguageServer, params: FunctionOptimizationParams
136+
) -> dict[str, str]:
137+
current_function = server.optimizer.current_function_being_optimized
138+
139+
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
140+
141+
validated_original_code, original_module_ast = module_prep_result
142+
143+
function_optimizer = server.optimizer.create_function_optimizer(
144+
current_function,
145+
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
146+
original_module_ast=original_module_ast,
147+
original_module_path=current_function.file_path,
148+
)
149+
150+
server.optimizer.current_function_optimizer = function_optimizer
151+
if not function_optimizer:
152+
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
153+
154+
initialization_result = function_optimizer.can_be_optimized()
155+
if not is_successful(initialization_result):
156+
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
157+
158+
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
159+
160+
test_setup_result = function_optimizer.generate_and_instrument_tests(
161+
code_context, should_run_experiment=should_run_experiment
162+
)
163+
if not is_successful(test_setup_result):
164+
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
165+
(
166+
generated_tests,
167+
function_to_concolic_tests,
168+
concolic_test_str,
169+
optimizations_set,
170+
generated_test_paths,
171+
generated_perf_test_paths,
172+
instrumented_unittests_created_for_function,
173+
original_conftest_content,
174+
) = test_setup_result.unwrap()
175+
176+
baseline_setup_result = function_optimizer.setup_and_establish_baseline(
177+
code_context=code_context,
178+
original_helper_code=original_helper_code,
179+
function_to_concolic_tests=function_to_concolic_tests,
180+
generated_test_paths=generated_test_paths,
181+
generated_perf_test_paths=generated_perf_test_paths,
182+
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
183+
original_conftest_content=original_conftest_content,
184+
)
185+
186+
if not is_successful(baseline_setup_result):
187+
return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()}
188+
189+
(
190+
function_to_optimize_qualified_name,
191+
function_to_all_tests,
192+
original_code_baseline,
193+
test_functions_to_remove,
194+
file_path_to_helper_classes,
195+
) = baseline_setup_result.unwrap()
196+
197+
best_optimization = function_optimizer.find_and_process_best_optimization(
198+
optimizations_set=optimizations_set,
199+
code_context=code_context,
200+
original_code_baseline=original_code_baseline,
201+
original_helper_code=original_helper_code,
202+
file_path_to_helper_classes=file_path_to_helper_classes,
203+
function_to_optimize_qualified_name=function_to_optimize_qualified_name,
204+
function_to_all_tests=function_to_all_tests,
205+
generated_tests=generated_tests,
206+
test_functions_to_remove=test_functions_to_remove,
207+
concolic_test_str=concolic_test_str,
208+
)
209+
210+
if not best_optimization:
211+
return {
212+
"functionName": params.functionName,
213+
"status": "error",
214+
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
215+
}
216+
217+
optimized_source = best_optimization.candidate.source_code # noqa: F841
218+
219+
return {
220+
"functionName": params.functionName,
221+
"status": "success",
222+
"message": "Optimization completed successfully",
223+
"extra": f"Speedup: {original_code_baseline.runtime / best_optimization.runtime:.2f}x faster",
224+
}
225+
226+
227+
if __name__ == "__main__":
228+
from codeflash.cli_cmds.console import console
229+
230+
console.quiet = True
231+
server.start_io()

codeflash/lsp/server.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
from typing import TYPE_CHECKING, Any
5+
6+
from lsprotocol.types import INITIALIZE
7+
from pygls import uris
8+
from pygls.protocol import LanguageServerProtocol, lsp_method
9+
from pygls.server import LanguageServer
10+
11+
if TYPE_CHECKING:
12+
from lsprotocol.types import InitializeParams, InitializeResult
13+
14+
15+
class CodeflashLanguageServerProtocol(LanguageServerProtocol):
16+
_server: CodeflashLanguageServer
17+
18+
@lsp_method(INITIALIZE)
19+
def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
20+
server = self._server
21+
initialize_result: InitializeResult = super().lsp_initialize(params)
22+
23+
workspace_uri = params.root_uri
24+
if workspace_uri:
25+
workspace_path = uris.to_fs_path(workspace_uri)
26+
pyproject_toml_path = self._find_pyproject_toml(workspace_path)
27+
if pyproject_toml_path:
28+
server.initialize_optimizer(pyproject_toml_path)
29+
server.show_message(f"Found pyproject.toml at: {pyproject_toml_path}")
30+
else:
31+
server.show_message("No pyproject.toml found in workspace.")
32+
else:
33+
server.show_message("No workspace URI provided.")
34+
35+
return initialize_result
36+
37+
def _find_pyproject_toml(self, workspace_path: str) -> Path | None:
38+
workspace_path_obj = Path(workspace_path)
39+
for file_path in workspace_path_obj.rglob("pyproject.toml"):
40+
return file_path.resolve()
41+
return None
42+
43+
44+
class CodeflashLanguageServer(LanguageServer):
45+
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
46+
super().__init__(*args, **kwargs)
47+
self.optimizer = None
48+
49+
def initialize_optimizer(self, config_file: Path) -> None:
50+
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
51+
from codeflash.optimization.optimizer import Optimizer
52+
53+
args = parse_args()
54+
args.config_file = config_file
55+
args = process_pyproject_config(args)
56+
self.optimizer = Optimizer(args)

0 commit comments

Comments
 (0)