Skip to content

Commit 7fbc766

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Delta tracker DMP integration (pytorch#3064)
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable * Added callable registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
1 parent be4e6d7 commit 7fbc766

File tree

8 files changed

+895
-49
lines changed

8 files changed

+895
-49
lines changed

torchrec/distributed/embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,13 +1514,17 @@ def compute_and_output_dist(
15141514
EmbeddingEvent.LOOKUP, self._module_fqn, sharding_type
15151515
):
15161516
embs = lookup(features)
1517+
if self.post_lookup_tracker_fn is not None:
1518+
self.post_lookup_tracker_fn(features, embs)
15171519

15181520
with maybe_annotate_embedding_event(
15191521
EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type
15201522
):
15211523
awaitables_per_sharding.append(
15221524
odist(embs.view(-1, embedding_dim), sharding_ctx)
15231525
)
1526+
if self.post_odist_tracker_fn is not None:
1527+
self.post_odist_tracker_fn()
15241528

15251529
features_before_all2all_per_sharding.append(
15261530
# pyre-fixme[6]: For 1st argument expected `KeyedJaggedTensor` but

torchrec/distributed/embedding_types.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
import abc
1111
import copy
12+
import logging as logger
1213
from dataclasses import dataclass
1314
from enum import Enum, unique
1415
from typing import (
1516
Any,
17+
Callable,
1618
Dict,
1719
Generic,
1820
Iterable,
@@ -370,6 +372,10 @@ def __init__(
370372
self._input_dists: List[nn.Module] = []
371373
self._lookups: List[nn.Module] = []
372374
self._output_dists: List[nn.Module] = []
375+
self.post_lookup_tracker_fn: Optional[
376+
Callable[[KeyedJaggedTensor, torch.Tensor], None]
377+
] = None
378+
self.post_odist_tracker_fn: Optional[Callable[..., None]] = None
373379

374380
def prefetch(
375381
self,
@@ -418,6 +424,41 @@ def train(self, mode: bool = True): # pyre-ignore[3]
418424

419425
return self
420426

427+
def register_post_lookup_tracker_fn(
428+
self,
429+
record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None],
430+
) -> None:
431+
"""
432+
Register a function to be called after lookup is done. This is used for
433+
tracking the lookup results and optimizer states.
434+
435+
Args:
436+
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
437+
438+
"""
439+
if self.post_lookup_tracker_fn is not None:
440+
logger.warning(
441+
"[ModelDeltaTracker] Custom record function already defined, overriding with new callable"
442+
)
443+
self.post_lookup_tracker_fn = record_fn
444+
445+
def register_post_odist_tracker_fn(
446+
self,
447+
record_fn: Callable[..., None],
448+
) -> None:
449+
"""
450+
Register a function to be called after registering odist awaitable.
451+
452+
Args:
453+
record_fn (Callable[Callable[..., None]):
454+
455+
"""
456+
if self.post_odist_tracker_fn is not None:
457+
logger.warning(
458+
"[ModelDeltaTracker] Compaction function already defined, overriding with new callable"
459+
)
460+
self.post_odist_tracker_fn = record_fn
461+
421462
@property
422463
def unsharded_module_type(self) -> Type[nn.Module]:
423464
"""

torchrec/distributed/embeddingbag.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,13 +1458,17 @@ def compute_and_output_dist(
14581458
sharding_type,
14591459
):
14601460
embs = lookup(features)
1461+
if self.post_lookup_tracker_fn is not None:
1462+
self.post_lookup_tracker_fn(features, embs)
14611463

14621464
with maybe_annotate_embedding_event(
14631465
EmbeddingEvent.OUTPUT_DIST,
14641466
self._module_fqn,
14651467
sharding_type,
14661468
):
14671469
awaitables.append(dist(embs, sharding_context))
1470+
if self.post_odist_tracker_fn is not None:
1471+
self.post_odist_tracker_fn()
14681472

14691473
if sharding_context:
14701474
batch_size_per_feature_pre_a2a.extend(

torchrec/distributed/model_parallel.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from torch.nn.modules.module import _IncompatibleKeys
3030
from torch.nn.parallel import DistributedDataParallel
3131
from torchrec.distributed.comm import get_local_size
32+
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker
33+
from torchrec.distributed.model_tracker.types import DeltaRows, ModelTrackerConfig
3234

3335
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
3436
from torchrec.distributed.sharding_plan import get_default_sharders
@@ -208,6 +210,7 @@ class DistributedModelParallel(nn.Module, FusedOptimizerModule):
208210
init_parameters (bool): initialize parameters for modules still on meta device.
209211
data_parallel_wrapper (Optional[DataParallelWrapper]): custom wrapper for data
210212
parallel modules.
213+
model_tracker_config (Optional[DeltaTrackerConfig]): config for model tracker.
211214
212215
Example::
213216
@@ -234,6 +237,7 @@ def __init__(
234237
init_data_parallel: bool = True,
235238
init_parameters: bool = True,
236239
data_parallel_wrapper: Optional[DataParallelWrapper] = None,
240+
model_tracker_config: Optional[ModelTrackerConfig] = None,
237241
) -> None:
238242
super().__init__()
239243
torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}")
@@ -242,6 +246,8 @@ def __init__(
242246

243247
self._ddp_wrapped: bool = False
244248

249+
self.has_model_tracker: bool = model_tracker_config is not None
250+
245251
if env is None:
246252
pg = dist.GroupMember.WORLD
247253
assert pg is not None, "Process group is not initialized"
@@ -286,6 +292,11 @@ def __init__(
286292
if init_data_parallel:
287293
self.init_data_parallel()
288294

295+
if model_tracker_config is not None:
296+
self.model_delta_tracker: ModelDeltaTracker = self._init_delta_tracker(
297+
model_tracker_config, self._dmp_wrapped_module
298+
)
299+
289300
@property
290301
def module(self) -> nn.Module:
291302
"""
@@ -344,6 +355,19 @@ def copy(
344355
def _init_dmp(self, module: nn.Module) -> nn.Module:
345356
return self._shard_modules_impl(module)
346357

358+
def _init_delta_tracker(
359+
self, model_tracker_config: ModelTrackerConfig, module: nn.Module
360+
) -> ModelDeltaTracker:
361+
# Init delta tracker if config is provided
362+
return ModelDeltaTracker(
363+
model=module,
364+
consumers=model_tracker_config.consumers,
365+
delete_on_read=model_tracker_config.delete_on_read,
366+
auto_compact=model_tracker_config.auto_compact,
367+
mode=model_tracker_config.tracking_mode,
368+
fqns_to_skip=model_tracker_config.fqns_to_skip,
369+
)
370+
347371
def _init_optim(self, module: nn.Module) -> CombinedOptimizer:
348372
# pyre-ignore [6]
349373
return CombinedOptimizer(self._fused_optim_impl(module, []))
@@ -421,6 +445,25 @@ def init_parameters(module: nn.Module) -> None:
421445

422446
module.apply(init_parameters)
423447

448+
def get_model_tracker(self) -> ModelDeltaTracker:
449+
"""
450+
Returns the model tracker if it exists.
451+
"""
452+
453+
assert (
454+
self.has_model_tracker
455+
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
456+
return self.model_delta_tracker
457+
458+
def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
459+
"""
460+
Returns the delta rows for the given consumer.
461+
"""
462+
assert (
463+
self.has_model_tracker
464+
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
465+
return self.model_delta_tracker.get_delta(consumer)
466+
424467
def sparse_grad_parameter_names(
425468
self, destination: Optional[List[str]] = None, prefix: str = ""
426469
) -> List[str]:
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
"""Torchrec Model Tracker
11+
12+
The model tracker module provides functionality to track and retrieve unique IDs and
13+
embeddings for supported modules during training. This is useful for identifying and
14+
retrieving the latest delta or unique rows for a model, which can help compute topk
15+
or to stream updated embeddings from predictors to trainers during online training.
16+
17+
Key features include:
18+
- Tracking unique IDs and embeddings for supported modules
19+
- Support for multiple consumers with independent tracking
20+
- Configurable tracking modes (ID_ONLY, EMBEDDING)
21+
- Compaction of tracked data to reduce memory usage
22+
"""
23+
24+
from torchrec.distributed.model_tracker.delta_store import DeltaStore # noqa
25+
from torchrec.distributed.model_tracker.model_delta_tracker import (
26+
ModelDeltaTracker, # noqa
27+
SUPPORTED_MODULES, # noqa
28+
)
29+
from torchrec.distributed.model_tracker.types import (
30+
DeltaRows, # noqa
31+
EmbdUpdateMode, # noqa
32+
IndexedLookup, # noqa
33+
ModelTrackerConfig, # noqa
34+
TrackingMode, # noqa
35+
)

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,15 @@ def __init__(
8585
self.curr_batch_idx: int = 0
8686
self.curr_compact_index: int = 0
8787

88-
self.store: DeltaStore = DeltaStore(UPDATE_MODE_MAP[self._mode])
89-
9088
# from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection
9189
self.tracked_modules: Dict[str, nn.Module] = {}
9290
self.feature_to_fqn: Dict[str, str] = {}
9391
# Generate the mapping from FQN to feature names.
9492
self.fqn_to_feature_names()
95-
# Validate the mode is supported for the given module
96-
self._validate_mode()
93+
# Validate is the mode is supported for the given module and initialize tracker functions
94+
self._validate_and_init_tracker_fns()
95+
96+
self.store: DeltaStore = DeltaStore(UPDATE_MODE_MAP[self._mode])
9797

9898
# Mapping feature name to corresponding FQNs. This is used for retrieving
9999
# the FQN associated with a given feature name in record_lookup().
@@ -105,6 +105,38 @@ def __init__(
105105
self.feature_to_fqn[feature_name] = fqn
106106
logger.info(f"feature_to_fqn: {self.feature_to_fqn}")
107107

108+
def increment_batch_idx(self) -> None:
109+
self.curr_batch_idx += 1
110+
111+
def trigger_compaction(self) -> None:
112+
if self.curr_compact_index >= self.curr_batch_idx:
113+
# only trigger compaction once per iteration
114+
return
115+
116+
# TODO: May need to revisit the compaction logic with multiple consmers.
117+
# At present we take the max per_consumer_batch_idx to ensure we only compact
118+
# newely received lookups
119+
120+
# The trigger_compaction() function is expected to overlap with comms to hide
121+
# compaction compute overhead. Currently, we overlap compaction with odist
122+
# because ID tracking occurs during local embedding lookup, which takes place
123+
# before odist. This way, auto_compact always merges all past IDs tensors since
124+
# the last get_delta call into a single IDs tensor per FQN.
125+
#
126+
# For delete_on_read=True, get_delta() should delete up to per_consumer_batch_idx
127+
# (exclusive). So the compaction should start from per_consumer_batch_idx.
128+
#
129+
# For delete_on_read=False, get_delta() won't delete tensors, but it does advance
130+
# per_consumer_batch_idx accordingly, where all ids prior to per_consumer_batch_idx (exclusive)
131+
# should have been compacted into one tensor regardless of auto_compact=True/False.
132+
# Therefore, all future compactions should start from per_consumer_batch_idx.
133+
start_idx = max(self.per_consumer_batch_idx.values())
134+
end_idx = self.curr_batch_idx
135+
# Update the current compact index to the end index to avoid duplicate compaction.
136+
self.curr_compact_index = end_idx
137+
if start_idx < end_idx:
138+
self.compact(start_idx=start_idx, end_idx=end_idx)
139+
108140
def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
109141
"""
110142
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.
@@ -120,6 +152,11 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
120152
states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt.
121153
"""
122154

155+
# FIXME: Validate if record_lookup is always the first entry point for each batch and mode.
156+
# Need to re-adjust this approach in future as it may not scale if record_lookup is called
157+
# multiple times per batch or if we add more entry points.
158+
self.increment_batch_idx()
159+
123160
# In ID_ONLY mode, we only track feature IDs received in the current batch.
124161
if self._mode == TrackingMode.ID_ONLY:
125162
self.record_ids(kjt)
@@ -333,13 +370,23 @@ def compact(self, start_idx: int, end_idx: int) -> None:
333370
self.store.compact(start_idx, end_idx)
334371

335372
def _clean_fqn_fn(self, fqn: str) -> str:
336-
# strip DMP internal module FQN prefix to match state dict FQN
337-
return fqn.replace("_dmp_wrapped_module.module.", "")
338-
339-
def _validate_mode(self) -> None:
373+
# strip FQN prefixes added by DMP and other TorchRec operations to match state dict FQN
374+
# handles both "_dmp_wrapped_module.module." and "module." prefixes
375+
prefixes_to_strip = ["_dmp_wrapped_module.module.", "module."]
376+
for prefix in prefixes_to_strip:
377+
if fqn.startswith(prefix):
378+
return fqn[len(prefix) :]
379+
return fqn
380+
381+
def _validate_and_init_tracker_fns(self) -> None:
340382
"To validate the mode is supported for the given module"
341383
for module in self.tracked_modules.values():
342384
assert not (
343385
isinstance(module, ShardedEmbeddingBagCollection)
344386
and self._mode == TrackingMode.EMBEDDING
345387
), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings."
388+
# register post lookup function
389+
module.register_post_lookup_tracker_fn(self.record_lookup)
390+
# register auto compaction function at odist
391+
if self._auto_compact:
392+
module.register_post_odist_tracker_fn(self.trigger_compaction)

0 commit comments

Comments
 (0)