Skip to content

Commit 8fe970c

Browse files
Merge pull request #307 from codeflash-ai/normalize-code-before-hashing
normalize code before hashing
2 parents a2e78e1 + 7167d2b commit 8fe970c

File tree

2 files changed

+28
-59
lines changed

2 files changed

+28
-59
lines changed

codeflash/context/code_context_extractor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import ast
34
import hashlib
45
import os
56
from collections import defaultdict
@@ -510,7 +511,10 @@ def parse_code_and_prune_cst(
510511
if not found_target:
511512
raise ValueError("No target functions found in the provided code")
512513
if filtered_node and isinstance(filtered_node, cst.Module):
513-
return str(filtered_node.code)
514+
code = str(filtered_node.code)
515+
if code_context_type == CodeContextType.HASHING:
516+
code = ast.unparse(ast.parse(code)) # Makes it standard
517+
return code
514518
return ""
515519

516520

tests/test_code_context_extractor.py

Lines changed: 23 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import sys
34
import tempfile
45
from argparse import Namespace
56
from collections import defaultdict
@@ -114,11 +115,10 @@ class HelperClass:
114115
def helper_method(self):
115116
return self.name
116117
117-
118118
class MainClass:
119119
120120
def main_method(self):
121-
self.name = HelperClass.NestedClass("test").nested_method()
121+
self.name = HelperClass.NestedClass('test').nested_method()
122122
return HelperClass(self.name).helper_method()
123123
```
124124
"""
@@ -181,22 +181,17 @@ class Graph:
181181
182182
def topologicalSortUtil(self, v, visited, stack):
183183
visited[v] = True
184-
185184
for i in self.graph[v]:
186185
if visited[i] == False:
187186
self.topologicalSortUtil(i, visited, stack)
188-
189187
stack.insert(0, v)
190188
191189
def topologicalSort(self):
192190
visited = [False] * self.V
193191
stack = []
194-
195192
for i in range(self.V):
196193
if visited[i] == False:
197194
self.topologicalSortUtil(i, visited, stack)
198-
199-
# Print contents of stack
200195
return stack
201196
```
202197
"""
@@ -614,58 +609,37 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
614609
```python:{file_path.relative_to(opt.args.project_root)}
615610
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
616611
617-
def get_cache_or_call(
618-
self,
619-
*,
620-
func: Callable[_P, Any],
621-
args: tuple[Any, ...],
622-
kwargs: dict[str, Any],
623-
lifespan: datetime.timedelta,
624-
) -> Any: # noqa: ANN401
625-
if os.environ.get("NO_CACHE"):
612+
def get_cache_or_call(self, *, func: Callable[_P, Any], args: tuple[Any, ...], kwargs: dict[str, Any], lifespan: datetime.timedelta) -> Any:
613+
if os.environ.get('NO_CACHE'):
626614
return func(*args, **kwargs)
627-
628615
try:
629616
key = self.hash_key(func=func, args=args, kwargs=kwargs)
630-
except: # noqa: E722
631-
# If we can't create a cache key, we should just call the function.
632-
logging.warning("Failed to hash cache key for function: %s", func)
617+
except:
618+
logging.warning('Failed to hash cache key for function: %s', func)
633619
return func(*args, **kwargs)
634620
result_pair = self.get(key=key)
635-
636621
if result_pair is not None:
637-
cached_time, result = result_pair
638-
if not os.environ.get("RE_CACHE") and (
639-
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
640-
):
622+
{"cached_time, result = result_pair" if sys.version_info >= (3, 11) else "(cached_time, result) = result_pair"}
623+
if not os.environ.get('RE_CACHE') and datetime.datetime.now() < cached_time + lifespan:
641624
try:
642625
return self.decode(data=result)
643626
except CacheBackendDecodeError as e:
644-
logging.warning("Failed to decode cache data: %s", e)
645-
# If decoding fails we will treat this as a cache miss.
646-
# This might happens if underlying class definition of the data changes.
627+
logging.warning('Failed to decode cache data: %s', e)
647628
self.delete(key=key)
648629
result = func(*args, **kwargs)
649630
try:
650631
self.put(key=key, data=self.encode(data=result))
651632
except CacheBackendEncodeError as e:
652-
logging.warning("Failed to encode cache data: %s", e)
653-
# If encoding fails, we should still return the result.
633+
logging.warning('Failed to encode cache data: %s', e)
654634
return result
655635
656-
657636
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
658637
659638
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
660-
if "NO_CACHE" in os.environ:
639+
if 'NO_CACHE' in os.environ:
661640
return self.__wrapped__(*args, **kwargs)
662641
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
663-
return self.__backend__.get_cache_or_call(
664-
func=self.__wrapped__,
665-
args=args,
666-
kwargs=kwargs,
667-
lifespan=self.__duration__,
668-
)
642+
return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__)
669643
```
670644
"""
671645
assert read_write_context.strip() == expected_read_write_context.strip()
@@ -749,10 +723,12 @@ def __repr__(self):
749723
expected_hashing_context = f"""
750724
```python:{file_path.relative_to(opt.args.project_root)}
751725
class MyClass:
726+
752727
def target_method(self):
753728
y = HelperClass().helper_method()
754729
755730
class HelperClass:
731+
756732
def helper_method(self):
757733
return self.x
758734
```
@@ -843,10 +819,12 @@ def __repr__(self):
843819
expected_hashing_context = f"""
844820
```python:{file_path.relative_to(opt.args.project_root)}
845821
class MyClass:
822+
846823
def target_method(self):
847824
y = HelperClass().helper_method()
848825
849826
class HelperClass:
827+
850828
def helper_method(self):
851829
return self.x
852830
```
@@ -927,10 +905,12 @@ def helper_method(self):
927905
expected_hashing_context = f"""
928906
```python:{file_path.relative_to(opt.args.project_root)}
929907
class MyClass:
908+
930909
def target_method(self):
931910
y = HelperClass().helper_method()
932911
933912
class HelperClass:
913+
934914
def helper_method(self):
935915
return self.x
936916
```
@@ -1116,22 +1096,17 @@ class DataProcessor:
11161096
def process_data(self, raw_data: str) -> str:
11171097
return raw_data.upper()
11181098
1119-
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
1099+
def add_prefix(self, data: str, prefix: str='PREFIX_') -> str:
11201100
return prefix + data
11211101
```
11221102
```python:{path_to_file.relative_to(project_root)}
11231103
def fetch_and_process_data():
1124-
# Use the global variable for the request
11251104
response = requests.get(API_URL)
11261105
response.raise_for_status()
1127-
11281106
raw_data = response.text
1129-
1130-
# Use code from another file (utils.py)
11311107
processor = DataProcessor()
11321108
processed = processor.process_data(raw_data)
11331109
processed = processor.add_prefix(processed)
1134-
11351110
return processed
11361111
```
11371112
"""
@@ -1225,16 +1200,11 @@ def transform_data(self, data: str) -> str:
12251200
```
12261201
```python:{path_to_file.relative_to(project_root)}
12271202
def fetch_and_transform_data():
1228-
# Use the global variable for the request
12291203
response = requests.get(API_URL)
1230-
12311204
raw_data = response.text
1232-
1233-
# Use code from another file (utils.py)
12341205
processor = DataProcessor()
12351206
processed = processor.process_data(raw_data)
12361207
transformed = processor.transform_data(processed)
1237-
12381208
return transformed
12391209
```
12401210
"""
@@ -1450,9 +1420,8 @@ def transform_data_all_same_file(self, data):
14501420
new_data = update_data(data)
14511421
return self.transform_using_own_method(new_data)
14521422
1453-
14541423
def update_data(data):
1455-
return data + " updated"
1424+
return data + ' updated'
14561425
```
14571426
"""
14581427

@@ -1591,6 +1560,7 @@ def outside_method():
15911560
expected_hashing_context = f"""
15921561
```python:{file_path.relative_to(opt.args.project_root)}
15931562
class MyClass:
1563+
15941564
def target_method(self):
15951565
return self.x + self.y
15961566
```
@@ -1640,16 +1610,11 @@ def transform_data(self, data: str) -> str:
16401610
expected_hashing_context = """
16411611
```python:main.py
16421612
def fetch_and_transform_data():
1643-
# Use the global variable for the request
16441613
response = requests.get(API_URL)
1645-
16461614
raw_data = response.text
1647-
1648-
# Use code from another file (utils.py)
16491615
processor = DataProcessor()
16501616
processed = processor.process_data(raw_data)
16511617
transformed = processor.transform_data(processed)
1652-
16531618
return transformed
16541619
```
16551620
```python:import_test.py
@@ -1915,9 +1880,9 @@ def subtract(self, a, b):
19151880
return a - b
19161881
19171882
def calculate(self, operation, x, y):
1918-
if operation == "add":
1883+
if operation == 'add':
19191884
return self.add(x, y)
1920-
elif operation == "subtract":
1885+
elif operation == 'subtract':
19211886
return self.subtract(x, y)
19221887
else:
19231888
return None

0 commit comments

Comments
 (0)