6
6
# LICENSE file in the root directory of this source tree.
7
7
8
8
# pyre-strict
9
- from typing import Dict , List , Optional , Union
9
+ import logging as logger
10
+ from collections import Counter , OrderedDict
11
+ from typing import Dict , Iterable , List , Optional , Union
10
12
11
13
import torch
12
14
30
32
}
31
33
32
34
# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
33
- SUPPORTED_MODULES = Union [ ShardedEmbeddingCollection , ShardedEmbeddingBagCollection ]
35
+ SUPPORTED_MODULES = ( ShardedEmbeddingCollection , ShardedEmbeddingBagCollection )
34
36
35
37
36
38
class ModelDeltaTracker :
@@ -49,6 +51,8 @@ class ModelDeltaTracker:
49
51
call.
50
52
delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them.
51
53
mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY.
54
+ fqns_to_skip (Iterable[str], optional): list of FQNs to skip tracking. Default: None.
55
+
52
56
"""
53
57
54
58
DEFAULT_CONSUMER : str = "default"
@@ -59,11 +63,15 @@ def __init__(
59
63
consumers : Optional [List [str ]] = None ,
60
64
delete_on_read : bool = True ,
61
65
mode : TrackingMode = TrackingMode .ID_ONLY ,
66
+ fqns_to_skip : Iterable [str ] = (),
62
67
) -> None :
63
68
self ._model = model
64
69
self ._consumers : List [str ] = consumers or [self .DEFAULT_CONSUMER ]
65
70
self ._delete_on_read = delete_on_read
66
71
self ._mode = mode
72
+ self ._fqn_to_feature_map : Dict [str , List [str ]] = {}
73
+ self ._fqns_to_skip : Iterable [str ] = fqns_to_skip
74
+ self .fqn_to_feature_names ()
67
75
pass
68
76
69
77
def record_lookup (self , kjt : KeyedJaggedTensor , states : torch .Tensor ) -> None :
@@ -85,14 +93,70 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
85
93
"""
86
94
return {}
87
95
88
- def fqn_to_feature_names (self , module : nn . Module ) -> Dict [str , List [str ]]:
96
+ def fqn_to_feature_names (self ) -> Dict [str , List [str ]]:
89
97
"""
90
- Returns a mapping from FQN to feature names for a given module.
91
-
92
- Args:
93
- module (nn.Module): the module to retrieve feature names for.
98
+ Returns a mapping of FQN to feature names from all Supported Modules [EmbeddingCollection and EmbeddingBagCollection] present in the given model.
94
99
"""
95
- return {}
100
+ if (self ._fqn_to_feature_map is not None ) and len (self ._fqn_to_feature_map ) > 0 :
101
+ return self ._fqn_to_feature_map
102
+
103
+ table_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
104
+ table_to_fqn : Dict [str , str ] = OrderedDict ()
105
+ for fqn , named_module in self ._model .named_modules ():
106
+ split_fqn = fqn .split ("." )
107
+ # Skipping partial FQNs present in fqns_to_skip
108
+ # TODO: Validate if we need to support more complex patterns for skipping fqns
109
+ should_skip = False
110
+ for fqn_to_skip in self ._fqns_to_skip :
111
+ if fqn_to_skip in split_fqn :
112
+ logger .info (f"Skipping { fqn } because it is part of fqns_to_skip" )
113
+ should_skip = True
114
+ break
115
+ if should_skip :
116
+ continue
117
+
118
+ # Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
119
+ if isinstance (named_module , SUPPORTED_MODULES ):
120
+ for table_name , config in named_module ._table_name_to_config .items ():
121
+ logger .info (
122
+ f"Found { table_name } for { fqn } with features { config .feature_names } "
123
+ )
124
+ table_to_feature_names [table_name ] = config .feature_names
125
+ for table_name in table_to_feature_names :
126
+ # Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
127
+ # will incorrectly match fqn with all the table names that have the same prefix
128
+ if table_name in split_fqn :
129
+ embedding_fqn = fqn .replace ("_dmp_wrapped_module.module." , "" )
130
+ if table_name in table_to_fqn :
131
+ # Sanity check for validating that we don't have more then one table mapping to same fqn.
132
+ logger .warning (
133
+ f"Override { table_to_fqn [table_name ]} with { embedding_fqn } for entry { table_name } "
134
+ )
135
+ table_to_fqn [table_name ] = embedding_fqn
136
+ logger .info (f"Table to fqn: { table_to_fqn } " )
137
+ flatten_names = [
138
+ name for names in table_to_feature_names .values () for name in names
139
+ ]
140
+ # TODO: Validate if there is a better way to handle duplicate feature names.
141
+ # Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case.
142
+ if len (set (flatten_names )) != len (flatten_names ):
143
+ counts = Counter (flatten_names )
144
+ duplicates = [item for item , count in counts .items () if count > 1 ]
145
+ logger .warning (f"duplicate feature names found: { duplicates } " )
146
+
147
+ fqn_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
148
+ for table_name in table_to_feature_names :
149
+ if table_name not in table_to_fqn :
150
+ # This is likely unexpected, where we can't locate the FQN associated with this table.
151
+ logger .warning (
152
+ f"Table { table_name } not found in { table_to_fqn } , skipping"
153
+ )
154
+ continue
155
+ fqn_to_feature_names [table_to_fqn [table_name ]] = table_to_feature_names [
156
+ table_name
157
+ ]
158
+ self ._fqn_to_feature_map = fqn_to_feature_names
159
+ return fqn_to_feature_names
96
160
97
161
def clear (self , consumer : Optional [str ] = None ) -> None :
98
162
"""
0 commit comments