|
| 1 | +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. |
| 2 | + |
| 3 | +# pyre-strict |
| 4 | + |
| 5 | +from typing import Dict, List, Optional, Tuple, Union |
| 6 | + |
| 7 | +import torch |
| 8 | +from torch import nn |
| 9 | + |
| 10 | +from torchrec import ( |
| 11 | + EmbeddingCollection, |
| 12 | + EmbeddingConfig, |
| 13 | + JaggedTensor, |
| 14 | + KeyedJaggedTensor, |
| 15 | + KeyedTensor, |
| 16 | +) |
| 17 | + |
| 18 | +# For MPZCH |
| 19 | +from torchrec.modules.hash_mc_evictions import ( |
| 20 | + HashZchEvictionConfig, |
| 21 | + HashZchEvictionPolicyName, |
| 22 | +) |
| 23 | + |
| 24 | +# For MPZCH |
| 25 | +from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule |
| 26 | +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection |
| 27 | + |
| 28 | +# For original MC |
| 29 | +from torchrec.modules.mc_modules import ( |
| 30 | + DistanceLFU_EvictionPolicy, |
| 31 | + ManagedCollisionCollection, |
| 32 | + MCHManagedCollisionModule, |
| 33 | +) |
| 34 | + |
| 35 | +""" |
| 36 | +Class SparseArch |
| 37 | +An example of SparseArch with 2 tables, each with 2 features. |
| 38 | +It looks up the corresponding embedding for incoming KeyedJaggedTensors with 2 features |
| 39 | +and returns the corresponding embeddings. |
| 40 | +
|
| 41 | +Parameters: |
| 42 | + tables(List[EmbeddingConfig]): List of EmbeddingConfig that defines the embedding table |
| 43 | + device(torch.device): device on which the embedding table should be placed |
| 44 | + buckets(int): number of buckets for each table |
| 45 | + input_hash_size(int): input hash size for each table |
| 46 | + return_remapped(bool): whether to return remapped features, if so, the return will be |
| 47 | + a tuple of (Embedding(KeyedTensor), Remapped_ID(KeyedJaggedTensor)), otherwise, the return will be |
| 48 | + a tuple of (Embedding(KeyedTensor), None) |
| 49 | + is_inference(bool): whether to use inference mode. In inference mode, the module will not update the embedding table |
| 50 | + use_mpzch(bool): whether to use MPZCH or not. If true, the module will use MPZCH managed collision module, |
| 51 | + otherwise, it will use original MC managed collision module |
| 52 | +""" |
| 53 | + |
| 54 | + |
| 55 | +class SparseArch(nn.Module): |
| 56 | + def __init__( |
| 57 | + self, |
| 58 | + tables: List[EmbeddingConfig], |
| 59 | + device: torch.device, |
| 60 | + buckets: int = 4, |
| 61 | + input_hash_size: int = 4000, |
| 62 | + return_remapped: bool = False, |
| 63 | + is_inference: bool = False, |
| 64 | + use_mpzch: bool = False, |
| 65 | + ) -> None: |
| 66 | + super().__init__() |
| 67 | + self._return_remapped = return_remapped |
| 68 | + |
| 69 | + mc_modules = {} |
| 70 | + |
| 71 | + if ( |
| 72 | + use_mpzch |
| 73 | + ): # if using the MPZCH module, we create a HashZchManagedCollisionModule for each table |
| 74 | + mc_modules["table_0"] = HashZchManagedCollisionModule( |
| 75 | + is_inference=is_inference, |
| 76 | + zch_size=( |
| 77 | + tables[0].num_embeddings |
| 78 | + ), # the zch size, that is, the size of local embedding table, should be the same as the size of the embedding table |
| 79 | + input_hash_size=input_hash_size, # the input hash size, that is, the size of the input id space |
| 80 | + device=device, # the device on which the embedding table should be placed |
| 81 | + total_num_buckets=buckets, # the number of buckets, the detailed explanation of the use of buckets can be found in the readme file |
| 82 | + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, # the eviction policy name, in this example use the single ttl eviction policy, which assume an id is evictable if it has been in the table longer than the ttl (time to live) |
| 83 | + eviction_config=HashZchEvictionConfig( # Here we need to specify for each feature, what is the ttl, that is, how long an id can stay in the table before it is evictable |
| 84 | + features=[ |
| 85 | + "feature_0" |
| 86 | + ], # because we only have one feature "feature_0" in this table, so we only need to specify the ttl for this feature |
| 87 | + single_ttl=1, # The unit of ttl is hour. Let's set the ttl to be default to 1, which means an id is evictable if it has been in the table for more than one hour. |
| 88 | + ), |
| 89 | + ) |
| 90 | + mc_modules["table_1"] = HashZchManagedCollisionModule( |
| 91 | + is_inference=is_inference, |
| 92 | + zch_size=(tables[1].num_embeddings), |
| 93 | + device=device, |
| 94 | + input_hash_size=input_hash_size, |
| 95 | + total_num_buckets=buckets, |
| 96 | + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, |
| 97 | + eviction_config=HashZchEvictionConfig( |
| 98 | + features=["feature_1"], |
| 99 | + single_ttl=1, |
| 100 | + ), |
| 101 | + ) |
| 102 | + else: # if not using the MPZCH module, we create a MCHManagedCollisionModule for each table |
| 103 | + mc_modules["table_0"] = MCHManagedCollisionModule( |
| 104 | + zch_size=(tables[0].num_embeddings), |
| 105 | + input_hash_size=input_hash_size, |
| 106 | + device=device, |
| 107 | + eviction_interval=2, |
| 108 | + eviction_policy=DistanceLFU_EvictionPolicy(), |
| 109 | + ) |
| 110 | + mc_modules["table_1"] = MCHManagedCollisionModule( |
| 111 | + zch_size=(tables[1].num_embeddings), |
| 112 | + device=device, |
| 113 | + input_hash_size=input_hash_size, |
| 114 | + eviction_interval=1, |
| 115 | + eviction_policy=DistanceLFU_EvictionPolicy(), |
| 116 | + ) |
| 117 | + |
| 118 | + self._mc_ec: ManagedCollisionEmbeddingCollection = ( |
| 119 | + ManagedCollisionEmbeddingCollection( |
| 120 | + EmbeddingCollection( |
| 121 | + tables=tables, |
| 122 | + device=device, |
| 123 | + ), |
| 124 | + ManagedCollisionCollection( |
| 125 | + managed_collision_modules=mc_modules, |
| 126 | + embedding_configs=tables, |
| 127 | + ), |
| 128 | + return_remapped_features=self._return_remapped, |
| 129 | + ) |
| 130 | + ) |
| 131 | + |
| 132 | + def forward( |
| 133 | + self, kjt: KeyedJaggedTensor |
| 134 | + ) -> Tuple[ |
| 135 | + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] |
| 136 | + ]: |
| 137 | + return self._mc_ec(kjt) |
0 commit comments