Skip to content

Commit 702bee2

Browse files
Merge pull request #303 from codeflash-ai/filter-tests-for-pr-test-discovery
Test discovery in GHA optimization by pre-filtering test files
2 parents a75a660 + fd06dca commit 702bee2

File tree

8 files changed

+718
-147
lines changed

8 files changed

+718
-147
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 226 additions & 87 deletions
Large diffs are not rendered by default.

codeflash/discovery/functions_to_optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
268268
def get_all_replay_test_functions(
269269
replay_test: Path, test_cfg: TestConfig, project_root_path: Path
270270
) -> dict[Path, list[FunctionToOptimize]]:
271-
function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
271+
function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
272272
# Get the absolute file paths for each function, excluding class name if present
273273
filtered_valid_functions = defaultdict(list)
274274
file_to_functions_map = defaultdict(list)

codeflash/optimization/function_optimizer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
from codeflash.models.models import (
5858
BestOptimization,
5959
CodeOptimizationContext,
60-
FunctionCalledInTest,
6160
GeneratedTests,
6261
GeneratedTestsList,
6362
OptimizationSet,
@@ -87,7 +86,13 @@
8786

8887
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
8988
from codeflash.either import Result
90-
from codeflash.models.models import BenchmarkKey, CoverageData, FunctionSource, OptimizedCandidate
89+
from codeflash.models.models import (
90+
BenchmarkKey,
91+
CoverageData,
92+
FunctionCalledInTest,
93+
FunctionSource,
94+
OptimizedCandidate,
95+
)
9196
from codeflash.verification.verification_utils import TestConfig
9297

9398

@@ -97,7 +102,7 @@ def __init__(
97102
function_to_optimize: FunctionToOptimize,
98103
test_cfg: TestConfig,
99104
function_to_optimize_source_code: str = "",
100-
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
105+
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
101106
function_to_optimize_ast: ast.FunctionDef | None = None,
102107
aiservice_client: AiServiceClient | None = None,
103108
function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
@@ -213,7 +218,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
213218

214219
function_to_optimize_qualified_name = self.function_to_optimize.qualified_name
215220
function_to_all_tests = {
216-
key: self.function_to_tests.get(key, []) + function_to_concolic_tests.get(key, [])
221+
key: self.function_to_tests.get(key, set()) | function_to_concolic_tests.get(key, set())
217222
for key in set(self.function_to_tests) | set(function_to_concolic_tests)
218223
}
219224
instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests)
@@ -690,7 +695,7 @@ def cleanup_leftover_test_return_values() -> None:
690695
get_run_tmp_file(Path("test_return_values_0.bin")).unlink(missing_ok=True)
691696
get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True)
692697

693-
def instrument_existing_tests(self, function_to_all_tests: dict[str, list[FunctionCalledInTest]]) -> set[Path]:
698+
def instrument_existing_tests(self, function_to_all_tests: dict[str, set[FunctionCalledInTest]]) -> set[Path]:
694699
existing_test_files_count = 0
695700
replay_test_files_count = 0
696701
concolic_coverage_test_files_count = 0
@@ -701,7 +706,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi
701706
logger.info(f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.")
702707
console.rule()
703708
else:
704-
test_file_invocation_positions = defaultdict(list[FunctionCalledInTest])
709+
test_file_invocation_positions = defaultdict(list)
705710
for tests_in_file in function_to_all_tests.get(func_qualname):
706711
test_file_invocation_positions[
707712
(tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type)
@@ -787,7 +792,7 @@ def generate_tests_and_optimizations(
787792
generated_test_paths: list[Path],
788793
generated_perf_test_paths: list[Path],
789794
run_experiment: bool = False, # noqa: FBT001, FBT002
790-
) -> Result[tuple[GeneratedTestsList, dict[str, list[FunctionCalledInTest]], OptimizationSet], str]:
795+
) -> Result[tuple[GeneratedTestsList, dict[str, set[FunctionCalledInTest]], OptimizationSet], str]:
791796
assert len(generated_test_paths) == N_TESTS_TO_GENERATE
792797
max_workers = N_TESTS_TO_GENERATE + 2 if not run_experiment else N_TESTS_TO_GENERATE + 3
793798
console.rule()

codeflash/optimization/optimizer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def create_function_optimizer(
4848
self,
4949
function_to_optimize: FunctionToOptimize,
5050
function_to_optimize_ast: ast.FunctionDef | None = None,
51-
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
51+
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
5252
function_to_optimize_source_code: str | None = "",
5353
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
5454
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
@@ -162,8 +162,9 @@ def run(self) -> None:
162162

163163
console.rule()
164164
start_time = time.time()
165-
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
166-
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
165+
function_to_tests, num_discovered_tests = discover_unit_tests(
166+
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize
167+
)
167168
console.rule()
168169
logger.info(
169170
f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"

codeflash/result/create_pr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
def existing_tests_source_for(
2727
function_qualified_name_with_modules_from_root: str,
28-
function_to_tests: dict[str, list[FunctionCalledInTest]],
28+
function_to_tests: dict[str, set[FunctionCalledInTest]],
2929
tests_root: Path,
3030
) -> str:
3131
test_files = function_to_tests.get(function_qualified_name_with_modules_from_root)

codeflash/verification/concolic_testing.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
def generate_concolic_tests(
2626
test_cfg: TestConfig, args: Namespace, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.AST
27-
) -> tuple[dict[str, list[FunctionCalledInTest]], str]:
27+
) -> tuple[dict[str, set[FunctionCalledInTest]], str]:
2828
start_time = time.perf_counter()
2929
function_to_concolic_tests = {}
3030
concolic_test_suite_code = ""
@@ -78,8 +78,7 @@ def generate_concolic_tests(
7878
test_framework=args.test_framework,
7979
pytest_cmd=args.pytest_cmd,
8080
)
81-
function_to_concolic_tests = discover_unit_tests(concolic_test_cfg)
82-
num_discovered_concolic_tests: int = sum([len(value) for value in function_to_concolic_tests.values()])
81+
function_to_concolic_tests, num_discovered_concolic_tests = discover_unit_tests(concolic_test_cfg)
8382
logger.info(
8483
f"Created {num_discovered_concolic_tests} "
8584
f"concolic unit test case{'s' if num_discovered_concolic_tests != 1 else ''} "

tests/test_static_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import ast
1+
import ast
22
from pathlib import Path
33

44
from codeflash.code_utils.static_analysis import (

0 commit comments

Comments
 (0)