Skip to content

Commit e202289

Browse files
committed
minor fixes
1 parent 4a68aa0 commit e202289

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from codeflash.cli_cmds.console import logger
88
from codeflash.code_utils.time_utils import format_time
9-
from codeflash.models.models import GeneratedTests, GeneratedTestsList
9+
from codeflash.models.models import GeneratedTests, GeneratedTestsList, InvocationId
1010
from codeflash.result.critic import performance_gain
1111
from codeflash.verification.verification_utils import TestConfig
1212

@@ -37,7 +37,10 @@ def remove_functions_from_generated_tests(
3737

3838

3939
def add_runtime_comments_to_generated_tests(
40-
test_cfg: TestConfig, generated_tests: GeneratedTestsList, original_runtimes: dict, optimized_runtimes: dict
40+
test_cfg: TestConfig,
41+
generated_tests: GeneratedTestsList,
42+
original_runtimes: dict[InvocationId, list[int]],
43+
optimized_runtimes: dict[InvocationId, list[int]],
4144
) -> GeneratedTestsList:
4245
"""Add runtime performance comments to function calls in generated tests."""
4346
tests_root = test_cfg.tests_root
@@ -48,7 +51,7 @@ def add_runtime_comments_to_generated_tests(
4851
class RuntimeCommentTransformer(cst.CSTTransformer):
4952
def __init__(self, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None:
5053
self.test = test
51-
self.context_stack = []
54+
self.context_stack: list[str] = []
5255
self.tests_root = tests_root
5356
self.rel_tests_root = rel_tests_root
5457

@@ -93,7 +96,7 @@ def leave_SimpleStatementLine(
9396
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name
9497
for invocation_id, runtimes in original_runtimes.items():
9598
qualified_name = (
96-
invocation_id.test_class_name + "." + invocation_id.test_function_name
99+
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
97100
if invocation_id.test_class_name
98101
else invocation_id.test_function_name
99102
)
@@ -110,7 +113,7 @@ def leave_SimpleStatementLine(
110113

111114
for invocation_id, runtimes in optimized_runtimes.items():
112115
qualified_name = (
113-
invocation_id.test_class_name + "." + invocation_id.test_function_name
116+
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
114117
if invocation_id.test_class_name
115118
else invocation_id.test_function_name
116119
)

codeflash/result/create_pr.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from codeflash.result.critic import performance_gain
2323

2424
if TYPE_CHECKING:
25-
from codeflash.models.models import FunctionCalledInTest
25+
from codeflash.models.models import FunctionCalledInTest, InvocationId
2626
from codeflash.result.explanation import Explanation
2727
from codeflash.verification.verification_utils import TestConfig
2828

@@ -31,8 +31,8 @@ def existing_tests_source_for(
3131
function_qualified_name_with_modules_from_root: str,
3232
function_to_tests: dict[str, set[FunctionCalledInTest]],
3333
test_cfg: TestConfig,
34-
original_runtimes_all: dict,
35-
optimized_runtimes_all: dict,
34+
original_runtimes_all: dict[InvocationId, list[int]],
35+
optimized_runtimes_all: dict[InvocationId, list[int]],
3636
) -> str:
3737
test_files = function_to_tests.get(function_qualified_name_with_modules_from_root)
3838
if not test_files:
@@ -41,8 +41,8 @@ def existing_tests_source_for(
4141
tests_root = test_cfg.tests_root
4242
module_root = test_cfg.project_root_path
4343
rel_tests_root = tests_root.relative_to(module_root)
44-
original_tests_to_runtimes = {}
45-
optimized_tests_to_runtimes = {}
44+
original_tests_to_runtimes: dict[Path, dict[str, int]] = {}
45+
optimized_tests_to_runtimes: dict[Path, dict[str, int]] = {}
4646
non_generated_tests = set()
4747
for test_file in test_files:
4848
non_generated_tests.add(Path(test_file.tests_in_file.test_file).relative_to(tests_root))
@@ -59,18 +59,18 @@ def existing_tests_source_for(
5959
if rel_path not in optimized_tests_to_runtimes:
6060
optimized_tests_to_runtimes[rel_path] = {}
6161
qualified_name = (
62-
invocation_id.test_class_name + "." + invocation_id.test_function_name
62+
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
6363
if invocation_id.test_class_name
6464
else invocation_id.test_function_name
6565
)
6666
if qualified_name not in original_tests_to_runtimes[rel_path]:
67-
original_tests_to_runtimes[rel_path][qualified_name] = 0
67+
original_tests_to_runtimes[rel_path][qualified_name] = 0 # type: ignore[index]
6868
if qualified_name not in optimized_tests_to_runtimes[rel_path]:
69-
optimized_tests_to_runtimes[rel_path][qualified_name] = 0
69+
optimized_tests_to_runtimes[rel_path][qualified_name] = 0 # type: ignore[index]
7070
if invocation_id in original_runtimes_all:
71-
original_tests_to_runtimes[rel_path][qualified_name] += min(original_runtimes_all[invocation_id])
71+
original_tests_to_runtimes[rel_path][qualified_name] += min(original_runtimes_all[invocation_id]) # type: ignore[index]
7272
if invocation_id in optimized_runtimes_all:
73-
optimized_tests_to_runtimes[rel_path][qualified_name] += min(optimized_runtimes_all[invocation_id])
73+
optimized_tests_to_runtimes[rel_path][qualified_name] += min(optimized_runtimes_all[invocation_id]) # type: ignore[index]
7474
# parse into string
7575
all_rel_paths = (
7676
original_tests_to_runtimes.keys()

0 commit comments

Comments
 (0)