Skip to content

Commit 57abf4e

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Renaming embeddings to states to account for optimizers states. (#3116)
Summary: Pull Request resolved: #3116 We have named the field 'embeddings' in the IndexedLookup and DeltaRows data classes for preserving embeddings and model tracker output. Now that we want to expand this functionality to preserve optimizer state as well, renaming the 'embeddings' field to 'states'. Reviewed By: TroyGarden Differential Revision: D76867584 fbshipit-source-id: 2659c267714940a33da1ba8371e37345c601faa9
1 parent 8a4378f commit 57abf4e

File tree

5 files changed

+99
-111
lines changed

5 files changed

+99
-111
lines changed

torchrec/distributed/model_tracker/delta_store.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,27 @@
2121

2222
def _compute_unique_rows(
2323
ids: List[torch.Tensor],
24-
embeddings: Optional[List[torch.Tensor]],
24+
states: Optional[List[torch.Tensor]],
2525
mode: EmbdUpdateMode,
2626
) -> DeltaRows:
2727
r"""
2828
To calculate unique ids and embeddings
2929
"""
3030
if mode == EmbdUpdateMode.NONE:
31-
assert (
32-
embeddings is None
33-
), f"{mode=} == EmbdUpdateMode.NONE but received embeddings"
31+
assert states is None, f"{mode=} == EmbdUpdateMode.NONE but received embeddings"
3432
unique_ids = torch.cat(ids).unique(return_inverse=False)
35-
return DeltaRows(ids=unique_ids, embeddings=None)
33+
return DeltaRows(ids=unique_ids, states=None)
3634
else:
3735
assert (
38-
embeddings is not None
36+
states is not None
3937
), f"{mode=} != EmbdUpdateMode.NONE but received no embeddings"
4038

4139
cat_ids = torch.cat(ids)
42-
cat_embeddings = torch.cat(embeddings)
40+
cat_states = torch.cat(states)
4341

4442
if mode == EmbdUpdateMode.LAST:
4543
cat_ids = cat_ids.flip(dims=[0])
46-
cat_embeddings = cat_embeddings.flip(dims=[0])
44+
cat_states = cat_states.flip(dims=[0])
4745

4846
# Get unique ids and inverse mapping (each element's index in unique_ids).
4947
unique_ids, inverse = cat_ids.unique(sorted=False, return_inverse=True)
@@ -65,8 +63,8 @@ def _compute_unique_rows(
6563
)
6664

6765
# Use first occurrence indices to select corresponding embedding row.
68-
unique_embedings = cat_embeddings[first_occurrence]
69-
return DeltaRows(ids=unique_ids, embeddings=unique_embedings)
66+
unique_states = cat_states[first_occurrence]
67+
return DeltaRows(ids=unique_ids, states=unique_states)
7068

7169

7270
class DeltaStore:
@@ -90,11 +88,11 @@ def append(
9088
batch_idx: int,
9189
table_fqn: str,
9290
ids: torch.Tensor,
93-
embeddings: Optional[torch.Tensor],
91+
states: Optional[torch.Tensor],
9492
) -> None:
9593
table_fqn_lookup = self.per_fqn_lookups.get(table_fqn, [])
9694
table_fqn_lookup.append(
97-
IndexedLookup(batch_idx=batch_idx, ids=ids, embeddings=embeddings)
95+
IndexedLookup(batch_idx=batch_idx, ids=ids, states=states)
9896
)
9997
self.per_fqn_lookups[table_fqn] = table_fqn_lookup
10098

@@ -132,21 +130,21 @@ def compact(self, start_idx: int, end_idx: int) -> None:
132130
new_per_fqn_lookups[table_fqn] = lookups
133131
continue
134132
ids = [lookup.ids for lookup in lookups_to_compact]
135-
embeddings = (
136-
[none_throws(lookup.embeddings) for lookup in lookups_to_compact]
133+
states = (
134+
[none_throws(lookup.states) for lookup in lookups_to_compact]
137135
if self.embdUpdateMode != EmbdUpdateMode.NONE
138136
else None
139137
)
140138
delta_rows = _compute_unique_rows(
141-
ids=ids, embeddings=embeddings, mode=self.embdUpdateMode
139+
ids=ids, states=states, mode=self.embdUpdateMode
142140
)
143141
new_per_fqn_lookups[table_fqn] = (
144142
lookups[:index_l]
145143
+ [
146144
IndexedLookup(
147145
batch_idx=start_idx,
148146
ids=delta_rows.ids,
149-
embeddings=delta_rows.embeddings,
147+
states=delta_rows.states,
150148
)
151149
]
152150
+ lookups[index_r:]
@@ -163,9 +161,9 @@ def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
163161
compact_ids = [
164162
lookup.ids for lookup in lookups if lookup.batch_idx >= from_idx
165163
]
166-
compact_embeddings = (
164+
compact_states = (
167165
[
168-
none_throws(lookup.embeddings)
166+
none_throws(lookup.states)
169167
for lookup in lookups
170168
if lookup.batch_idx >= from_idx
171169
]
@@ -174,6 +172,6 @@ def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
174172
)
175173

176174
delta_per_table_fqn[table_fqn] = _compute_unique_rows(
177-
ids=compact_ids, embeddings=compact_embeddings, mode=self.embdUpdateMode
175+
ids=compact_ids, states=compact_states, mode=self.embdUpdateMode
178176
)
179177
return delta_per_table_fqn

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def record_ids(self, kjt: KeyedJaggedTensor) -> None:
183183
batch_idx=self.curr_batch_idx,
184184
table_fqn=table_fqn,
185185
ids=torch.cat(ids_list),
186-
embeddings=None,
186+
states=None,
187187
)
188188

189189
def record_embeddings(
@@ -223,7 +223,7 @@ def record_embeddings(
223223
batch_idx=self.curr_batch_idx,
224224
table_fqn=table_fqn,
225225
ids=torch.cat(ids_list),
226-
embeddings=torch.cat(per_table_emb[table_fqn]),
226+
states=torch.cat(per_table_emb[table_fqn]),
227227
)
228228

229229
def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]:

0 commit comments

Comments
 (0)