66# LICENSE file in the root directory of this source tree.
77
88# pyre-strict
9- from typing import Dict , List , Optional , Union
9+ import logging as logger
10+ from collections import Counter , OrderedDict
11+ from typing import Dict , Iterable , List , Optional , Union
1012
1113import torch
1214
@@ -49,6 +51,8 @@ class ModelDeltaTracker:
4951 call.
5052 delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them.
5153 mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY.
54+ fqns_to_skip (Iterable[str], optional): list of FQNs to skip tracking. Default: None.
55+
5256 """
5357
5458 DEFAULT_CONSUMER : str = "default"
@@ -59,11 +63,15 @@ def __init__(
5963 consumers : Optional [List [str ]] = None ,
6064 delete_on_read : bool = True ,
6165 mode : TrackingMode = TrackingMode .ID_ONLY ,
66+ fqns_to_skip : Iterable [str ] = (),
6267 ) -> None :
6368 self ._model = model
6469 self ._consumers : List [str ] = consumers or [self .DEFAULT_CONSUMER ]
6570 self ._delete_on_read = delete_on_read
6671 self ._mode = mode
72+ self ._fqn_to_feature_map : Dict [str , List [str ]] = {}
73+ self ._fqns_to_skip : Iterable [str ] = fqns_to_skip
74+ self .fqn_to_feature_names ()
6775 pass
6876
6977 def record_lookup (self , kjt : KeyedJaggedTensor , states : torch .Tensor ) -> None :
@@ -85,14 +93,70 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
8593 """
8694 return {}
8795
88- def fqn_to_feature_names (self , module : nn . Module ) -> Dict [str , List [str ]]:
96+ def fqn_to_feature_names (self ) -> Dict [str , List [str ]]:
8997 """
90- Returns a mapping from FQN to feature names for a given module.
91-
92- Args:
93- module (nn.Module): the module to retrieve feature names for.
98+ Returns a mapping of FQN to feature names from all Supported Modules [EmbeddingCollection and EmbeddingBagCollection] present in the given model.
9499 """
95- return {}
100+ if (self ._fqn_to_feature_map is not None ) and len (self ._fqn_to_feature_map ) > 0 :
101+ return self ._fqn_to_feature_map
102+
103+ table_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
104+ table_to_fqn : Dict [str , str ] = OrderedDict ()
105+ for fqn , named_module in self ._model .named_modules ():
106+ split_fqn = fqn .split ("." )
107+ # Skipping partial FQNs present in fqns_to_skip
108+ # TODO: Validate if we need to support more complex patterns for skipping fqns
109+ should_skip = False
110+ for fqn_to_skip in self ._fqns_to_skip :
111+ if fqn_to_skip in split_fqn :
112+ logger .info (f"Skipping { fqn } because it is part of fqns_to_skip" )
113+ should_skip = True
114+ break
115+ if should_skip :
116+ continue
117+
118+ # Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
119+ if isinstance (named_module , SUPPORTED_MODULES ):
120+ for table_name , config in named_module ._table_name_to_config .items ():
121+ logger .info (
122+ f"Found { table_name } for { fqn } with features { config .feature_names } "
123+ )
124+ table_to_feature_names [table_name ] = config .feature_names
125+ for table_name in table_to_feature_names :
126+ # Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
127+ # will incorrectly match fqn with all the table names that have the same prefix
128+ if table_name in split_fqn :
129+ embedding_fqn = fqn .replace ("_dmp_wrapped_module.module." , "" )
130+ if table_name in table_to_fqn :
131+ # Sanity check for validating that we don't have more then one table mapping to same fqn.
132+ logger .warning (
133+ f"Override { table_to_fqn [table_name ]} with { embedding_fqn } for entry { table_name } "
134+ )
135+ table_to_fqn [table_name ] = embedding_fqn
136+ logger .info (f"Table to fqn: { table_to_fqn } " )
137+ flatten_names = [
138+ name for names in table_to_feature_names .values () for name in names
139+ ]
140+ # TODO: Validate if there is a better way to handle duplicate feature names.
141+ # Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case.
142+ if len (set (flatten_names )) != len (flatten_names ):
143+ counts = Counter (flatten_names )
144+ duplicates = [item for item , count in counts .items () if count > 1 ]
145+ logger .warning (f"duplicate feature names found: { duplicates } " )
146+
147+ fqn_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
148+ for table_name in table_to_feature_names :
149+ if table_name not in table_to_fqn :
150+ # This is likely unexpected, where we can't locate the FQN associated with this table.
151+ logger .warning (
152+ f"Table { table_name } not found in { table_to_fqn } , skipping"
153+ )
154+ continue
155+ fqn_to_feature_names [table_to_fqn [table_name ]] = table_to_feature_names [
156+ table_name
157+ ]
158+ self ._fqn_to_feature_map = fqn_to_feature_names
159+ return fqn_to_feature_names
96160
97161 def clear (self , consumer : Optional [str ] = None ) -> None :
98162 """
0 commit comments