Skip to content

Commit 41f66c1

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add emb cast in KTRegroupAsDict module (#3008)
Summary: Pull Request resolved: #3008 # context * add `emb_dtype` to the `KTRegroupAsDict` module because in APS some customized regroup module there is a casting operation * to make the model ir-compatible with the "short-circuit" solution, we'll need to absorb this casting function inside the KTRegroupAsDict module Reviewed By: malaybag Differential Revision: D75326034 fbshipit-source-id: 1a32e7c1195b062d19f2b9107ac3af190ebdeb89
1 parent 22078ab commit 41f66c1

File tree

5 files changed

+129
-4
lines changed

5 files changed

+129
-4
lines changed

torchrec/ir/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,4 @@ class PositionWeightedModuleCollectionMetadata:
5454
class KTRegroupAsDictMetadata:
5555
groups: List[List[str]]
5656
keys: List[str]
57+
emb_dtype: Optional[str]

torchrec/ir/serializer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,11 @@ def serialize_to_dict(
411411
# pyre-fixme[6]: For 2nd argument expected `List[List[str]]` but got
412412
# `Union[Module, Tensor]`.
413413
groups=module._groups,
414+
emb_dtype=(
415+
module._emb_dtype.value # pyre-ignore[16]
416+
if module._emb_dtype is not None
417+
else None
418+
),
414419
)
415420
return metadata.__dict__
416421

@@ -425,6 +430,9 @@ def deserialize_from_dict(
425430
return KTRegroupAsDict(
426431
keys=metadata.keys,
427432
groups=metadata.groups,
433+
emb_dtype=(
434+
DataType(metadata.emb_dtype) if metadata.emb_dtype is not None else None
435+
),
428436
)
429437

430438

torchrec/ir/tests/test_serializer.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
encapsulate_ir_modules,
2323
mark_dynamic_kjt,
2424
)
25-
2625
from torchrec.modules.embedding_configs import EmbeddingBagConfig
2726
from torchrec.modules.embedding_modules import EmbeddingBagCollection
2827
from torchrec.modules.feature_processor_ import (
@@ -32,6 +31,7 @@
3231
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
3332
from torchrec.modules.regroup import KTRegroupAsDict
3433
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
34+
from torchrec.types import DataType
3535

3636

3737
class CompoundModule(nn.Module):
@@ -747,3 +747,80 @@ def forward(
747747
deserialized_out = deserialized_model(id_list_features)
748748
for key in eager_out.keys():
749749
torch.testing.assert_close(deserialized_out[key], eager_out[key])
750+
751+
def test_cast_in_regroup(self) -> None:
752+
class Model(nn.Module):
753+
def __init__(self, ebc, fpebc, regroup):
754+
super().__init__()
755+
self.ebc = ebc
756+
self.fpebc = fpebc
757+
self.regroup = regroup
758+
759+
def forward(
760+
self,
761+
features: KeyedJaggedTensor,
762+
) -> Dict[str, torch.Tensor]:
763+
kt1 = self.ebc(features)
764+
kt2 = self.fpebc(features)
765+
return self.regroup([kt1, kt2])
766+
767+
tb1_config = EmbeddingBagConfig(
768+
name="t1",
769+
embedding_dim=3,
770+
num_embeddings=10,
771+
feature_names=["f1", "f2"],
772+
)
773+
tb2_config = EmbeddingBagConfig(
774+
name="t2",
775+
embedding_dim=4,
776+
num_embeddings=10,
777+
feature_names=["f3", "f4"],
778+
)
779+
tb3_config = EmbeddingBagConfig(
780+
name="t3",
781+
embedding_dim=5,
782+
num_embeddings=10,
783+
feature_names=["f5"],
784+
)
785+
786+
ebc = EmbeddingBagCollection(
787+
tables=[tb1_config, tb3_config],
788+
is_weighted=False,
789+
)
790+
max_feature_lengths = {"f3": 100, "f4": 100}
791+
fpebc = FeatureProcessedEmbeddingBagCollection(
792+
EmbeddingBagCollection(
793+
tables=[tb2_config],
794+
is_weighted=True,
795+
),
796+
PositionWeightedModuleCollection(
797+
max_feature_lengths=max_feature_lengths,
798+
),
799+
)
800+
data_type = DataType.BF16
801+
802+
regroup = KTRegroupAsDict(
803+
[["f1", "f3", "f5"], ["f2", "f4"]], ["odd", "even"], emb_dtype=data_type
804+
)
805+
model = Model(ebc, fpebc, regroup)
806+
self.assertEqual(model.regroup._emb_dtype, data_type)
807+
808+
id_list_features = KeyedJaggedTensor.from_offsets_sync(
809+
keys=["f1", "f2", "f3", "f4", "f5"],
810+
values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]),
811+
offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]),
812+
)
813+
# Serialize EBC
814+
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
815+
ep = torch.export.export(
816+
model,
817+
(id_list_features,),
818+
{},
819+
strict=False,
820+
# Allows KJT to not be unflattened and run a forward on unflattened EP
821+
preserve_module_call_signature=(tuple(sparse_fqns)),
822+
)
823+
824+
unflatten_ep = torch.export.unflatten(ep)
825+
deserialized = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
826+
self.assertEqual(deserialized.regroup._emb_dtype, data_type) # pyre-ignore[16]

