Skip to content

Commit 816a69d

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
use qualname for the serialized nn.Module (meta-pytorch#3021)
Summary: Pull Request resolved: meta-pytorch#3021 # context * original the IR serializer use type(m).__name__ as the key for the module_to_serializer_cls mapping * this becomes a problem when there are multiple modules with same name * this change uses the full qualname of a class as the key * example ``` # previously ebc = EmbeddingBagCollection(...) type(ebc).__name__ == 'EmbeddingBagCollection' # now ebc = EmbeddingBagCollection(...) qualname(ebc) == 'torchrec.modules.embedding_modules.EmbeddingBagCollection' ``` # debug print > {'torchrec.modules.embedding_modules.EmbeddingBagCollection': <class 'torchrec.ir.serializer.EBCJsonSerializer'>, 'torchrec.modules.feature_processor_.PositionWeightedModule': <class 'torchrec.ir.serializer.PWMJsonSerializer'>, 'torchrec.modules.feature_processor_.PositionWeightedModuleCollection': <class 'torchrec.ir.serializer.PWMCJsonSerializer'>, 'torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection': <class 'torchrec.ir.serializer.FPEBCJsonSerializer'>, 'torchrec.modules.regroup.KTRegroupAsDict': <class 'torchrec.ir.serializer.KTRegroupAsDictJsonSerializer'>} # backward compatibility issue * a previously serialized model (module) won't be correctly deserialized * the workaround is to keep the old key name in the `module_to_serializer_cls` map for a while until the old method is fully deprecated Differential Revision: D75727469
1 parent 151aa02 commit 816a69d

File tree

3 files changed

+27
-15
lines changed

3 files changed

+27
-15
lines changed

torchrec/ir/serializer.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323

2424
from torchrec.ir.types import SerializerInterface
25-
from torchrec.ir.utils import logging
25+
from torchrec.ir.utils import logging, qualname
2626
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType
2727
from torchrec.modules.embedding_modules import EmbeddingBagCollection
2828
from torchrec.modules.feature_processor_ import (
@@ -174,7 +174,7 @@ def swap_meta_forward(cls, module: nn.Module) -> None:
174174

175175
@classmethod
176176
def encapsulate_module(cls, module: nn.Module) -> List[str]:
177-
typename = type(module).__name__
177+
typename = qualname(module)
178178
serializer = cls.module_to_serializer_cls.get(typename)
179179
if serializer is None:
180180
raise ValueError(
@@ -184,8 +184,8 @@ def encapsulate_module(cls, module: nn.Module) -> List[str]:
184184
assert serializer._module_cls is not None
185185
if not isinstance(module, serializer._module_cls):
186186
raise ValueError(
187-
f"Expected module to be of type {serializer._module_cls.__name__}, "
188-
f"got {type(module)}"
187+
f"Expected module to be of type {qualname(serializer._module_cls)}, "
188+
f"got {qualname(module)}"
189189
)
190190
metadata_dict = serializer.serialize_to_dict(module)
191191
raw_dict = {"typename": typename, "metadata_dict": metadata_dict}
@@ -218,7 +218,7 @@ def decapsulate_module(
218218
)
219219
if not isinstance(module, serializer._module_cls):
220220
raise ValueError(
221-
f"Expected module to be of type {serializer._module_cls.__name__}, got {type(module)}"
221+
f"Expected module to be of type {qualname(serializer._module_cls)}, got {qualname(module)}"
222222
)
223223
return module
224224

@@ -275,7 +275,9 @@ def deserialize_from_dict(
275275
)
276276

277277

278-
JsonSerializer.module_to_serializer_cls["EmbeddingBagCollection"] = EBCJsonSerializer
278+
JsonSerializer.module_to_serializer_cls[qualname(EmbeddingBagCollection)] = (
279+
EBCJsonSerializer
280+
)
279281

280282

281283
class PWMJsonSerializer(JsonSerializer):
@@ -299,7 +301,9 @@ def deserialize_from_dict(
299301
return PositionWeightedModule(metadata.max_feature_length, device)
300302

301303

302-
JsonSerializer.module_to_serializer_cls["PositionWeightedModule"] = PWMJsonSerializer
304+
JsonSerializer.module_to_serializer_cls[qualname(PositionWeightedModule)] = (
305+
PWMJsonSerializer
306+
)
303307

304308

305309
class PWMCJsonSerializer(JsonSerializer):
@@ -331,7 +335,7 @@ def deserialize_from_dict(
331335
return PositionWeightedModuleCollection(max_feature_lengths, device)
332336

333337

334-
JsonSerializer.module_to_serializer_cls["PositionWeightedModuleCollection"] = (
338+
JsonSerializer.module_to_serializer_cls[qualname(PositionWeightedModuleCollection)] = (
335339
PWMCJsonSerializer
336340
)
337341

@@ -385,9 +389,9 @@ def deserialize_from_dict(
385389
)
386390

387391

388-
JsonSerializer.module_to_serializer_cls["FeatureProcessedEmbeddingBagCollection"] = (
389-
FPEBCJsonSerializer
390-
)
392+
JsonSerializer.module_to_serializer_cls[
393+
qualname(FeatureProcessedEmbeddingBagCollection)
394+
] = FPEBCJsonSerializer
391395

392396

393397
class KTRegroupAsDictJsonSerializer(JsonSerializer):
@@ -436,6 +440,6 @@ def deserialize_from_dict(
436440
)
437441

438442

439-
JsonSerializer.module_to_serializer_cls["KTRegroupAsDict"] = (
443+
JsonSerializer.module_to_serializer_cls[qualname(KTRegroupAsDict)] = (
440444
KTRegroupAsDictJsonSerializer
441445
)

torchrec/ir/tests/test_serializer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
decapsulate_ir_modules,
2222
encapsulate_ir_modules,
2323
mark_dynamic_kjt,
24+
qualname,
2425
)
2526
from torchrec.modules.embedding_configs import EmbeddingBagConfig
2627
from torchrec.modules.embedding_modules import EmbeddingBagCollection
@@ -524,7 +525,7 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
524525

525526
eager_out = model(id_list_features)
526527

527-
JsonSerializer.module_to_serializer_cls["CompoundModule"] = (
528+
JsonSerializer.module_to_serializer_cls[qualname(CompoundModule)] = (
528529
CompoundModuleSerializer
529530
)
530531
# Serialize

torchrec/ir/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import logging
1313
import operator
1414
from collections import defaultdict
15-
from typing import Dict, List, Optional, Tuple, Type
15+
from typing import Dict, List, Optional, Tuple, Type, Union
1616

1717
import torch
1818

@@ -34,6 +34,13 @@
3434
logger: logging.Logger = logging.getLogger(__name__)
3535

3636

37+
def qualname(m: Union[nn.Module, type[nn.Module]]) -> str:
38+
if isinstance(m, nn.Module):
39+
return type(m).__module__ + "." + type(m).__qualname__
40+
else:
41+
return m.__module__ + "." + m.__qualname__
42+
43+
3744
def get_device(tensors: List[Optional[torch.Tensor]]) -> Optional[torch.device]:
3845
"""
3946
Returns the device of the first non-None tensor in the list.
@@ -115,7 +122,7 @@ def encapsulate_ir_modules(
115122
preserve_fqns: List[str] = [] # fqns of the serialized modules
116123
children: List[str] = [] # fqns of the children that need further serialization
117124
# handle current module, and find the children which need further serialization
118-
if type(module).__name__ in serializer.module_to_serializer_cls:
125+
if qualname(module) in serializer.module_to_serializer_cls:
119126
children = serializer.encapsulate_module(module)
120127
preserve_fqns.append(fqn)
121128
else:

0 commit comments

Comments
 (0)