Skip to content

Commit 0f7297b

Browse files
lizhouyufacebook-github-bot
authored andcommitted
Create an example for MPZCH (#3063)
Summary: Pull Request resolved: #3063 ### Major changes - Create a `mpzch` folder under the `torchrec/github/examples` folder - Implement a simple SparseArch module with a flag to switch between original and MPZCH managed collision modules - Profile the running time and QPS for model training(GPU)/inference(CPU) - Create a notebook tutorial for ZCH basics and the use of ZCH modules in TorchRec ### ToDos for OSS - When the internal torchrec MPZCH module is OSS - Remove the `BUCK` file - Replace all the `from torchrec.fb.modules` in `sparse_arch.py` to `from torchrec.modules` ### Potential improvement - Add hash collision counter - Show profiling results in the Readme file - Add multi-batch profiling Reviewed By: aporialiao Differential Revision: D75570684
1 parent 71db31d commit 0f7297b

File tree

7 files changed

+801
-4
lines changed

7 files changed

+801
-4
lines changed

examples/zch/Readme.md

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Managed Collision Hash Example
2+
3+
This example demonstrates the usage of managed collision hash feature in TorchRec, which is designed to efficiently handle hash collisions in embedding tables. We include two implementations of the feature: sorted managed collision Hash (MCH) and MPZCH (Multi-Probe Zero Collision Hash).
4+
5+
## Folder Structure
6+
7+
```
8+
managed_collision_hash/
9+
├── Readme.md # This documentation file
10+
├── __init__.py # Python package marker
11+
├── main.py # Main script to run the benchmark
12+
└── sparse_arch.py # Implementation of the sparse architecture with managed collision
13+
└── zero_collision_hash_tutorial.ipynb # Jupyter notebook for the motivation of zero collision hash and the use of zero collision hash modules in TorchRec
14+
```
15+
16+
### Introduction of MPZCH
17+
18+
Multi-probe Zero Collision Hash (MPZCH) is a technique that can be used to reduce the collision rate for embedding table lookups. For the concept of hash collision and why we need to manage the collision, please refer to the [zero collision hash tutorial](zero_collision_hash_tutorial.ipynb).
19+
20+
A MPZCH module contains two essential tables: the identity table and the metadata table.
21+
The identity table is used to record the mapping from input hash value to the remapped ID. The value in each identity table slot is an input hash value, and that hash value's remmaped ID is the index of the slot.
22+
The metadata table share the same length as the identity table. The time when a hash value is inserted into a identity table slot is recorded in the same-indexed slot of the metadata table.
23+
24+
Specifically, MPZCH include the following two steps:
25+
1. **First Probe**: Check if there are available or evictable slots in its identity table.
26+
2. **Second Probe**: Check if the slot for indexed with the input hash value is occupied. If not, directly insert the input hash value into that slot. Otherwise, perform a linear probe to find the next available slot. If all the slots are occupied, find the next evictable slot whose value has stayed in the table for a time longer than a threshold, and replace the expired hash value with the input one.
27+
28+
The use of MPZCH module `HashZchManagedCollisionModule` are introduced with detailed comments in the [sparse_arch.py](sparse_arch.py) file.
29+
30+
The module can be configured to use different eviction policies and parameters.
31+
32+
The detailed function calls are shown in the diagram below
33+
![MPZCH Module Data Flow](docs/mpzch_module_dataflow.png)
34+
35+
#### Relationship among Important Parameters
36+
37+
The `HashZchManagedCollisionModule` module has three important parameters for initialization
38+
- `num_embeddings`: the number of embeddings in the embedding table
39+
- `num_buckets`: the number of buckets in the hash table
40+
41+
The `num_buckets` is used as the minimal sharding unit for the embedding table. Because we are doing linear probe in MPZCH, when resharding the embedding table, we want to avoid separate the remapped index of an input feature ID and its hash value to different ranks. So we make sure they are in the same bucket, and move the whole bucket during resharding.
42+
43+
## Usage
44+
We also prepare a profiling example of an Sparse Arch implemented with different ZCH techniques.
45+
To run the profiling example with sorted ZCH:
46+
47+
```bash
48+
python main.py
49+
```
50+
51+
To run the profiling example with MPZCH:
52+
53+
```bash
54+
python main.py --use_mpzch
55+
```
56+
57+
You can also specify the `batch_size`, `num_iters`, and `num_embeddings_per_table`:
58+
```bash
59+
python main.py --use_mpzch --batch_size 8 --num_iters 100 --num_embeddings_per_table 1000
60+
```
61+
62+
The example allows you to compare the QPS of embedding operations with sorted ZCH and MPZCH. On our server with A100 GPU, the initial QPS benchmark results with `batch_size=8`, `num_iters=100`, and `num_embeddings_per_table=1000` is presented in the table below:
63+
64+
| ZCH module | QPS |
65+
| --- | --- |
66+
| sorted ZCH | 1371.6942797862002 |
67+
| MPZCH | 2750.4449443587414 |
68+
69+
And with `batch_size=1024`, `num_iters=1000`, and `num_embeddings_per_table=1000` is
70+
71+
| ZCH module | QPS |
72+
| --- | --- |
73+
| sorted ZCH | 263827.54955056956 |
74+
| MPZCH | 551306.9687760604 |

examples/zch/__init__.py

Whitespace-only changes.
911 KB
Loading

examples/zch/main.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
import argparse
6+
import time
7+
8+
import torch
9+
10+
from torchrec import EmbeddingConfig, KeyedJaggedTensor
11+
from torchrec.distributed.benchmark.benchmark_utils import get_inputs
12+
from tqdm import tqdm
13+
14+
from .sparse_arch import SparseArch
15+
16+
17+
def main(args: argparse.Namespace) -> None:
18+
"""
19+
This function tests the performance of a Sparse module with or without the MPZCH feature.
20+
Arguments:
21+
use_mpzch: bool, whether to enable MPZCH or not
22+
Prints:
23+
duration: time for a forward pass of the Sparse module with or without MPZCH enabled
24+
collision_rate: the collision rate of the MPZCH feature
25+
"""
26+
print(f"Is use MPZCH: {args.use_mpzch}")
27+
28+
# check available devices
29+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30+
# device = torch.device("cpu")
31+
32+
print(f"Using device: {device}")
33+
34+
# create an embedding configuration
35+
embedding_config = [
36+
EmbeddingConfig(
37+
name="table_0",
38+
feature_names=["feature_0"],
39+
embedding_dim=8,
40+
num_embeddings=args.num_embeddings_per_table,
41+
),
42+
EmbeddingConfig(
43+
name="table_1",
44+
feature_names=["feature_1"],
45+
embedding_dim=8,
46+
num_embeddings=args.num_embeddings_per_table,
47+
),
48+
]
49+
50+
# generate kjt input list
51+
input_kjt_list = []
52+
for _ in range(args.num_iters):
53+
input_kjt_single = KeyedJaggedTensor.from_lengths_sync(
54+
keys=["feature_0", "feature_1"],
55+
# pick a set of 24 random numbers from 0 to args.num_embeddings_per_table
56+
values=torch.LongTensor(
57+
list(
58+
torch.randint(
59+
0, args.num_embeddings_per_table, (3 * args.batch_size,)
60+
)
61+
)
62+
),
63+
lengths=torch.LongTensor([1] * args.batch_size + [2] * args.batch_size),
64+
weights=None,
65+
)
66+
input_kjt_single = input_kjt_single.to(device)
67+
input_kjt_list.append(input_kjt_single)
68+
69+
num_requests = args.num_iters * args.batch_size
70+
71+
# make the model
72+
model = SparseArch(
73+
tables=embedding_config,
74+
device=device,
75+
return_remapped=True,
76+
use_mpzch=args.use_mpzch,
77+
buckets=1,
78+
)
79+
80+
# do the forward pass
81+
if device.type == "cuda":
82+
torch.cuda.synchronize()
83+
starter = torch.cuda.Event(enable_timing=True)
84+
ender = torch.cuda.Event(enable_timing=True)
85+
86+
# record the start time
87+
starter.record()
88+
for it_idx in tqdm(range(args.num_iters)):
89+
# ec_out, remapped_ids_out = model(input_kjt_single)
90+
input_kjt = input_kjt_list[it_idx].to(device)
91+
ec_out, remapped_ids_out = model(input_kjt)
92+
# record the end time
93+
ender.record()
94+
# wait for the end time to be recorded
95+
torch.cuda.synchronize()
96+
duration = starter.elapsed_time(ender) / 1000.0 # convert to seconds
97+
else:
98+
# in cpu mode, MPZCH can only run in inference mode, so we profile the model in eval mode
99+
model.eval()
100+
if args.use_mpzch:
101+
# when using MPZCH modules, we need to manually set the modules to be in inference mode
102+
# pyre-ignore
103+
model._mc_ec._managed_collision_collection._managed_collision_modules[
104+
"table_0"
105+
].reset_inference_mode()
106+
# pyre-ignore
107+
model._mc_ec._managed_collision_collection._managed_collision_modules[
108+
"table_1"
109+
].reset_inference_mode()
110+
111+
start_time = time.time()
112+
for it_idx in tqdm(range(args.num_iters)):
113+
input_kjt = input_kjt_list[it_idx].to(device)
114+
ec_out, remapped_ids_out = model(input_kjt)
115+
end_time = time.time()
116+
duration = end_time - start_time
117+
# get qps
118+
qps = num_requests / duration
119+
print(f"qps: {qps}")
120+
# print the duration
121+
print(f"duration: {duration} seconds")
122+
123+
124+
if __name__ == "__main__":
125+
parser = argparse.ArgumentParser()
126+
parser.add_argument("--use_mpzch", action="store_true", default=False)
127+
parser.add_argument("--num_iters", type=int, default=100)
128+
parser.add_argument("--batch_size", type=int, default=8)
129+
parser.add_argument("--num_embeddings_per_table", type=int, default=1000)
130+
args: argparse.Namespace = parser.parse_args()
131+
main(args)

examples/zch/sparse_arch.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Dict, List, Optional, Tuple, Union
6+
7+
import torch
8+
from torch import nn
9+
10+
from torchrec import (
11+
EmbeddingCollection,
12+
EmbeddingConfig,
13+
JaggedTensor,
14+
KeyedJaggedTensor,
15+
KeyedTensor,
16+
)
17+
18+
# For MPZCH
19+
from torchrec.modules.hash_mc_evictions import (
20+
HashZchEvictionConfig,
21+
HashZchEvictionPolicyName,
22+
)
23+
24+
# For MPZCH
25+
from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule
26+
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection
27+
28+
# For original MC
29+
from torchrec.modules.mc_modules import (
30+
DistanceLFU_EvictionPolicy,
31+
ManagedCollisionCollection,
32+
MCHManagedCollisionModule,
33+
)
34+
35+
"""
36+
Class SparseArch
37+
An example of SparseArch with 2 tables, each with 2 features.
38+
It looks up the corresponding embedding for incoming KeyedJaggedTensors with 2 features
39+
and returns the corresponding embeddings.
40+
41+
Parameters:
42+
tables(List[EmbeddingConfig]): List of EmbeddingConfig that defines the embedding table
43+
device(torch.device): device on which the embedding table should be placed
44+
buckets(int): number of buckets for each table
45+
input_hash_size(int): input hash size for each table
46+
return_remapped(bool): whether to return remapped features, if so, the return will be
47+
a tuple of (Embedding(KeyedTensor), Remapped_ID(KeyedJaggedTensor)), otherwise, the return will be
48+
a tuple of (Embedding(KeyedTensor), None)
49+
is_inference(bool): whether to use inference mode. In inference mode, the module will not update the embedding table
50+
use_mpzch(bool): whether to use MPZCH or not. If true, the module will use MPZCH managed collision module,
51+
otherwise, it will use original MC managed collision module
52+
"""
53+
54+
55+
class SparseArch(nn.Module):
56+
def __init__(
57+
self,
58+
tables: List[EmbeddingConfig],
59+
device: torch.device,
60+
buckets: int = 4,
61+
input_hash_size: int = 4000,
62+
return_remapped: bool = False,
63+
is_inference: bool = False,
64+
use_mpzch: bool = False,
65+
) -> None:
66+
super().__init__()
67+
self._return_remapped = return_remapped
68+
69+
mc_modules = {}
70+
71+
if (
72+
use_mpzch
73+
): # if using the MPZCH module, we create a HashZchManagedCollisionModule for each table
74+
mc_modules["table_0"] = HashZchManagedCollisionModule(
75+
is_inference=is_inference,
76+
zch_size=(
77+
tables[0].num_embeddings
78+
), # the zch size, that is, the size of local embedding table, should be the same as the size of the embedding table
79+
input_hash_size=input_hash_size, # the input hash size, that is, the size of the input id space
80+
device=device, # the device on which the embedding table should be placed
81+
total_num_buckets=buckets, # the number of buckets, the detailed explanation of the use of buckets can be found in the readme file
82+
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, # the eviction policy name, in this example use the single ttl eviction policy, which assume an id is evictable if it has been in the table longer than the ttl (time to live)
83+
eviction_config=HashZchEvictionConfig( # Here we need to specify for each feature, what is the ttl, that is, how long an id can stay in the table before it is evictable
84+
features=[
85+
"feature_0"
86+
], # because we only have one feature "feature_0" in this table, so we only need to specify the ttl for this feature
87+
single_ttl=1, # The unit of ttl is hour. Let's set the ttl to be default to 1, which means an id is evictable if it has been in the table for more than one hour.
88+
),
89+
)
90+
mc_modules["table_1"] = HashZchManagedCollisionModule(
91+
is_inference=is_inference,
92+
zch_size=(tables[1].num_embeddings),
93+
device=device,
94+
input_hash_size=input_hash_size,
95+
total_num_buckets=buckets,
96+
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
97+
eviction_config=HashZchEvictionConfig(
98+
features=["feature_1"],
99+
single_ttl=1,
100+
),
101+
)
102+
else: # if not using the MPZCH module, we create a MCHManagedCollisionModule for each table
103+
mc_modules["table_0"] = MCHManagedCollisionModule(
104+
zch_size=(tables[0].num_embeddings),
105+
input_hash_size=input_hash_size,
106+
device=device,
107+
eviction_interval=2,
108+
eviction_policy=DistanceLFU_EvictionPolicy(),
109+
)
110+
mc_modules["table_1"] = MCHManagedCollisionModule(
111+
zch_size=(tables[1].num_embeddings),
112+
device=device,
113+
input_hash_size=input_hash_size,
114+
eviction_interval=1,
115+
eviction_policy=DistanceLFU_EvictionPolicy(),
116+
)
117+
118+
self._mc_ec: ManagedCollisionEmbeddingCollection = (
119+
ManagedCollisionEmbeddingCollection(
120+
EmbeddingCollection(
121+
tables=tables,
122+
device=device,
123+
),
124+
ManagedCollisionCollection(
125+
managed_collision_modules=mc_modules,
126+
embedding_configs=tables,
127+
),
128+
return_remapped_features=self._return_remapped,
129+
)
130+
)
131+
132+
def forward(
133+
self, kjt: KeyedJaggedTensor
134+
) -> Tuple[
135+
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
136+
]:
137+
return self._mc_ec(kjt)

0 commit comments

Comments
 (0)