26
26
)
27
27
from torchrec .distributed .fused_embedding import FusedEmbeddingCollectionSharder
28
28
from torchrec .distributed .fused_embeddingbag import FusedEmbeddingBagCollectionSharder
29
- from torchrec .distributed .types import QuantizedCommCodecs
29
+ from torchrec .distributed .mc_embedding_modules import (
30
+ BaseManagedCollisionEmbeddingCollectionSharder ,
31
+ )
32
+ from torchrec .distributed .mc_embeddingbag import (
33
+ ShardedManagedCollisionEmbeddingBagCollection ,
34
+ )
35
+ from torchrec .distributed .mc_modules import ManagedCollisionCollectionSharder
36
+ from torchrec .distributed .types import (
37
+ ParameterSharding ,
38
+ QuantizedCommCodecs ,
39
+ ShardingEnv ,
40
+ )
30
41
from torchrec .distributed .utils import CopyableMixin
31
42
from torchrec .modules .activation import SwishLayerNorm
32
43
from torchrec .modules .embedding_configs import (
39
50
from torchrec .modules .feature_processor import PositionWeightedProcessor
40
51
from torchrec .modules .feature_processor_ import PositionWeightedModuleCollection
41
52
from torchrec .modules .fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
53
+ from torchrec .modules .mc_embedding_modules import ManagedCollisionEmbeddingBagCollection
54
+ from torchrec .modules .mc_modules import (
55
+ DistanceLFU_EvictionPolicy ,
56
+ ManagedCollisionCollection ,
57
+ MCHManagedCollisionModule ,
58
+ )
42
59
from torchrec .modules .regroup import KTRegroupAsDict
43
60
from torchrec .sparse .jagged_tensor import _to_offsets , KeyedJaggedTensor , KeyedTensor
44
61
from torchrec .streamable import Pipelineable
@@ -1351,6 +1368,7 @@ def __init__(
1351
1368
feature_processor_modules : Optional [Dict [str , torch .nn .Module ]] = None ,
1352
1369
over_arch_clazz : Type [nn .Module ] = TestOverArch ,
1353
1370
postproc_module : Optional [nn .Module ] = None ,
1371
+ zch : bool = False ,
1354
1372
) -> None :
1355
1373
super ().__init__ (
1356
1374
tables = cast (List [BaseEmbeddingConfig ], tables ),
@@ -1362,12 +1380,20 @@ def __init__(
1362
1380
if weighted_tables is None :
1363
1381
weighted_tables = []
1364
1382
self .dense = TestDenseArch (num_float_features , dense_device )
1365
- self .sparse = TestSparseArch (
1366
- tables ,
1367
- weighted_tables ,
1368
- sparse_device ,
1369
- max_feature_lengths ,
1370
- )
1383
+ if zch :
1384
+ self .sparse : nn .Module = TestSparseArchZCH (
1385
+ tables ,
1386
+ weighted_tables ,
1387
+ torch .device ("meta" ),
1388
+ return_remapped = True ,
1389
+ )
1390
+ else :
1391
+ self .sparse = TestSparseArch (
1392
+ tables ,
1393
+ weighted_tables ,
1394
+ sparse_device ,
1395
+ max_feature_lengths ,
1396
+ )
1371
1397
1372
1398
embedding_names = (
1373
1399
list (embedding_groups .values ())[0 ] if embedding_groups else None
@@ -1687,6 +1713,64 @@ def compute_kernels(
1687
1713
return [self ._kernel_type ]
1688
1714
1689
1715
1716
+ class TestMCSharder (ManagedCollisionCollectionSharder ):
1717
+ def __init__ (
1718
+ self ,
1719
+ sharding_type : str ,
1720
+ qcomm_codecs_registry : Optional [Dict [str , QuantizedCommCodecs ]] = None ,
1721
+ ) -> None :
1722
+ self ._sharding_type = sharding_type
1723
+ super ().__init__ (qcomm_codecs_registry = qcomm_codecs_registry )
1724
+
1725
+ def sharding_types (self , compute_device_type : str ) -> List [str ]:
1726
+ return [self ._sharding_type ]
1727
+
1728
+
1729
+ class TestEBCSharderMCH (
1730
+ BaseManagedCollisionEmbeddingCollectionSharder [
1731
+ ManagedCollisionEmbeddingBagCollection
1732
+ ]
1733
+ ):
1734
+ def __init__ (
1735
+ self ,
1736
+ sharding_type : str ,
1737
+ kernel_type : str ,
1738
+ fused_params : Optional [Dict [str , Any ]] = None ,
1739
+ qcomm_codecs_registry : Optional [Dict [str , QuantizedCommCodecs ]] = None ,
1740
+ ) -> None :
1741
+ super ().__init__ (
1742
+ TestEBCSharder (
1743
+ sharding_type , kernel_type , fused_params , qcomm_codecs_registry
1744
+ ),
1745
+ TestMCSharder (sharding_type , qcomm_codecs_registry ),
1746
+ qcomm_codecs_registry = qcomm_codecs_registry ,
1747
+ )
1748
+
1749
+ @property
1750
+ def module_type (self ) -> Type [ManagedCollisionEmbeddingBagCollection ]:
1751
+ return ManagedCollisionEmbeddingBagCollection
1752
+
1753
+ def shard (
1754
+ self ,
1755
+ module : ManagedCollisionEmbeddingBagCollection ,
1756
+ params : Dict [str , ParameterSharding ],
1757
+ env : ShardingEnv ,
1758
+ device : Optional [torch .device ] = None ,
1759
+ module_fqn : Optional [str ] = None ,
1760
+ ) -> ShardedManagedCollisionEmbeddingBagCollection :
1761
+ if device is None :
1762
+ device = torch .device ("cuda" )
1763
+ return ShardedManagedCollisionEmbeddingBagCollection (
1764
+ module ,
1765
+ params ,
1766
+ # pyre-ignore [6]
1767
+ ebc_sharder = self ._e_sharder ,
1768
+ mc_sharder = self ._mc_sharder ,
1769
+ env = env ,
1770
+ device = device ,
1771
+ )
1772
+
1773
+
1690
1774
class TestFusedEBCSharder (FusedEmbeddingBagCollectionSharder ):
1691
1775
def __init__ (
1692
1776
self ,
@@ -2188,3 +2272,122 @@ def forward(self, input: ModelInput) -> ModelInput:
2188
2272
modified_input = copy .deepcopy (input )
2189
2273
modified_input .idlist_features = self .fp_proc (modified_input .idlist_features )
2190
2274
return modified_input
2275
+
2276
+
2277
+ class TestSparseArchZCH (nn .Module ):
2278
+ """
2279
+ Basic nn.Module for testing MCH EmbeddingBagCollection
2280
+
2281
+ Args:
2282
+ tables
2283
+ weighted_tables
2284
+ device
2285
+ return_remapped
2286
+
2287
+ Call Args:
2288
+ features
2289
+ weighted_features
2290
+ batch_size
2291
+
2292
+ Returns:
2293
+ KeyedTensor
2294
+
2295
+ Example::
2296
+
2297
+ TestSparseArch()
2298
+ """
2299
+
2300
+ def __init__ (
2301
+ self ,
2302
+ tables : List [EmbeddingBagConfig ],
2303
+ weighted_tables : List [EmbeddingBagConfig ],
2304
+ device : torch .device ,
2305
+ return_remapped : bool = False ,
2306
+ ) -> None :
2307
+ super ().__init__ ()
2308
+ self ._return_remapped = return_remapped
2309
+
2310
+ mc_modules = {}
2311
+ for table in tables :
2312
+ mc_modules [table .name ] = MCHManagedCollisionModule (
2313
+ zch_size = table .num_embeddings ,
2314
+ input_hash_size = 4000 ,
2315
+ device = device ,
2316
+ # TODO: If eviction interval is set to
2317
+ # a low number (e.g. 2), semi-sync pipeline test will
2318
+ # fail with in-place modification error during
2319
+ # loss.backward(). This is because during semi-sync training,
2320
+ # we run embedding module forward after autograd graph
2321
+ # is constructed, but if MCH eviction happens, the
2322
+ # variable used in autograd will have been modified
2323
+ eviction_interval = 1000 ,
2324
+ eviction_policy = DistanceLFU_EvictionPolicy (),
2325
+ )
2326
+
2327
+ self .ebc : ManagedCollisionEmbeddingBagCollection = (
2328
+ ManagedCollisionEmbeddingBagCollection (
2329
+ EmbeddingBagCollection (
2330
+ tables = tables ,
2331
+ device = device ,
2332
+ ),
2333
+ ManagedCollisionCollection (
2334
+ managed_collision_modules = mc_modules ,
2335
+ embedding_configs = tables ,
2336
+ ),
2337
+ return_remapped_features = self ._return_remapped ,
2338
+ )
2339
+ )
2340
+
2341
+ self .weighted_ebc : Optional [ManagedCollisionEmbeddingBagCollection ] = None
2342
+ if weighted_tables :
2343
+ weighted_mc_modules = {}
2344
+ for table in weighted_tables :
2345
+ weighted_mc_modules [table .name ] = MCHManagedCollisionModule (
2346
+ zch_size = table .num_embeddings ,
2347
+ input_hash_size = 4000 ,
2348
+ device = device ,
2349
+ # TODO: Support MCH evictions during semi-sync
2350
+ eviction_interval = 1000 ,
2351
+ eviction_policy = DistanceLFU_EvictionPolicy (),
2352
+ )
2353
+ self .weighted_ebc : ManagedCollisionEmbeddingBagCollection = (
2354
+ ManagedCollisionEmbeddingBagCollection (
2355
+ EmbeddingBagCollection (
2356
+ tables = weighted_tables ,
2357
+ device = device ,
2358
+ is_weighted = True ,
2359
+ ),
2360
+ ManagedCollisionCollection (
2361
+ managed_collision_modules = weighted_mc_modules ,
2362
+ embedding_configs = weighted_tables ,
2363
+ ),
2364
+ return_remapped_features = self ._return_remapped ,
2365
+ )
2366
+ )
2367
+
2368
+ def forward (
2369
+ self ,
2370
+ features : KeyedJaggedTensor ,
2371
+ weighted_features : Optional [KeyedJaggedTensor ] = None ,
2372
+ batch_size : Optional [int ] = None ,
2373
+ ) -> KeyedTensor :
2374
+ """
2375
+ Runs forward and MC EBC and optionally, weighted MC EBC,
2376
+ then merges the results into one KeyedTensor
2377
+
2378
+ Args:
2379
+ features
2380
+ weighted_features
2381
+ batch_size
2382
+ Returns:
2383
+ KeyedTensor
2384
+ """
2385
+ ebc , _ = self .ebc (features )
2386
+ ebc = _post_ebc_test_wrap_function (ebc )
2387
+ w_ebc , _ = (
2388
+ self .weighted_ebc (weighted_features )
2389
+ if self .weighted_ebc is not None and weighted_features is not None
2390
+ else None
2391
+ )
2392
+ result = _post_sparsenn_forward (ebc , None , w_ebc , batch_size )
2393
+ return result
0 commit comments