|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import contextlib |
| 4 | +import json |
| 5 | +import os |
3 | 6 | from dataclasses import dataclass
|
4 | 7 | from pathlib import Path
|
5 | 8 | from typing import TYPE_CHECKING
|
@@ -55,18 +58,21 @@ def initialize_function_optimization(
|
55 | 58 | return {"functionName": params.functionName, "status": "not found", "args": None}
|
56 | 59 | fto = optimizable_funcs.popitem()[1][0]
|
57 | 60 | server.optimizer.current_function_being_optimized = fto
|
58 |
| - return {"functionName": params.functionName, "status": "success", "info": fto.server_info} |
| 61 | + return {"functionName": params.functionName, "status": "success"} |
59 | 62 |
|
60 | 63 |
|
61 | 64 | @server.feature("discoverFunctionTests")
|
62 | 65 | def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
|
63 |
| - current_function = server.optimizer.current_function_being_optimized |
| 66 | + fto = server.optimizer.current_function_being_optimized |
| 67 | + optimizable_funcs = {fto.file_path: [fto]} |
| 68 | + |
| 69 | + devnull_writer = open(os.devnull, "w") # noqa |
| 70 | + with contextlib.redirect_stdout(devnull_writer): |
| 71 | + function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs) |
64 | 72 |
|
65 |
| - optimizable_funcs = {current_function.file_path: [current_function]} |
| 73 | + server.optimizer.discovered_tests = function_to_tests |
66 | 74 |
|
67 |
| - function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs) |
68 |
| - # mocking in order to get things going |
69 |
| - return {"functionName": params.functionName, "status": "success", "generated_tests": str(num_discovered_tests)} |
| 75 | + return {"functionName": params.functionName, "status": "success", "discovered_tests": num_discovered_tests} |
70 | 76 |
|
71 | 77 |
|
72 | 78 | @server.feature("prepareOptimization")
|
@@ -145,6 +151,7 @@ def perform_function_optimization(
|
145 | 151 | function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
|
146 | 152 | original_module_ast=original_module_ast,
|
147 | 153 | original_module_path=current_function.file_path,
|
| 154 | + function_to_tests=server.optimizer.discovered_tests or {}, |
148 | 155 | )
|
149 | 156 |
|
150 | 157 | server.optimizer.current_function_optimizer = function_optimizer
|
@@ -214,13 +221,14 @@ def perform_function_optimization(
|
214 | 221 | "message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
|
215 | 222 | }
|
216 | 223 |
|
217 |
| - optimized_source = best_optimization.candidate.source_code # noqa: F841 |
| 224 | + optimized_source = best_optimization.candidate.source_code |
218 | 225 |
|
219 | 226 | return {
|
220 | 227 | "functionName": params.functionName,
|
221 | 228 | "status": "success",
|
222 | 229 | "message": "Optimization completed successfully",
|
223 | 230 | "extra": f"Speedup: {original_code_baseline.runtime / best_optimization.runtime:.2f}x faster",
|
| 231 | + "optimization": json.dumps(optimized_source, indent=None), |
224 | 232 | }
|
225 | 233 |
|
226 | 234 |
|
|
0 commit comments