Skip to content

Commit 3b6b537

Browse files
aliafzalfacebook-github-bot
authored andcommitted
ModelDeltaTracer implementation for tracking logic (#3060)
Summary: Pull Request resolved: #3060 ### Diff Summary This diff introduces implementation of tracking logic for ID and Embedding mode 1. **Record Functions** ```record_lookup():``` Handles recording of IDs and embeddings based on the tracking mode. ```record_ids():``` Records IDs from a KeyedJaggedTensor. ```record_embeddings():``` Records IDs along with embeddings, ensuring size compatibility between IDs and embeddings. 2. **Delta Retrieval** ```get_delta():``` Retrieves per FQN local IDs for each sparse feature. 3. **Tracked Modules Access** ```get_tracked_modules():``` Returns a dictionary of tracked modules. ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training Reviewed By: TroyGarden Differential Revision: D76094097 fbshipit-source-id: af06e7a8419e680770fd082e41bd09b649d52a59
1 parent a4cd73b commit 3b6b537

File tree

1 file changed

+183
-16
lines changed

1 file changed

+183
-16
lines changed

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 183 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
# pyre-strict
99
import logging as logger
1010
from collections import Counter, OrderedDict
11-
from typing import Dict, Iterable, List, Optional, Union
11+
from typing import Dict, Iterable, List, Optional
1212

1313
import torch
1414

1515
from torch import nn
1616
from torchrec.distributed.embedding import ShardedEmbeddingCollection
1717
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
18+
from torchrec.distributed.model_tracker.delta_store import DeltaStore
1819
from torchrec.distributed.model_tracker.types import (
1920
DeltaRows,
2021
EmbdUpdateMode,
@@ -41,15 +42,16 @@ class ModelDeltaTracker:
4142
ModelDeltaTracker provides a way to track and retrieve unique IDs for supported modules, along with optional support
4243
for tracking corresponding embeddings or states. This is useful for identifying and retrieving the latest delta or
4344
unique rows for a given model, which can help compute topk or to stream updated embeddings from predictors to trainers during
44-
online training. Unique IDs or states can be retrieved by calling the get_unique() method.
45+
online training. Unique IDs or states can be retrieved by calling the get_delta() method.
4546
4647
Args:
4748
model (nn.Module): the model to track.
4849
consumers (List[str], optional): list of consumers to track. Each consumer will
49-
have its own batch offset index. Every get_unique_ids invocation will
50-
only return the new ids for the given consumer since last get_unique_ids
51-
call.
50+
have its own batch offset index. Every get_delta and get_delta_ids invocation will
51+
only return the new values for the given consumer since last call.
5252
delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them.
53+
auto_compact (bool, optional): Trigger compaction automatically during communication at each train cycle.
54+
When set false, compaction is triggered at get_delta() call. Default: False.
5355
mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY.
5456
fqns_to_skip (Iterable[str], optional): list of FQNs to skip tracking. Default: None.
5557
@@ -62,36 +64,177 @@ def __init__(
6264
model: nn.Module,
6365
consumers: Optional[List[str]] = None,
6466
delete_on_read: bool = True,
67+
auto_compact: bool = False,
6568
mode: TrackingMode = TrackingMode.ID_ONLY,
6669
fqns_to_skip: Iterable[str] = (),
6770
) -> None:
6871
self._model = model
6972
self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER]
7073
self._delete_on_read = delete_on_read
74+
self._auto_compact = auto_compact
7175
self._mode = mode
7276
self._fqn_to_feature_map: Dict[str, List[str]] = {}
7377
self._fqns_to_skip: Iterable[str] = fqns_to_skip
78+
79+
# per_consumer_batch_idx is used to track the batch index for each consumer.
80+
# This is used to retrieve the delta values for a given consumer as well as
81+
# start_ids for compaction window.
82+
self.per_consumer_batch_idx: Dict[str, int] = {
83+
c: -1 for c in (consumers or [self.DEFAULT_CONSUMER])
84+
}
85+
self.curr_batch_idx: int = 0
86+
self.curr_compact_index: int = 0
87+
88+
self.store: DeltaStore = DeltaStore(UPDATE_MODE_MAP[self._mode])
89+
90+
# from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection
91+
self.tracked_modules: Dict[str, nn.Module] = {}
92+
self.feature_to_fqn: Dict[str, str] = {}
93+
# Generate the mapping from FQN to feature names.
7494
self.fqn_to_feature_names()
75-
pass
95+
# Validate the mode is supported for the given module
96+
self._validate_mode()
97+
98+
# Mapping feature name to corresponding FQNs. This is used for retrieving
99+
# the FQN associated with a given feature name in record_lookup().
100+
for fqn, feature_names in self._fqn_to_feature_map.items():
101+
for feature_name in feature_names:
102+
if feature_name in self.feature_to_fqn:
103+
logger.warn(f"Duplicate feature name: {feature_name} in fqn {fqn}")
104+
continue
105+
self.feature_to_fqn[feature_name] = fqn
106+
logger.info(f"feature_to_fqn: {self.feature_to_fqn}")
76107

77108
def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
78109
"""
79-
Record Ids from a given KeyedJaggedTensor and embeddings/ parameter states.
110+
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.
111+
112+
This method is run post-lookup, after the embedding lookup has been performed,
113+
as it needs access to both the input IDs and the resulting embeddings.
114+
115+
This function processes the input KeyedJaggedTensor and records either just the IDs
116+
(in ID_ONLY mode) or both IDs and their corresponding embeddings (in EMBEDDING mode).
117+
118+
Args:
119+
kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record.
120+
states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt.
121+
"""
122+
123+
# In ID_ONLY mode, we only track feature IDs received in the current batch.
124+
if self._mode == TrackingMode.ID_ONLY:
125+
self.record_ids(kjt)
126+
# In EMBEDDING mode, we track per feature IDs and corresponding embeddings received in the current batch.
127+
elif self._mode == TrackingMode.EMBEDDING:
128+
self.record_embeddings(kjt, states)
129+
130+
else:
131+
raise NotImplementedError(f"Tracking mode {self._mode} is not supported")
132+
133+
def record_ids(self, kjt: KeyedJaggedTensor) -> None:
134+
"""
135+
Record Ids from a given KeyedJaggedTensor.
136+
137+
Args:
138+
kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record.
139+
"""
140+
per_table_ids: Dict[str, List[torch.Tensor]] = {}
141+
for key in kjt.keys():
142+
table_fqn = self.feature_to_fqn[key]
143+
ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, [])
144+
ids_list.append(kjt[key].values())
145+
per_table_ids[table_fqn] = ids_list
146+
147+
for table_fqn, ids_list in per_table_ids.items():
148+
self.store.append(
149+
batch_idx=self.curr_batch_idx,
150+
table_fqn=table_fqn,
151+
ids=torch.cat(ids_list),
152+
embeddings=None,
153+
)
154+
155+
def record_embeddings(
156+
self, kjt: KeyedJaggedTensor, embeddings: torch.Tensor
157+
) -> None:
158+
"""
159+
Record Ids along with Embeddings from a given KeyedJaggedTensor and embeddings.
80160
81161
Args:
82162
kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record.
83-
states (torch.Tensor): the states to record.
163+
embeddings (torch.Tensor): the embeddings to record.
164+
"""
165+
per_table_ids: Dict[str, List[torch.Tensor]] = {}
166+
per_table_emb: Dict[str, List[torch.Tensor]] = {}
167+
assert embeddings.numel() % kjt.values().numel() == 0, (
168+
f"ids and embeddings size mismatch, expect [{kjt.values().numel()} * emb_dim], "
169+
f"but got {embeddings.numel()}"
170+
)
171+
embeddings_2d = embeddings.view(kjt.values().numel(), -1)
172+
173+
offset: int = 0
174+
for key in kjt.keys():
175+
table_fqn = self.feature_to_fqn[key]
176+
ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, [])
177+
emb_list: List[torch.Tensor] = per_table_emb.get(table_fqn, [])
178+
179+
ids = kjt[key].values()
180+
ids_list.append(ids)
181+
emb_list.append(embeddings_2d[offset : offset + ids.numel()])
182+
offset += ids.numel()
183+
184+
per_table_ids[table_fqn] = ids_list
185+
per_table_emb[table_fqn] = emb_list
186+
187+
for table_fqn, ids_list in per_table_ids.items():
188+
self.store.append(
189+
batch_idx=self.curr_batch_idx,
190+
table_fqn=table_fqn,
191+
ids=torch.cat(ids_list),
192+
embeddings=torch.cat(per_table_emb[table_fqn]),
193+
)
194+
195+
def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]:
196+
"""
197+
Return a dictionary of hit local IDs for each sparse feature. Ids are
198+
first keyed by submodule FQN.
199+
200+
Args:
201+
consumer (str, optional): The consumer to retrieve unique IDs for. If not specified, "default" is used as the default consumer.
84202
"""
85-
pass
203+
per_table_delta_rows = self.get_delta(consumer)
204+
return {fqn: delta_rows.ids for fqn, delta_rows in per_table_delta_rows.items()}
86205

