57
57
from codeflash .models .models import (
58
58
BestOptimization ,
59
59
CodeOptimizationContext ,
60
- FunctionCalledInTest ,
61
60
GeneratedTests ,
62
61
GeneratedTestsList ,
63
62
OptimizationSet ,
87
86
88
87
from codeflash .discovery .functions_to_optimize import FunctionToOptimize
89
88
from codeflash .either import Result
90
- from codeflash .models .models import BenchmarkKey , CoverageData , FunctionSource , OptimizedCandidate
89
+ from codeflash .models .models import (
90
+ BenchmarkKey ,
91
+ CoverageData ,
92
+ FunctionCalledInTest ,
93
+ FunctionSource ,
94
+ OptimizedCandidate ,
95
+ )
91
96
from codeflash .verification .verification_utils import TestConfig
92
97
93
98
@@ -97,7 +102,7 @@ def __init__(
97
102
function_to_optimize : FunctionToOptimize ,
98
103
test_cfg : TestConfig ,
99
104
function_to_optimize_source_code : str = "" ,
100
- function_to_tests : dict [str , list [FunctionCalledInTest ]] | None = None ,
105
+ function_to_tests : dict [str , set [FunctionCalledInTest ]] | None = None ,
101
106
function_to_optimize_ast : ast .FunctionDef | None = None ,
102
107
aiservice_client : AiServiceClient | None = None ,
103
108
function_benchmark_timings : dict [BenchmarkKey , int ] | None = None ,
@@ -213,7 +218,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
213
218
214
219
function_to_optimize_qualified_name = self .function_to_optimize .qualified_name
215
220
function_to_all_tests = {
216
- key : self .function_to_tests .get (key , []) + function_to_concolic_tests .get (key , [] )
221
+ key : self .function_to_tests .get (key , set ()) | function_to_concolic_tests .get (key , set () )
217
222
for key in set (self .function_to_tests ) | set (function_to_concolic_tests )
218
223
}
219
224
instrumented_unittests_created_for_function = self .instrument_existing_tests (function_to_all_tests )
@@ -690,7 +695,7 @@ def cleanup_leftover_test_return_values() -> None:
690
695
get_run_tmp_file (Path ("test_return_values_0.bin" )).unlink (missing_ok = True )
691
696
get_run_tmp_file (Path ("test_return_values_0.sqlite" )).unlink (missing_ok = True )
692
697
693
- def instrument_existing_tests (self , function_to_all_tests : dict [str , list [FunctionCalledInTest ]]) -> set [Path ]:
698
+ def instrument_existing_tests (self , function_to_all_tests : dict [str , set [FunctionCalledInTest ]]) -> set [Path ]:
694
699
existing_test_files_count = 0
695
700
replay_test_files_count = 0
696
701
concolic_coverage_test_files_count = 0
@@ -701,7 +706,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi
701
706
logger .info (f"Did not find any pre-existing tests for '{ func_qualname } ', will only use generated tests." )
702
707
console .rule ()
703
708
else :
704
- test_file_invocation_positions = defaultdict (list [ FunctionCalledInTest ] )
709
+ test_file_invocation_positions = defaultdict (list )
705
710
for tests_in_file in function_to_all_tests .get (func_qualname ):
706
711
test_file_invocation_positions [
707
712
(tests_in_file .tests_in_file .test_file , tests_in_file .tests_in_file .test_type )
@@ -787,7 +792,7 @@ def generate_tests_and_optimizations(
787
792
generated_test_paths : list [Path ],
788
793
generated_perf_test_paths : list [Path ],
789
794
run_experiment : bool = False , # noqa: FBT001, FBT002
790
- ) -> Result [tuple [GeneratedTestsList , dict [str , list [FunctionCalledInTest ]], OptimizationSet ], str ]:
795
+ ) -> Result [tuple [GeneratedTestsList , dict [str , set [FunctionCalledInTest ]], OptimizationSet ], str ]:
791
796
assert len (generated_test_paths ) == N_TESTS_TO_GENERATE
792
797
max_workers = N_TESTS_TO_GENERATE + 2 if not run_experiment else N_TESTS_TO_GENERATE + 3
793
798
console .rule ()
0 commit comments