Skip to content

Commit a4cd73b

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Add logic for fqn_to_feature_names (#3059)
Summary: Pull Request resolved: #3059 # This Diff Added implementation for fqn_to_feature_names method along with initial testing framework and UTs for fqn_to_feature_names # ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training Reviewed By: kausv Differential Revision: D75908963 fbshipit-source-id: 2abe7615fbe36d5adce6551fffc64a9202a19e86
1 parent 3b386f4 commit a4cd73b

File tree

3 files changed

+658
-8
lines changed

3 files changed

+658
-8
lines changed

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
from typing import Dict, List, Optional, Union
9+
import logging as logger
10+
from collections import Counter, OrderedDict
11+
from typing import Dict, Iterable, List, Optional, Union
1012

1113
import torch
1214

@@ -30,7 +32,7 @@
3032
}
3133

3234
# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
33-
SUPPORTED_MODULES = Union[ShardedEmbeddingCollection, ShardedEmbeddingBagCollection]
35+
SUPPORTED_MODULES = (ShardedEmbeddingCollection, ShardedEmbeddingBagCollection)
3436

3537

3638
class ModelDeltaTracker:
@@ -49,6 +51,8 @@ class ModelDeltaTracker:
4951
call.
5052
delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them.
5153
mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY.
54+
fqns_to_skip (Iterable[str], optional): list of FQNs to skip tracking. Default: None.
55+
5256
"""
5357

5458
DEFAULT_CONSUMER: str = "default"
@@ -59,11 +63,15 @@ def __init__(
5963
consumers: Optional[List[str]] = None,
6064
delete_on_read: bool = True,
6165
mode: TrackingMode = TrackingMode.ID_ONLY,
66+
fqns_to_skip: Iterable[str] = (),
6267
) -> None:
6368
self._model = model
6469
self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER]
6570
self._delete_on_read = delete_on_read
6671
self._mode = mode
72+
self._fqn_to_feature_map: Dict[str, List[str]] = {}
73+
self._fqns_to_skip: Iterable[str] = fqns_to_skip
74+
self.fqn_to_feature_names()
6775
pass
6876

6977
def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
@@ -85,14 +93,70 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
8593
"""
8694
return {}
8795

88-
def fqn_to_feature_names(self, module: nn.Module) -> Dict[str, List[str]]:
96+
def fqn_to_feature_names(self) -> Dict[str, List[str]]:
8997
"""
90-
Returns a mapping from FQN to feature names for a given module.
91-
92-
Args:
93-
module (nn.Module): the module to retrieve feature names for.
98+
Returns a mapping of FQN to feature names from all Supported Modules [EmbeddingCollection and EmbeddingBagCollection] present in the given model.
9499
"""
95-
return {}
100+
if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0:
101+
return self._fqn_to_feature_map
102+
103+
table_to_feature_names: Dict[str, List[str]] = OrderedDict()
104+
table_to_fqn: Dict[str, str] = OrderedDict()
105+
for fqn, named_module in self._model.named_modules():
106+
split_fqn = fqn.split(".")
107+
# Skipping partial FQNs present in fqns_to_skip
108+
# TODO: Validate if we need to support more complex patterns for skipping fqns
109+
should_skip = False
110+
for fqn_to_skip in self._fqns_to_skip:
111+
if fqn_to_skip in split_fqn:
112+
logger.info(f"Skipping {fqn} because it is part of fqns_to_skip")
113+
should_skip = True
114+
break
115+
if should_skip:
116+
continue
117+
118+
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
119+
if isinstance(named_module, SUPPORTED_MODULES):
120+
for table_name, config in named_module._table_name_to_config.items():
121+
logger.info(
122+
f"Found {table_name} for {fqn} with features {config.feature_names}"
123+
)
124+
table_to_feature_names[table_name] = config.feature_names
125+
for table_name in table_to_feature_names:
126+
# Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
127+
# will incorrectly match fqn with all the table names that have the same prefix
128+
if table_name in split_fqn:
129+
embedding_fqn = fqn.replace("_dmp_wrapped_module.module.", "")
130+
if table_name in table_to_fqn:
131+
# Sanity check for validating that we don't have more then one table mapping to same fqn.
132+
logger.warning(
133+
f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}"
134+
)
135+
table_to_fqn[table_name] = embedding_fqn
136+
logger.info(f"Table to fqn: {table_to_fqn}")
137+
flatten_names = [
138+
name for names in table_to_feature_names.values() for name in names
139+
]
140+
# TODO: Validate if there is a better way to handle duplicate feature names.
141+
# Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case.
142+
if len(set(flatten_names)) != len(flatten_names):
143+
counts = Counter(flatten_names)
144+
duplicates = [item for item, count in counts.items() if count > 1]
145+
logger.warning(f"duplicate feature names found: {duplicates}")
146+
147+
fqn_to_feature_names: Dict[str, List[str]] = OrderedDict()
148+
for table_name in table_to_feature_names:
149+
if table_name not in table_to_fqn:
150+
# This is likely unexpected, where we can't locate the FQN associated with this table.
151+
logger.warning(
152+
f"Table {table_name} not found in {table_to_fqn}, skipping"
153+
)
154+
continue
155+
fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[
156+
table_name
157+
]
158+
self._fqn_to_feature_map = fqn_to_feature_names
159+
return fqn_to_feature_names
96160

97161
def clear(self, consumer: Optional[str] = None) -> None:
98162
"""

0 commit comments

Comments
 (0)