87206
def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
88207
"""
89-
Return a dictionary of hit local IDs for each sparse feature. The IDs are first keyed by submodule FQN.
208+
Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN.
90209
91210
Args:
92-
consumer (str, optional): The consumer to retrieve IDs for. If not specified, "default" is used as the default consumer.
211+
consumer (str, optional): The consumer to retrieve delta values for. If not specified, "default" is used as the default consumer.
93212
"""
94-
return {}
213+
consumer = consumer or self.DEFAULT_CONSUMER
214+
assert (
215+
consumer in self.per_consumer_batch_idx
216+
), f"consumer {consumer} not present in {self.per_consumer_batch_idx.values()}"
217+
218+
index_end: int = self.curr_batch_idx + 1
219+
index_start = max(self.per_consumer_batch_idx.values())
220+
221+
# In case of multiple consumers, it is possible that the previous consumer has already compact these indices
222+
# and index_start could be equal to index_end, in which case we should not compact again.
223+
if index_start < index_end:
224+
self.compact(index_start, index_end)
225+
tracker_rows = self.store.get_delta(
226+
from_idx=self.per_consumer_batch_idx[consumer]
227+
)
228+
self.per_consumer_batch_idx[consumer] = index_end
229+
if self._delete_on_read:
230+
self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values()))
231+
return tracker_rows
232+
233+
def get_tracked_modules(self) -> Dict[str, nn.Module]:
234+
"""
235+
Returns a dictionary of tracked modules.
236+
"""
237+
return self.tracked_modules
95238

