10
10
#!/usr/bin/env python3
11
11
12
12
import abc
13
- from collections import defaultdict
14
13
from typing import Callable , Dict , List , NamedTuple , Optional , Tuple , Union
15
14
16
15
import torch
30
29
31
30
@torch .fx .wrap
32
31
def apply_mc_method_to_jt_dict (
32
+ mc_module : nn .Module ,
33
33
method : str ,
34
34
features_dict : Dict [str , JaggedTensor ],
35
- table_to_features : Dict [str , List [str ]],
36
- managed_collisions : nn .ModuleDict ,
37
35
) -> Dict [str , JaggedTensor ]:
38
36
"""
39
37
Applies an MC method to a dictionary of JaggedTensors, returning the updated dictionary with same ordering
40
38
"""
41
- mc_output : Dict [str , JaggedTensor ] = features_dict .copy ()
42
- for table , features in table_to_features .items ():
43
- mc_input : Dict [str , JaggedTensor ] = {}
44
- for feature in features :
45
- mc_input [feature ] = features_dict [feature ]
46
- mc_module = managed_collisions [table ]
47
- attr = getattr (mc_module , method )
48
- mc_output .update (attr (mc_input ))
49
- return mc_output
39
+ attr = getattr (mc_module , method )
40
+ return attr (features_dict )
50
41
51
42
52
43
@torch .no_grad ()
@@ -153,6 +144,14 @@ def evict(self) -> Optional[torch.Tensor]:
153
144
"""
154
145
pass
155
146
147
+ @abc .abstractmethod
148
+ def remap (self , features : Dict [str , JaggedTensor ]) -> Dict [str , JaggedTensor ]:
149
+ pass
150
+
151
+ @abc .abstractmethod
152
+ def profile (self , features : Dict [str , JaggedTensor ]) -> Dict [str , JaggedTensor ]:
153
+ pass
154
+
156
155
@abc .abstractmethod
157
156
def forward (
158
157
self ,
@@ -203,6 +202,8 @@ class ManagedCollisionCollection(nn.Module):
203
202
embedding_confgs (List[BaseEmbeddingConfig]): List of embedding configs, for each table with a managed collsion module
204
203
"""
205
204
205
+ _table_to_features : Dict [str , List [str ]]
206
+
206
207
def __init__ (
207
208
self ,
208
209
managed_collision_modules : Dict [str , ManagedCollisionModule ],
@@ -216,10 +217,13 @@ def __init__(
216
217
for config in embedding_configs
217
218
for feature in config .feature_names
218
219
}
219
- self ._table_to_features : Dict [ str , List [ str ]] = defaultdict ( list )
220
+ self ._table_to_features = {}
220
221
221
222
self ._compute_jt_dict_to_kjt = ComputeJTDictToKJT ()
222
223
for feature , table in self ._feature_to_table .items ():
224
+ if table not in self ._table_to_features :
225
+ self ._table_to_features [table ] = []
226
+
223
227
self ._table_to_features [table ].append (feature )
224
228
225
229
table_to_config = {config .name : config for config in embedding_configs }
@@ -243,25 +247,18 @@ def forward(
243
247
self ,
244
248
features : KeyedJaggedTensor ,
245
249
) -> KeyedJaggedTensor :
246
- features_dict = apply_mc_method_to_jt_dict (
247
- "preprocess" ,
248
- features_dict = features .to_dict (),
249
- table_to_features = self ._table_to_features ,
250
- managed_collisions = self ._managed_collision_modules ,
251
- )
252
- features_dict = apply_mc_method_to_jt_dict (
253
- "profile" ,
254
- features_dict = features_dict ,
255
- table_to_features = self ._table_to_features ,
256
- managed_collisions = self ._managed_collision_modules ,
257
- )
258
- features_dict = apply_mc_method_to_jt_dict (
259
- "remap" ,
260
- features_dict = features_dict ,
261
- table_to_features = self ._table_to_features ,
262
- managed_collisions = self ._managed_collision_modules ,
263
- )
264
- return self ._compute_jt_dict_to_kjt (features_dict )
250
+ features_dict = features .to_dict ()
251
+ output : Dict [str , JaggedTensor ] = features_dict .copy ()
252
+ for table , mc_module in self ._managed_collision_modules .items ():
253
+ feature_list : List [str ] = self ._table_to_features [table ]
254
+ mc_input : Dict [str , JaggedTensor ] = {}
255
+ for feature in feature_list :
256
+ mc_input [feature ] = features_dict [feature ]
257
+ mc_input = mc_module .preprocess (mc_input )
258
+ mc_input = mc_module .profile (mc_input )
259
+ mc_input = mc_module .remap (mc_input )
260
+ output .update (mc_input )
261
+ return self ._compute_jt_dict_to_kjt (output )
265
262
266
263
def evict (self ) -> Dict [str , Optional [torch .Tensor ]]:
267
264
evictions : Dict [str , Optional [torch .Tensor ]] = {}
@@ -933,7 +930,17 @@ def _init_history_buffers(self, features: Dict[str, JaggedTensor]) -> None:
933
930
self ._history_metadata [metadata_name ] = getattr (self , buffer_name )
934
931
935
932
@torch .no_grad ()
936
- def preprocess (self , features : Dict [str , JaggedTensor ]) -> Dict [str , JaggedTensor ]:
933
+ def preprocess (
934
+ self ,
935
+ features : Dict [str , JaggedTensor ],
936
+ ) -> Dict [str , JaggedTensor ]:
937
+ return apply_mc_method_to_jt_dict (
938
+ self ,
939
+ "_preprocess" ,
940
+ features ,
941
+ )
942
+
943
+ def _preprocess (self , features : Dict [str , JaggedTensor ]) -> Dict [str , JaggedTensor ]:
937
944
if self ._input_hash_func is None :
938
945
return features
939
946
preprocessed_features : Dict [str , JaggedTensor ] = {}
@@ -1070,6 +1077,16 @@ def _coalesce_history(self) -> None:
1070
1077
def profile (
1071
1078
self ,
1072
1079
features : Dict [str , JaggedTensor ],
1080
+ ) -> Dict [str , JaggedTensor ]:
1081
+ return apply_mc_method_to_jt_dict (
1082
+ self ,
1083
+ "_profile" ,
1084
+ features ,
1085
+ )
1086
+
1087
+ def _profile (
1088
+ self ,
1089
+ features : Dict [str , JaggedTensor ],
1073
1090
) -> Dict [str , JaggedTensor ]:
1074
1091
if not self .training :
1075
1092
return features
@@ -1115,7 +1132,17 @@ def profile(
1115
1132
return features
1116
1133
1117
1134
@torch .no_grad ()
1118
- def remap (self , features : Dict [str , JaggedTensor ]) -> Dict [str , JaggedTensor ]:
1135
+ def remap (
1136
+ self ,
1137
+ features : Dict [str , JaggedTensor ],
1138
+ ) -> Dict [str , JaggedTensor ]:
1139
+ return apply_mc_method_to_jt_dict (
1140
+ self ,
1141
+ "_remap" ,
1142
+ features ,
1143
+ )
1144
+
1145
+ def _remap (self , features : Dict [str , JaggedTensor ]) -> Dict [str , JaggedTensor ]:
1119
1146
1120
1147
remapped_features : Dict [str , JaggedTensor ] = {}
1121
1148
for name , feature in features .items ():
0 commit comments