Skip to content

Commit 51c936f

Browse files
committed
normalize code before hashing
1 parent a2e78e1 commit 51c936f

File tree

2 files changed

+26
-58
lines changed

2 files changed

+26
-58
lines changed

codeflash/context/code_context_extractor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import ast
34
import hashlib
45
import os
56
from collections import defaultdict
@@ -510,7 +511,10 @@ def parse_code_and_prune_cst(
510511
if not found_target:
511512
raise ValueError("No target functions found in the provided code")
512513
if filtered_node and isinstance(filtered_node, cst.Module):
513-
return str(filtered_node.code)
514+
code = str(filtered_node.code)
515+
if code_context_type == CodeContextType.HASHING:
516+
code = ast.unparse(ast.parse(code)) # Makes it standard
517+
return code
514518
return ""
515519

516520

tests/test_code_context_extractor.py

Lines changed: 21 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,10 @@ class HelperClass:
114114
def helper_method(self):
115115
return self.name
116116
117-
118117
class MainClass:
119118
120119
def main_method(self):
121-
self.name = HelperClass.NestedClass("test").nested_method()
120+
self.name = HelperClass.NestedClass('test').nested_method()
122121
return HelperClass(self.name).helper_method()
123122
```
124123
"""
@@ -181,22 +180,17 @@ class Graph:
181180
182181
def topologicalSortUtil(self, v, visited, stack):
183182
visited[v] = True
184-
185183
for i in self.graph[v]:
186184
if visited[i] == False:
187185
self.topologicalSortUtil(i, visited, stack)
188-
189186
stack.insert(0, v)
190187
191188
def topologicalSort(self):
192189
visited = [False] * self.V
193190
stack = []
194-
195191
for i in range(self.V):
196192
if visited[i] == False:
197193
self.topologicalSortUtil(i, visited, stack)
198-
199-
# Print contents of stack
200194
return stack
201195
```
202196
"""
@@ -614,58 +608,37 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
614608
```python:{file_path.relative_to(opt.args.project_root)}
615609
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
616610
617-
def get_cache_or_call(
618-
self,
619-
*,
620-
func: Callable[_P, Any],
621-
args: tuple[Any, ...],
622-
kwargs: dict[str, Any],
623-
lifespan: datetime.timedelta,
624-
) -> Any: # noqa: ANN401
625-
if os.environ.get("NO_CACHE"):
611+
def get_cache_or_call(self, *, func: Callable[_P, Any], args: tuple[Any, ...], kwargs: dict[str, Any], lifespan: datetime.timedelta) -> Any:
612+
if os.environ.get('NO_CACHE'):
626613
return func(*args, **kwargs)
627-
628614
try:
629615
key = self.hash_key(func=func, args=args, kwargs=kwargs)
630-
except: # noqa: E722
631-
# If we can't create a cache key, we should just call the function.
632-
logging.warning("Failed to hash cache key for function: %s", func)
616+
except:
617+
logging.warning('Failed to hash cache key for function: %s', func)
633618
return func(*args, **kwargs)
634619
result_pair = self.get(key=key)
635-
636620
if result_pair is not None:
637621
cached_time, result = result_pair
638-
if not os.environ.get("RE_CACHE") and (
639-
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
640-
):
622+
if not os.environ.get('RE_CACHE') and datetime.datetime.now() < cached_time + lifespan:
641623
try:
642624
return self.decode(data=result)
643625
except CacheBackendDecodeError as e:
644-
logging.warning("Failed to decode cache data: %s", e)
645-
# If decoding fails we will treat this as a cache miss.
646-
# This might happens if underlying class definition of the data changes.
626+
logging.warning('Failed to decode cache data: %s', e)
647627
self.delete(key=key)
648628
result = func(*args, **kwargs)
649629
try:
650630
self.put(key=key, data=self.encode(data=result))
651631
except CacheBackendEncodeError as e:
652-
logging.warning("Failed to encode cache data: %s", e)
653-
# If encoding fails, we should still return the result.
632+
logging.warning('Failed to encode cache data: %s', e)
654633
return result
655634
656-
657635
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
658636
659637
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
660-
if "NO_CACHE" in os.environ:
638+
if 'NO_CACHE' in os.environ:
661639
return self.__wrapped__(*args, **kwargs)
662640
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
663-
return self.__backend__.get_cache_or_call(
664-
func=self.__wrapped__,
665-
args=args,
666-
kwargs=kwargs,
667-
lifespan=self.__duration__,
668-
)
641+
return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__)
669642
```
670643
"""
671644
assert read_write_context.strip() == expected_read_write_context.strip()
@@ -749,10 +722,12 @@ def __repr__(self):
749722
expected_hashing_context = f"""
750723
```python:{file_path.relative_to(opt.args.project_root)}
751724
class MyClass:
725+
752726
def target_method(self):
753727
y = HelperClass().helper_method()
754728
755729
class HelperClass:
730+
756731
def helper_method(self):
757732
return self.x
758733
```
@@ -843,10 +818,12 @@ def __repr__(self):
843818
expected_hashing_context = f"""
844819
```python:{file_path.relative_to(opt.args.project_root)}
845820
class MyClass:
821+
846822
def target_method(self):
847823
y = HelperClass().helper_method()
848824
849825
class HelperClass:
826+
850827
def helper_method(self):
851828
return self.x
852829
```
@@ -927,10 +904,12 @@ def helper_method(self):
927904
expected_hashing_context = f"""
928905
```python:{file_path.relative_to(opt.args.project_root)}
929906
class MyClass:
907+
930908
def target_method(self):
931909
y = HelperClass().helper_method()
932910
933911
class HelperClass:
912+
934913
def helper_method(self):
935914
return self.x
936915
```
@@ -1116,22 +1095,17 @@ class DataProcessor:
11161095
def process_data(self, raw_data: str) -> str:
11171096
return raw_data.upper()
11181097
1119-
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
1098+
def add_prefix(self, data: str, prefix: str='PREFIX_') -> str:
11201099
return prefix + data
11211100
```
11221101
```python:{path_to_file.relative_to(project_root)}
11231102
def fetch_and_process_data():
1124-
# Use the global variable for the request
11251103
response = requests.get(API_URL)
11261104
response.raise_for_status()
1127-
11281105
raw_data = response.text
1129-
1130-
# Use code from another file (utils.py)
11311106
processor = DataProcessor()
11321107
processed = processor.process_data(raw_data)
11331108
processed = processor.add_prefix(processed)
1134-
11351109
return processed
11361110
```
11371111
"""
@@ -1225,16 +1199,11 @@ def transform_data(self, data: str) -> str:
12251199
```
12261200
```python:{path_to_file.relative_to(project_root)}
12271201
def fetch_and_transform_data():
1228-
# Use the global variable for the request
12291202
response = requests.get(API_URL)
1230-
12311203
raw_data = response.text
1232-
1233-
# Use code from another file (utils.py)
12341204
processor = DataProcessor()
12351205
processed = processor.process_data(raw_data)
12361206
transformed = processor.transform_data(processed)
1237-
12381207
return transformed
12391208
```
12401209
"""
@@ -1450,9 +1419,8 @@ def transform_data_all_same_file(self, data):
14501419
new_data = update_data(data)
14511420
return self.transform_using_own_method(new_data)
14521421
1453-
14541422
def update_data(data):
1455-
return data + " updated"
1423+
return data + ' updated'
14561424
```
14571425
"""
14581426

@@ -1591,6 +1559,7 @@ def outside_method():
15911559
expected_hashing_context = f"""
15921560
```python:{file_path.relative_to(opt.args.project_root)}
15931561
class MyClass:
1562+
15941563
def target_method(self):
15951564
return self.x + self.y
15961565
```
@@ -1640,16 +1609,11 @@ def transform_data(self, data: str) -> str:
16401609
expected_hashing_context = """
16411610
```python:main.py
16421611
def fetch_and_transform_data():
1643-
# Use the global variable for the request
16441612
response = requests.get(API_URL)
1645-
16461613
raw_data = response.text
1647-
1648-
# Use code from another file (utils.py)
16491614
processor = DataProcessor()
16501615
processed = processor.process_data(raw_data)
16511616
transformed = processor.transform_data(processed)
1652-
16531617
return transformed
16541618
```
16551619
```python:import_test.py
@@ -1915,9 +1879,9 @@ def subtract(self, a, b):
19151879
return a - b
19161880
19171881
def calculate(self, operation, x, y):
1918-
if operation == "add":
1882+
if operation == 'add':
19191883
return self.add(x, y)
1920-
elif operation == "subtract":
1884+
elif operation == 'subtract':
19211885
return self.subtract(x, y)
19221886
else:
19231887
return None

0 commit comments

Comments
 (0)