96239
def fqn_to_feature_names(self) -> Dict[str, List[str]]:
97240
"""
@@ -114,19 +257,19 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
114257
break
115258
if should_skip:
116259
continue
117-
118260
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
119261
if isinstance(named_module, SUPPORTED_MODULES):
120262
for table_name, config in named_module._table_name_to_config.items():
121263
logger.info(
122264
f"Found {table_name} for {fqn} with features {config.feature_names}"
123265
)
124266
table_to_feature_names[table_name] = config.feature_names
267+
self.tracked_modules[self._clean_fqn_fn(fqn)] = named_module
125268
for table_name in table_to_feature_names:
126269
# Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
127270
# will incorrectly match fqn with all the table names that have the same prefix
128271
if table_name in split_fqn:
129-
embedding_fqn = fqn.replace("_dmp_wrapped_module.module.", "")
272+
embedding_fqn = self._clean_fqn_fn(fqn)
130273
if table_name in table_to_fqn:
131274
# Sanity check for validating that we don't have more then one table mapping to same fqn.
132275
logger.warning(
@@ -165,7 +308,19 @@ def clear(self, consumer: Optional[str] = None) -> None:
165308
Args:
166309
consumer (str, optional): The consumer to clear IDs/States for. If not specified, "default" is used as the default consumer.
167310
"""
168-
pass
311+
# 1. If consumer is None, delete globally.
312+
if consumer is None:
313+
self.store.delete()
314+
return
315+
316+
assert (
317+
consumer in self.per_consumer_batch_idx
318+
), f"consumer {consumer} not found in {self.per_consumer_batch_idx.values()}"
319+
320+
# 2. For single consumer, we can just delete all ids
321+
if len(self.per_consumer_batch_idx) == 1:
322+
self.store.delete()
323+
return
169324

170325
def compact(self, start_idx: int, end_idx: int) -> None:
171326
"""
@@ -175,4 +330,16 @@ def compact(self, start_idx: int, end_idx: int) -> None:
175330
start_idx (int): Starting index for compaction.
176331
end_idx (int): Ending index for compaction.
177332
"""
178-
pass
333+
self.store.compact(start_idx, end_idx)
334+
335+
def _clean_fqn_fn(self, fqn: str) -> str:
336+
# strip DMP internal module FQN prefix to match state dict FQN
337+
return fqn.replace("_dmp_wrapped_module.module.", "")
338+
339+
def _validate_mode(self) -> None:
340+
"To validate the mode is supported for the given module"
341+
for module in self.tracked_modules.values():
342+
assert not (
343+
isinstance(module, ShardedEmbeddingBagCollection)
344+
and self._mode == TrackingMode.EMBEDDING
345+
), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings."

0 commit comments

Comments
 (0)