Skip to content

Commit ad42b3e

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
parameterize torchrec components in e2e benchmark (#2867)
Summary: Pull Request resolved: #2867 # context * parameterize (with config dataclass) the test and benchmark framework in TorchRec. * more specifically, use config to generate the necessary TorchRec components including 1) input data, 2) embedding tables, 3) test model, 4) module sharder, 5) pipeline, 6) env setup (e.g., world_size) and etc. * the goal is to facilitate the development of test cases and benchmarks by consolidate the ad-hoc style component generation. # details of this diff * decorator `cmd_conf` to take command-line arguments override the default configs * example benchmark "benchmark_train_sparsenn" with all necessary abstractions for components generation Reviewed By: aporialiao Differential Revision: D71703434 fbshipit-source-id: 4f5a05130f5479d89d78959f40dfb5923d6dabbc
1 parent f36d26d commit ad42b3e

File tree

4 files changed

+373
-3
lines changed

4 files changed

+373
-3
lines changed
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
#!/usr/bin/env python3
11+
12+
import copy
13+
14+
from dataclasses import dataclass
15+
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
16+
17+
import click
18+
19+
import torch
20+
import torch.distributed as dist
21+
from fbgemm_gpu.split_embedding_configs import EmbOptimType
22+
from torch import nn, optim
23+
from torch.optim import Optimizer
24+
from torchrec.distributed import DistributedModelParallel
25+
from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf
26+
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
27+
28+
from torchrec.distributed.test_utils.multi_process import (
29+
MultiProcessContext,
30+
run_multi_process_func,
31+
)
32+
from torchrec.distributed.test_utils.test_input import (
33+
ModelInput,
34+
TdModelInput,
35+
TestSparseNNInputConfig,
36+
)
37+
from torchrec.distributed.test_utils.test_model import (
38+
TestEBCSharder,
39+
TestOverArchLarge,
40+
TestSparseNN,
41+
)
42+
from torchrec.distributed.train_pipeline import (
43+
TrainPipeline,
44+
TrainPipelineBase,
45+
TrainPipelineSparseDist,
46+
)
47+
from torchrec.distributed.train_pipeline.train_pipelines import (
48+
PrefetchTrainPipelineSparseDist,
49+
TrainPipelineSemiSync,
50+
)
51+
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
52+
from torchrec.modules.embedding_configs import EmbeddingBagConfig
53+
54+
55+
@dataclass
56+
class RunOptions:
57+
world_size: int = 4
58+
num_batches: int = 20
59+
sharding_type: ShardingType = ShardingType.TABLE_WISE
60+
input_type: str = "kjt"
61+
profile: str = ""
62+
63+
64+
@dataclass
65+
class EmbeddingTablesConfig:
66+
num_unweighted_features: int = 4
67+
num_weighted_features: int = 4
68+
embedding_feature_dim: int = 512
69+
70+
def generate_tables(
71+
self,
72+
) -> Tuple[
73+
List[EmbeddingBagConfig],
74+
List[EmbeddingBagConfig],
75+
]:
76+
tables = [
77+
EmbeddingBagConfig(
78+
num_embeddings=max(i + 1, 100) * 1000,
79+
embedding_dim=self.embedding_feature_dim,
80+
name="table_" + str(i),
81+
feature_names=["feature_" + str(i)],
82+
)
83+
for i in range(self.num_unweighted_features)
84+
]
85+
weighted_tables = [
86+
EmbeddingBagConfig(
87+
num_embeddings=max(i + 1, 100) * 1000,
88+
embedding_dim=self.embedding_feature_dim,
89+
name="weighted_table_" + str(i),
90+
feature_names=["weighted_feature_" + str(i)],
91+
)
92+
for i in range(self.num_weighted_features)
93+
]
94+
return tables, weighted_tables
95+
96+
97+
@dataclass
98+
class PipelineConfig:
99+
pipeline: str = "base"
100+
101+
def generate_pipeline(
102+
self, model: nn.Module, opt: torch.optim.Optimizer, device: torch.device
103+
) -> Union[TrainPipelineBase, TrainPipelineSparseDist]:
104+
_pipeline_cls: Dict[
105+
str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]]
106+
] = {
107+
"base": TrainPipelineBase,
108+
"sparse": TrainPipelineSparseDist,
109+
"semi": TrainPipelineSemiSync,
110+
"prefetch": PrefetchTrainPipelineSparseDist,
111+
}
112+
113+
if self.pipeline == "semi":
114+
return TrainPipelineSemiSync(
115+
model=model, optimizer=opt, device=device, start_batch=0
116+
)
117+
elif self.pipeline in _pipeline_cls:
118+
Pipeline = _pipeline_cls[self.pipeline]
119+
return Pipeline(model=model, optimizer=opt, device=device)
120+
else:
121+
raise RuntimeError(f"unknown pipeline option {self.pipeline}")
122+
123+
124+
@click.command()
125+
@cmd_conf(RunOptions, EmbeddingTablesConfig, TestSparseNNInputConfig, PipelineConfig)
126+
def main(
127+
run_option: RunOptions,
128+
table_config: EmbeddingTablesConfig,
129+
input_config: TestSparseNNInputConfig,
130+
pipeline_config: PipelineConfig,
131+
) -> None:
132+
# sparse table config is available on each trainer
133+
tables, weighted_tables = table_config.generate_tables()
134+
135+
# launch trainers
136+
run_multi_process_func(
137+
func=runner,
138+
world_size=run_option.world_size,
139+
tables=tables,
140+
weighted_tables=weighted_tables,
141+
run_option=run_option,
142+
input_config=input_config,
143+
pipeline_config=pipeline_config,
144+
)
145+
146+
147+
def _generate_data(
148+
tables: List[EmbeddingBagConfig],
149+
weighted_tables: List[EmbeddingBagConfig],
150+
input_config: TestSparseNNInputConfig,
151+
num_batches: int,
152+
) -> List[ModelInput]:
153+
return [
154+
input_config.generate_model_input(
155+
tables=tables,
156+
weighted_tables=weighted_tables,
157+
)
158+
for _ in range(num_batches)
159+
]
160+
161+
162+
def _generate_model(
163+
tables: List[EmbeddingBagConfig],
164+
weighted_tables: List[EmbeddingBagConfig],
165+
dense_device: torch.device,
166+
) -> nn.Module:
167+
return TestSparseNN(
168+
tables=tables,
169+
weighted_tables=weighted_tables,
170+
dense_device=dense_device,
171+
sparse_device=torch.device("meta"),
172+
over_arch_clazz=TestOverArchLarge,
173+
)
174+
175+
176+
def _generate_sharded_model_and_optimizer(
177+
model: nn.Module,
178+
sharding_type: str,
179+
kernel_type: str,
180+
pg: dist.ProcessGroup,
181+
device: torch.device,
182+
fused_params: Optional[Dict[str, Any]] = None,
183+
) -> Tuple[nn.Module, Optimizer]:
184+
sharder = TestEBCSharder(
185+
sharding_type=sharding_type,
186+
kernel_type=kernel_type,
187+
fused_params=fused_params,
188+
)
189+
sharded_model = DistributedModelParallel(
190+
module=copy.deepcopy(model),
191+
env=ShardingEnv.from_process_group(pg),
192+
init_data_parallel=True,
193+
device=device,
194+
sharders=[
195+
cast(
196+
ModuleSharder[nn.Module],
197+
sharder,
198+
)
199+
],
200+
).to(device)
201+
optimizer = optim.SGD(
202+
[
203+
param
204+
for name, param in sharded_model.named_parameters()
205+
if "sparse" not in name
206+
],
207+
lr=0.1,
208+
)
209+
return sharded_model, optimizer
210+
211+
212+
def runner(
213+
rank: int,
214+
world_size: int,
215+
tables: List[EmbeddingBagConfig],
216+
weighted_tables: List[EmbeddingBagConfig],
217+
run_option: RunOptions,
218+
input_config: TestSparseNNInputConfig,
219+
pipeline_config: PipelineConfig,
220+
) -> None:
221+
222+
torch.autograd.set_detect_anomaly(True)
223+
with MultiProcessContext(
224+
rank=rank,
225+
world_size=world_size,
226+
backend="nccl",
227+
use_deterministic_algorithms=False,
228+
) as ctx:
229+
unsharded_model = _generate_model(
230+
tables=tables,
231+
weighted_tables=weighted_tables,
232+
dense_device=ctx.device,
233+
)
234+
235+
sharded_model, optimizer = _generate_sharded_model_and_optimizer(
236+
model=unsharded_model,
237+
sharding_type=run_option.sharding_type.value,
238+
kernel_type=EmbeddingComputeKernel.FUSED.value,
239+
# pyre-ignore
240+
pg=ctx.pg,
241+
device=ctx.device,
242+
fused_params={
243+
"optimizer": EmbOptimType.EXACT_ADAGRAD,
244+
"learning_rate": 0.1,
245+
},
246+
)
247+
bench_inputs = _generate_data(
248+
tables=tables,
249+
weighted_tables=weighted_tables,
250+
input_config=input_config,
251+
num_batches=run_option.num_batches,
252+
)
253+
pipeline = pipeline_config.generate_pipeline(
254+
sharded_model, optimizer, ctx.device
255+
)
256+
pipeline.progress(iter(bench_inputs))
257+
258+
def _func_to_benchmark(
259+
bench_inputs: List[ModelInput],
260+
model: nn.Module,
261+
pipeline: TrainPipeline,
262+
) -> None:
263+
dataloader = iter(bench_inputs)
264+
while True:
265+
try:
266+
pipeline.progress(dataloader)
267+
except StopIteration:
268+
break
269+
270+
result = benchmark_func(
271+
name=type(pipeline).__name__,
272+
bench_inputs=bench_inputs, # pyre-ignore
273+
prof_inputs=bench_inputs, # pyre-ignore
274+
num_benchmarks=5,
275+
num_profiles=2,
276+
profile_dir=run_option.profile,
277+
world_size=run_option.world_size,
278+
func_to_benchmark=_func_to_benchmark,
279+
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},
280+
rank=rank,
281+
)
282+
if rank == 0:
283+
print(result)
284+
285+
286+
if __name__ == "__main__":
287+
main()

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
import os
2020
import time
2121
import timeit
22-
from dataclasses import dataclass
23-
22+
from dataclasses import dataclass, fields, is_dataclass
2423
from enum import Enum
2524
from typing import (
2625
Any,
@@ -29,11 +28,14 @@
2928
Dict,
3029
List,
3130
Optional,
31+
Set,
3232
Tuple,
3333
TypeVar,
3434
Union,
3535
)
3636

