1
1
from __future__ import annotations
2
2
3
+ import hashlib
3
4
import os
4
5
from collections import defaultdict
5
6
from itertools import chain
6
- from typing import TYPE_CHECKING
7
+ from typing import TYPE_CHECKING , cast
7
8
8
9
import libcst as cst
9
10
31
32
def get_code_optimization_context (
32
33
function_to_optimize : FunctionToOptimize ,
33
34
project_root_path : Path ,
34
- optim_token_limit : int = 8000 ,
35
- testgen_token_limit : int = 8000 ,
35
+ optim_token_limit : int = 16000 ,
36
+ testgen_token_limit : int = 16000 ,
36
37
) -> CodeOptimizationContext :
37
38
# Get FunctionSource representation of helpers of FTO
38
39
helpers_of_fto_dict , helpers_of_fto_list = get_function_sources_from_jedi (
@@ -73,6 +74,13 @@ def get_code_optimization_context(
73
74
remove_docstrings = False ,
74
75
code_context_type = CodeContextType .READ_ONLY ,
75
76
)
77
+ hashing_code_context = extract_code_markdown_context_from_files (
78
+ helpers_of_fto_dict ,
79
+ helpers_of_helpers_dict ,
80
+ project_root_path ,
81
+ remove_docstrings = True ,
82
+ code_context_type = CodeContextType .HASHING ,
83
+ )
76
84
77
85
# Handle token limits
78
86
final_read_writable_tokens = encoded_tokens_len (final_read_writable_code )
@@ -125,11 +133,15 @@ def get_code_optimization_context(
125
133
testgen_context_code_tokens = encoded_tokens_len (testgen_context_code )
126
134
if testgen_context_code_tokens > testgen_token_limit :
127
135
raise ValueError ("Testgen code context has exceeded token limit, cannot proceed" )
136
+ code_hash_context = hashing_code_context .markdown
137
+ code_hash = hashlib .sha256 (code_hash_context .encode ("utf-8" )).hexdigest ()
128
138
129
139
return CodeOptimizationContext (
130
140
testgen_context_code = testgen_context_code ,
131
141
read_writable_code = final_read_writable_code ,
132
142
read_only_context_code = read_only_context_code ,
143
+ hashing_code_context = code_hash_context ,
144
+ hashing_code_context_hash = code_hash ,
133
145
helper_functions = helpers_of_fto_list ,
134
146
preexisting_objects = preexisting_objects ,
135
147
)
@@ -309,8 +321,8 @@ def extract_code_markdown_context_from_files(
309
321
logger .debug (f"Error while getting read-only code: { e } " )
310
322
continue
311
323
if code_context .strip ():
312
- code_context_with_imports = CodeString (
313
- code = add_needed_imports_from_module (
324
+ if code_context_type != CodeContextType . HASHING :
325
+ code_context = add_needed_imports_from_module (
314
326
src_module_code = original_code ,
315
327
dst_module_code = code_context ,
316
328
src_path = file_path ,
@@ -319,10 +331,9 @@ def extract_code_markdown_context_from_files(
319
331
helper_functions = list (
320
332
helpers_of_fto .get (file_path , set ()) | helpers_of_helpers .get (file_path , set ())
321
333
),
322
- ),
323
- file_path = file_path .relative_to (project_root_path ),
324
- )
325
- code_context_markdown .code_strings .append (code_context_with_imports )
334
+ )
335
+ code_string_context = CodeString (code = code_context , file_path = file_path .relative_to (project_root_path ))
336
+ code_context_markdown .code_strings .append (code_string_context )
326
337
# Extract code from file paths containing helpers of helpers
327
338
for file_path , helper_function_sources in helpers_of_helpers_no_overlap .items ():
328
339
try :
@@ -343,18 +354,17 @@ def extract_code_markdown_context_from_files(
343
354
continue
344
355
345
356
if code_context .strip ():
346
- code_context_with_imports = CodeString (
347
- code = add_needed_imports_from_module (
357
+ if code_context_type != CodeContextType . HASHING :
358
+ code_context = add_needed_imports_from_module (
348
359
src_module_code = original_code ,
349
360
dst_module_code = code_context ,
350
361
src_path = file_path ,
351
362
dst_path = file_path ,
352
363
project_root = project_root_path ,
353
364
helper_functions = list (helpers_of_helpers_no_overlap .get (file_path , set ())),
354
- ),
355
- file_path = file_path .relative_to (project_root_path ),
356
- )
357
- code_context_markdown .code_strings .append (code_context_with_imports )
365
+ )
366
+ code_string_context = CodeString (code = code_context , file_path = file_path .relative_to (project_root_path ))
367
+ code_context_markdown .code_strings .append (code_string_context )
358
368
return code_context_markdown
359
369
360
370
@@ -492,6 +502,8 @@ def parse_code_and_prune_cst(
492
502
filtered_node , found_target = prune_cst_for_testgen_code (
493
503
module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings
494
504
)
505
+ elif code_context_type == CodeContextType .HASHING :
506
+ filtered_node , found_target = prune_cst_for_code_hashing (module , target_functions )
495
507
else :
496
508
raise ValueError (f"Unknown code_context_type: { code_context_type } " ) # noqa: EM102
497
509
@@ -583,6 +595,90 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
583
595
return (node .with_changes (** updates ) if updates else node ), True
584
596
585
597
598
+ def prune_cst_for_code_hashing ( # noqa: PLR0911
599
+ node : cst .CSTNode , target_functions : set [str ], prefix : str = ""
600
+ ) -> tuple [cst .CSTNode | None , bool ]:
601
+ """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
602
+
603
+ Returns
604
+ -------
605
+ (filtered_node, found_target):
606
+ filtered_node: The modified CST node or None if it should be removed.
607
+ found_target: True if a target function was found in this node's subtree.
608
+
609
+ """
610
+ if isinstance (node , (cst .Import , cst .ImportFrom )):
611
+ return None , False
612
+
613
+ if isinstance (node , cst .FunctionDef ):
614
+ qualified_name = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
615
+ if qualified_name in target_functions :
616
+ new_body = remove_docstring_from_body (node .body ) if isinstance (node .body , cst .IndentedBlock ) else node .body
617
+ return node .with_changes (body = new_body ), True
618
+ return None , False
619
+
620
+ if isinstance (node , cst .ClassDef ):
621
+ # Do not recurse into nested classes
622
+ if prefix :
623
+ return None , False
624
+ # Assuming always an IndentedBlock
625
+ if not isinstance (node .body , cst .IndentedBlock ):
626
+ raise ValueError ("ClassDef body is not an IndentedBlock" ) # noqa: TRY004
627
+ class_prefix = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
628
+ new_class_body : list [cst .CSTNode ] = []
629
+ found_target = False
630
+
631
+ for stmt in node .body .body :
632
+ if isinstance (stmt , cst .FunctionDef ):
633
+ qualified_name = f"{ class_prefix } .{ stmt .name .value } "
634
+ if qualified_name in target_functions :
635
+ stmt_with_changes = stmt .with_changes (
636
+ body = remove_docstring_from_body (cast ("cst.IndentedBlock" , stmt .body ))
637
+ )
638
+ new_class_body .append (stmt_with_changes )
639
+ found_target = True
640
+ # If no target functions found, remove the class entirely
641
+ if not new_class_body or not found_target :
642
+ return None , False
643
+ return node .with_changes (
644
+ body = cst .IndentedBlock (cast ("list[cst.BaseStatement]" , new_class_body ))
645
+ ) if new_class_body else None , found_target
646
+
647
+ # For other nodes, we preserve them only if they contain target functions in their children.
648
+ section_names = get_section_names (node )
649
+ if not section_names :
650
+ return node , False
651
+
652
+ updates : dict [str , list [cst .CSTNode ] | cst .CSTNode ] = {}
653
+ found_any_target = False
654
+
655
+ for section in section_names :
656
+ original_content = getattr (node , section , None )
657
+ if isinstance (original_content , (list , tuple )):
658
+ new_children = []
659
+ section_found_target = False
660
+ for child in original_content :
661
+ filtered , found_target = prune_cst_for_code_hashing (child , target_functions , prefix )
662
+ if filtered :
663
+ new_children .append (filtered )
664
+ section_found_target |= found_target
665
+
666
+ if section_found_target :
667
+ found_any_target = True
668
+ updates [section ] = new_children
669
+ elif original_content is not None :
670
+ filtered , found_target = prune_cst_for_code_hashing (original_content , target_functions , prefix )
671
+ if found_target :
672
+ found_any_target = True
673
+ if filtered :
674
+ updates [section ] = filtered
675
+
676
+ if not found_any_target :
677
+ return None , False
678
+
679
+ return (node .with_changes (** updates ) if updates else node ), True
680
+
681
+
586
682
def prune_cst_for_read_only_code ( # noqa: PLR0911
587
683
node : cst .CSTNode ,
588
684
target_functions : set [str ],
0 commit comments