torchrec/modules/regroup.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
from typing import Dict, List, Optional, Tuple, Union
1313

1414
import torch
15+
from torchrec.modules.embedding_configs import data_type_to_dtype
1516
from torchrec.sparse.jagged_tensor import (
1617
_desugar_keyed_tensors,
1718
_kt_regroup_arguments,
1819
KeyedTensor,
1920
)
20-
from torchrec.types import CacheMixin
21+
from torchrec.types import CacheMixin, DataType
2122

2223

2324
@torch.fx.wrap
@@ -131,7 +132,12 @@ class KTRegroupAsDict(torch.nn.Module, CacheMixin):
131132
132133
"""
133134

134-
def __init__(self, groups: List[List[str]], keys: List[str]) -> None:
135+
def __init__(
136+
self,
137+
groups: List[List[str]],
138+
keys: List[str],
139+
emb_dtype: Optional[DataType] = None,
140+
) -> None:
135141
super().__init__()
136142
torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}")
137143
assert len(groups) == len(keys), "Groups and keys should have same length"
@@ -145,6 +151,7 @@ def __init__(self, groups: List[List[str]], keys: List[str]) -> None:
145151
self._splits: List[int] = []
146152
self._idx_key_pairs: List[Tuple[int, str]] = []
147153
self._permute_pooled_embs_impl = PermuteMultiEmbedding(groups)
154+
self._emb_dtype = emb_dtype
148155

149156
def _init_fbgemm_regroup(self, kts: List[KeyedTensor]) -> None:
150157
self._use_fbgemm_regroup = True
@@ -190,18 +197,26 @@ def _init_regroup(self, kts: List[KeyedTensor]) -> None:
190197
self._splits = splits
191198
self._idx_key_pairs = idx_key_pairs
192199

200+
def embedding_cast(self, embs: List[torch.Tensor]) -> List[torch.Tensor]:
201+
if self._emb_dtype is None:
202+
return embs
203+
dtype = data_type_to_dtype(self._emb_dtype)
204+
return [emb.to(dtype=dtype) for emb in embs]
205+
193206
def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]:
194207
if not self._is_inited:
195208
module_init(self, keyed_tensors)
196209

197210
if self._use_fbgemm_regroup:
198211
values = _get_kts_values(keyed_tensors)
212+
values = self.embedding_cast(values)
199213
permuted_values = self._permute_pooled_embs_impl(values)
200214
return _to_tensor_dict(self._keys, permuted_values)
201215
else:
202216
permuted_values = _permuted_values(
203217
keyed_tensors, self._idx_key_pairs, self._dim
204218
)
219+
permuted_values = self.embedding_cast([permuted_values])[0]
205220
splitted_values = torch.split(permuted_values, self._splits, dim=self._dim)
206221
return _to_tensor_dict(self._keys, splitted_values)
207222

torchrec/modules/tests/test_regroup.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111

1212
import torch
1313
import torch.fx
14-
14+
from hypothesis import given, settings, strategies as st, Verbosity
15+
from torchrec.modules.embedding_configs import data_type_to_dtype
1516
from torchrec.modules.regroup import KTRegroupAsDict
1617
from torchrec.sparse.jagged_tensor import _all_keys_used_once, KeyedTensor
1718
from torchrec.sparse.tests.utils import build_groups, build_kts
19+
from torchrec.types import DataType
1820

1921

2022
class KTRegroupAsDictTest(unittest.TestCase):
@@ -171,3 +173,25 @@ def test_fx_and_jit_regroup_skips_and_duplicates(self) -> None:
171173
eager_out = regroup_module(self.kts)
172174
for key in out.keys():
173175
torch.allclose(out[key], eager_out[key])
176+
177+
# pyre-ignore[56]
178+
@given(data_type=st.sampled_from([DataType.BF16, DataType.FP16]))
179+
@settings(verbosity=Verbosity.verbose, max_examples=20)
180+
def test_regroup_cast(self, data_type: DataType) -> None:
181+
dtype = data_type_to_dtype(data_type)
182+
groups = build_groups(
183+
kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True
184+
)
185+
assert _all_keys_used_once(self.kts, groups) is False
186+
187+
regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys)
188+
cast_regroup = KTRegroupAsDict(
189+
groups=groups, keys=self.keys, emb_dtype=data_type
190+
)
191+
192+
eager_out = regroup_module(self.kts)
193+
cast_out = cast_regroup(self.kts)
194+
195+
for key in eager_out.keys():
196+
self.assertEqual(cast_out[key].dtype, dtype)
197+
torch.allclose(cast_out[key], eager_out[key].to(dtype))

0 commit comments

Comments
 (0)