37+
import click
38+
3739
import torch
3840
from torch import multiprocessing as mp
3941
from torch.autograd.profiler import record_function
@@ -463,6 +465,52 @@ def set_embedding_config(
463465
return embedding_configs, pooling_configs
464466

465467

468+
# pyre-ignore [24]
469+
def cmd_conf(*configs: Any) -> Callable:
470+
support_classes: List[Any] = [int, str, bool, float, Enum] # pyre-ignore[33]
471+
472+
# pyre-ignore [24]
473+
def wrapper(func: Callable) -> Callable:
474+
for config in configs:
475+
assert is_dataclass(config), f"{config} should be a dataclass"
476+
477+
# pyre-ignore
478+
def rtf(**kwargs):
479+
loglevel = logging._nameToLevel[kwargs["loglevel"].upper()]
480+
logger.setLevel(logging.INFO)
481+
input_configs = []
482+
for config in configs:
483+
params = {}
484+
for field in fields(config):
485+
params[field.name] = kwargs.get(field.name, field.default)
486+
conf = config(**params)
487+
logger.info(conf)
488+
input_configs.append(conf)
489+
logger.setLevel(loglevel)
490+
return func(*input_configs)
491+
492+
names: Set[str] = set()
493+
for config in configs:
494+
for field in fields(config):
495+
if not isinstance(field.default, tuple(support_classes)):
496+
continue
497+
if field.name not in names:
498+
names.add(field.name)
499+
else:
500+
logger.warn(f"WARNING: duplicate argument {field.name}")
501+
continue
502+
rtf = click.option(
503+
f"--{field.name}", type=field.type, default=field.default
504+
)(rtf)
505+
return click.option(
506+
"--loglevel",
507+
type=click.Choice(list(logging._nameToLevel.keys()), case_sensitive=False),
508+
default=logging._levelToName[logger.level],
509+
)(rtf)
510+
511+
return wrapper
512+
513+
466514
def init_argparse_and_args() -> argparse.Namespace:
467515
parser = argparse.ArgumentParser()
468516

torchrec/distributed/test_utils/multi_process.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _run_multi_process_test_per_rank(
183183

184184
def run_multi_process_func(
185185
func: Callable[
186-
...,
186+
[int, int, ...], # rank, world_size, ...
187187
None,
188188
],
189189
multiprocessing_method: str = "spawn",
@@ -192,6 +192,7 @@ def run_multi_process_func(
192192
# pyre-ignore
193193
**kwargs,
194194
) -> None:
195+
""" """
195196
os.environ["MASTER_ADDR"] = str("localhost")
196197
os.environ["MASTER_PORT"] = str(get_free_port())
197198
os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP"

0 commit comments

Comments
 (0)