Skip to content

Commit 055119e

Browse files
joshuuuasufacebook-github-bot
authored andcommitted
Fix empty EmbeddingCollection/EmbeddingBagCollection edge cases (#2823)
Summary: Pull Request resolved: #2823 This is an edge-case fix for D68991644. Some models include empty EmbeddingCollections or EmbeddingBagCollections by mistake (?), such as T217919588, thus causing the feature order caching logic broken. Reviewed By: sidt-meta Differential Revision: D71218553 fbshipit-source-id: 204661b4327cd792e3e98f482126d9634f93b011
1 parent 4c203eb commit 055119e

File tree

1 file changed

+32
-28
lines changed

1 file changed

+32
-28
lines changed

torchrec/quant/embedding_modules.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def __init__(
358358
self._feature_names: List[str] = []
359359
self._feature_splits: List[int] = []
360360
self._length_per_key: List[int] = []
361-
self._features_order: List[int] = []
361+
self._features_order: Optional[List[int]] = None
362362
# Registering in a List instead of ModuleList because we want don't want them to be auto-registered.
363363
# Their states will be modified via self.embedding_bags
364364
self._emb_modules: List[nn.Module] = []
@@ -502,20 +502,22 @@ def forward(
502502
kjt_keys = _get_kjt_keys(features)
503503
# Cache the features order since the features will always have the same order of keys in inference.
504504
if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False):
505-
if self._features_order == []:
506-
for k in self._feature_names:
507-
self._features_order.append(kjt_keys.index(k))
508-
self.register_buffer(
509-
"_features_order_tensor",
510-
torch.tensor(
511-
data=self._features_order,
512-
device=features.device(),
513-
dtype=torch.int32,
514-
),
515-
persistent=False,
516-
)
505+
if self._features_order is None:
506+
self._features_order = [kjt_keys.index(k) for k in self._feature_names]
507+
if self._features_order:
508+
self.register_buffer(
509+
"_features_order_tensor",
510+
torch.tensor(
511+
data=self._features_order,
512+
device=features.device(),
513+
dtype=torch.int32,
514+
),
515+
persistent=False,
516+
)
517517
kjt_permute = _permute_kjt(
518-
features, self._features_order, self._features_order_tensor
518+
features,
519+
self._features_order,
520+
getattr(self, "_features_order_tensor", None),
519521
)
520522
else:
521523
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
@@ -752,7 +754,7 @@ def __init__( # noqa C901
752754
self.row_alignment = row_alignment
753755
self._key_to_tables: Dict[DataType, List[EmbeddingConfig]] = defaultdict(list)
754756
self._feature_names: List[str] = []
755-
self._features_order: List[int] = []
757+
self._features_order: Optional[List[int]] = None
756758

757759
self._table_name_to_quantized_weights: Optional[
758760
Dict[str, Tuple[Tensor, Tensor]]
@@ -881,20 +883,22 @@ def forward(
881883
kjt_keys = _get_kjt_keys(features)
882884
# Cache the features order since the features will always have the same order of keys in inference.
883885
if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False):
884-
if self._features_order == []:
885-
for k in self._feature_names:
886-
self._features_order.append(kjt_keys.index(k))
887-
self.register_buffer(
888-
"_features_order_tensor",
889-
torch.tensor(
890-
data=self._features_order,
891-
device=features.device(),
892-
dtype=torch.int32,
893-
),
894-
persistent=False,
895-
)
886+
if self._features_order is None:
887+
self._features_order = [kjt_keys.index(k) for k in self._feature_names]
888+
if self._features_order:
889+
self.register_buffer(
890+
"_features_order_tensor",
891+
torch.tensor(
892+
data=self._features_order,
893+
device=features.device(),
894+
dtype=torch.int32,
895+
),
896+
persistent=False,
897+
)
896898
kjt_permute = _permute_kjt(
897-
features, self._features_order, self._features_order_tensor
899+
features,
900+
self._features_order,
901+
getattr(self, "_features_order_tensor", None),
898902
)
899903
else:
900904
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]

0 commit comments

Comments
 (0)