21
21
22
22
def _compute_unique_rows (
23
23
ids : List [torch .Tensor ],
24
- embeddings : Optional [List [torch .Tensor ]],
24
+ states : Optional [List [torch .Tensor ]],
25
25
mode : EmbdUpdateMode ,
26
26
) -> DeltaRows :
27
27
r"""
28
28
To calculate unique ids and embeddings
29
29
"""
30
30
if mode == EmbdUpdateMode .NONE :
31
- assert (
32
- embeddings is None
33
- ), f"{ mode = } == EmbdUpdateMode.NONE but received embeddings"
31
+ assert states is None , f"{ mode = } == EmbdUpdateMode.NONE but received embeddings"
34
32
unique_ids = torch .cat (ids ).unique (return_inverse = False )
35
- return DeltaRows (ids = unique_ids , embeddings = None )
33
+ return DeltaRows (ids = unique_ids , states = None )
36
34
else :
37
35
assert (
38
- embeddings is not None
36
+ states is not None
39
37
), f"{ mode = } != EmbdUpdateMode.NONE but received no embeddings"
40
38
41
39
cat_ids = torch .cat (ids )
42
- cat_embeddings = torch .cat (embeddings )
40
+ cat_states = torch .cat (states )
43
41
44
42
if mode == EmbdUpdateMode .LAST :
45
43
cat_ids = cat_ids .flip (dims = [0 ])
46
- cat_embeddings = cat_embeddings .flip (dims = [0 ])
44
+ cat_states = cat_states .flip (dims = [0 ])
47
45
48
46
# Get unique ids and inverse mapping (each element's index in unique_ids).
49
47
unique_ids , inverse = cat_ids .unique (sorted = False , return_inverse = True )
@@ -65,8 +63,8 @@ def _compute_unique_rows(
65
63
)
66
64
67
65
# Use first occurrence indices to select corresponding embedding row.
68
- unique_embedings = cat_embeddings [first_occurrence ]
69
- return DeltaRows (ids = unique_ids , embeddings = unique_embedings )
66
+ unique_states = cat_states [first_occurrence ]
67
+ return DeltaRows (ids = unique_ids , states = unique_states )
70
68
71
69
72
70
class DeltaStore :
@@ -90,11 +88,11 @@ def append(
90
88
batch_idx : int ,
91
89
table_fqn : str ,
92
90
ids : torch .Tensor ,
93
- embeddings : Optional [torch .Tensor ],
91
+ states : Optional [torch .Tensor ],
94
92
) -> None :
95
93
table_fqn_lookup = self .per_fqn_lookups .get (table_fqn , [])
96
94
table_fqn_lookup .append (
97
- IndexedLookup (batch_idx = batch_idx , ids = ids , embeddings = embeddings )
95
+ IndexedLookup (batch_idx = batch_idx , ids = ids , states = states )
98
96
)
99
97
self .per_fqn_lookups [table_fqn ] = table_fqn_lookup
100
98
@@ -132,21 +130,21 @@ def compact(self, start_idx: int, end_idx: int) -> None:
132
130
new_per_fqn_lookups [table_fqn ] = lookups
133
131
continue
134
132
ids = [lookup .ids for lookup in lookups_to_compact ]
135
- embeddings = (
136
- [none_throws (lookup .embeddings ) for lookup in lookups_to_compact ]
133
+ states = (
134
+ [none_throws (lookup .states ) for lookup in lookups_to_compact ]
137
135
if self .embdUpdateMode != EmbdUpdateMode .NONE
138
136
else None
139
137
)
140
138
delta_rows = _compute_unique_rows (
141
- ids = ids , embeddings = embeddings , mode = self .embdUpdateMode
139
+ ids = ids , states = states , mode = self .embdUpdateMode
142
140
)
143
141
new_per_fqn_lookups [table_fqn ] = (
144
142
lookups [:index_l ]
145
143
+ [
146
144
IndexedLookup (
147
145
batch_idx = start_idx ,
148
146
ids = delta_rows .ids ,
149
- embeddings = delta_rows .embeddings ,
147
+ states = delta_rows .states ,
150
148
)
151
149
]
152
150
+ lookups [index_r :]
@@ -163,9 +161,9 @@ def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
163
161
compact_ids = [
164
162
lookup .ids for lookup in lookups if lookup .batch_idx >= from_idx
165
163
]
166
- compact_embeddings = (
164
+ compact_states = (
167
165
[
168
- none_throws (lookup .embeddings )
166
+ none_throws (lookup .states )
169
167
for lookup in lookups
170
168
if lookup .batch_idx >= from_idx
171
169
]
@@ -174,6 +172,6 @@ def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
174
172
)
175
173
176
174
delta_per_table_fqn [table_fqn ] = _compute_unique_rows (
177
- ids = compact_ids , embeddings = compact_embeddings , mode = self .embdUpdateMode
175
+ ids = compact_ids , states = compact_states , mode = self .embdUpdateMode
178
176
)
179
177
return delta_per_table_fqn
0 commit comments