diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index fddc5c18a..b22fbc361 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -204,6 +204,60 @@ def optimize_python_code_line_profiler( console.rule() return [] + def get_new_explanation( + self, + source_code: str, + optimized_code: str, + dependency_code: str, + trace_id: str, + existing_explanation: str, + ) -> str: + """Optimize the given python code for performance by making a request to the Django endpoint. + + Parameters + ---------- + - source_code (str): The python code to optimize. + - dependency_code (str): The dependency code used as read-only context for the optimization + - trace_id (str): Trace id of optimization run + - num_candidates (int): Number of optimization variants to generate. Default is 10. + - experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization + - existing_explanation (str): Existing explanation from AIservice call + + Returns + ------- + - List[OptimizationCandidate]: A list of Optimization Candidates. + + """ + payload = { + "trace_id": trace_id, + "source_code": source_code, + "optimized_code":optimized_code, + "existing_explanation": existing_explanation, + "dependency_code": dependency_code, + } + logger.info("Generating optimized candidates…") + console.rule() + try: + response = self.make_ai_service_request("/explain", payload=payload, timeout=600) + except requests.exceptions.RequestException as e: + logger.exception(f"Error generating optimized candidates: {e}") + ph("cli-optimize-error-caught", {"error": str(e)}) + return "" + + if response.status_code == 200: + explanation = response.json()["explanation"] + logger.info(f"New Explanation: {explanation}") + console.rule() + return explanation + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return "" + def log_results( self, diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 93def83c0..faf0d762a 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -8,6 +8,7 @@ import time import uuid from collections import defaultdict, deque +from dataclasses import replace from pathlib import Path from typing import TYPE_CHECKING @@ -254,13 +255,12 @@ def optimize_function(self) -> Result[BestOptimization, str]: ) if best_optimization: - logger.info("Best candidate:") - code_print(best_optimization.candidate.source_code) - console.print( - Panel( - best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue" - ) - ) + new_explanation = self.aiservice_client.get_new_explanation(source_code=code_context.read_writable_code, + dependency_code=code_context.read_only_context_code, + trace_id=self.function_trace_id, + num_candidates=1, + experiment_metadata=None, existing_explanation=best_optimization.candidate.explanation) + best_optimization.candidate = replace(best_optimization.candidate, explanation=new_explanation if new_explanation!="" else best_optimization.candidate.explanation) explanation = Explanation( raw_explanation_message=best_optimization.candidate.explanation, winning_behavioral_test_results=best_optimization.winning_behavioral_test_results, @@ -270,7 +270,13 @@ def optimize_function(self) -> Result[BestOptimization, str]: function_name=function_to_optimize_qualified_name, file_path=self.function_to_optimize.file_path, ) - + logger.info("Best candidate:") + code_print(best_optimization.candidate.source_code) + console.print( + Panel( + best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue" + ) + ) self.log_successful_optimization(explanation, generated_tests) self.replace_function_and_helpers_with_optimized_code( diff --git a/tests/test_explain_api.py b/tests/test_explain_api.py new file mode 100644 index 000000000..390b532b4 --- /dev/null +++ b/tests/test_explain_api.py @@ -0,0 +1,22 @@ +from codeflash.api.aiservice import AiServiceClient +def test_explain_api(): + aiservice = AiServiceClient() + source_code: str = """def bubble_sort(arr): + n = len(arr) + for i in range(n): + for j in range(0, n-i-1): + if arr[j] > arr[j+1]: + arr[j], arr[j+1] = arr[j+1], arr[j] + return arr +""" + dependency_code: str = "def helper(): return 1" + trace_id: str = "d5822364-7617-4389-a4fc-64602a00b714" + existing_explanation: str = "I used to numpy to optimize it" + optimized_code: str = """def bubble_sort(arr): + return arr.sort() +""" + new_explanation = aiservice.get_new_explanation(source_code=source_code, optimized_code=optimized_code, + existing_explanation=existing_explanation, dependency_code=dependency_code, + trace_id=trace_id) + print("\nNew explanation: \n", new_explanation) + assert new_explanation.__len__()>0 \ No newline at end of file