|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +# pyre-strict |
| 9 | + |
| 10 | +#!/usr/bin/env python3 |
| 11 | + |
| 12 | +import copy |
| 13 | + |
| 14 | +from dataclasses import dataclass |
| 15 | +from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union |
| 16 | + |
| 17 | +import click |
| 18 | + |
| 19 | +import torch |
| 20 | +import torch.distributed as dist |
| 21 | +from fbgemm_gpu.split_embedding_configs import EmbOptimType |
| 22 | +from torch import nn, optim |
| 23 | +from torch.optim import Optimizer |
| 24 | +from torchrec.distributed import DistributedModelParallel |
| 25 | +from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf |
| 26 | +from torchrec.distributed.embedding_types import EmbeddingComputeKernel |
| 27 | + |
| 28 | +from torchrec.distributed.test_utils.multi_process import ( |
| 29 | + MultiProcessContext, |
| 30 | + run_multi_process_func, |
| 31 | +) |
| 32 | +from torchrec.distributed.test_utils.test_input import ( |
| 33 | + ModelInput, |
| 34 | + TdModelInput, |
| 35 | + TestSparseNNInputConfig, |
| 36 | +) |
| 37 | +from torchrec.distributed.test_utils.test_model import ( |
| 38 | + TestEBCSharder, |
| 39 | + TestOverArchLarge, |
| 40 | + TestSparseNN, |
| 41 | +) |
| 42 | +from torchrec.distributed.train_pipeline import ( |
| 43 | + TrainPipeline, |
| 44 | + TrainPipelineBase, |
| 45 | + TrainPipelineSparseDist, |
| 46 | +) |
| 47 | +from torchrec.distributed.train_pipeline.train_pipelines import ( |
| 48 | + PrefetchTrainPipelineSparseDist, |
| 49 | + TrainPipelineSemiSync, |
| 50 | +) |
| 51 | +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType |
| 52 | +from torchrec.modules.embedding_configs import EmbeddingBagConfig |
| 53 | + |
| 54 | + |
| 55 | +@dataclass |
| 56 | +class RunOptions: |
| 57 | + world_size: int = 4 |
| 58 | + num_batches: int = 20 |
| 59 | + sharding_type: ShardingType = ShardingType.TABLE_WISE |
| 60 | + input_type: str = "kjt" |
| 61 | + profile: str = "" |
| 62 | + |
| 63 | + |
| 64 | +@dataclass |
| 65 | +class EmbeddingTablesConfig: |
| 66 | + num_unweighted_features: int = 4 |
| 67 | + num_weighted_features: int = 4 |
| 68 | + embedding_feature_dim: int = 512 |
| 69 | + |
| 70 | + def generate_tables( |
| 71 | + self, |
| 72 | + ) -> Tuple[ |
| 73 | + List[EmbeddingBagConfig], |
| 74 | + List[EmbeddingBagConfig], |
| 75 | + ]: |
| 76 | + tables = [ |
| 77 | + EmbeddingBagConfig( |
| 78 | + num_embeddings=max(i + 1, 100) * 1000, |
| 79 | + embedding_dim=self.embedding_feature_dim, |
| 80 | + name="table_" + str(i), |
| 81 | + feature_names=["feature_" + str(i)], |
| 82 | + ) |
| 83 | + for i in range(self.num_unweighted_features) |
| 84 | + ] |
| 85 | + weighted_tables = [ |
| 86 | + EmbeddingBagConfig( |
| 87 | + num_embeddings=max(i + 1, 100) * 1000, |
| 88 | + embedding_dim=self.embedding_feature_dim, |
| 89 | + name="weighted_table_" + str(i), |
| 90 | + feature_names=["weighted_feature_" + str(i)], |
| 91 | + ) |
| 92 | + for i in range(self.num_weighted_features) |
| 93 | + ] |
| 94 | + return tables, weighted_tables |
| 95 | + |
| 96 | + |
| 97 | +@dataclass |
| 98 | +class PipelineConfig: |
| 99 | + pipeline: str = "base" |
| 100 | + |
| 101 | + def generate_pipeline( |
| 102 | + self, model: nn.Module, opt: torch.optim.Optimizer, device: torch.device |
| 103 | + ) -> Union[TrainPipelineBase, TrainPipelineSparseDist]: |
| 104 | + _pipeline_cls: Dict[ |
| 105 | + str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]] |
| 106 | + ] = { |
| 107 | + "base": TrainPipelineBase, |
| 108 | + "sparse": TrainPipelineSparseDist, |
| 109 | + "semi": TrainPipelineSemiSync, |
| 110 | + "prefetch": PrefetchTrainPipelineSparseDist, |
| 111 | + } |
| 112 | + |
| 113 | + if self.pipeline == "semi": |
| 114 | + return TrainPipelineSemiSync( |
| 115 | + model=model, optimizer=opt, device=device, start_batch=0 |
| 116 | + ) |
| 117 | + elif self.pipeline in _pipeline_cls: |
| 118 | + Pipeline = _pipeline_cls[self.pipeline] |
| 119 | + return Pipeline(model=model, optimizer=opt, device=device) |
| 120 | + else: |
| 121 | + raise RuntimeError(f"unknown pipeline option {self.pipeline}") |
| 122 | + |
| 123 | + |
| 124 | +@click.command() |
| 125 | +@cmd_conf(RunOptions, EmbeddingTablesConfig, TestSparseNNInputConfig, PipelineConfig) |
| 126 | +def main( |
| 127 | + run_option: RunOptions, |
| 128 | + table_config: EmbeddingTablesConfig, |
| 129 | + input_config: TestSparseNNInputConfig, |
| 130 | + pipeline_config: PipelineConfig, |
| 131 | +) -> None: |
| 132 | + # sparse table config is available on each trainer |
| 133 | + tables, weighted_tables = table_config.generate_tables() |
| 134 | + |
| 135 | + # launch trainers |
| 136 | + run_multi_process_func( |
| 137 | + func=runner, |
| 138 | + world_size=run_option.world_size, |
| 139 | + tables=tables, |
| 140 | + weighted_tables=weighted_tables, |
| 141 | + run_option=run_option, |
| 142 | + input_config=input_config, |
| 143 | + pipeline_config=pipeline_config, |
| 144 | + ) |
| 145 | + |
| 146 | + |
| 147 | +def _generate_data( |
| 148 | + tables: List[EmbeddingBagConfig], |
| 149 | + weighted_tables: List[EmbeddingBagConfig], |
| 150 | + input_config: TestSparseNNInputConfig, |
| 151 | + num_batches: int, |
| 152 | +) -> List[ModelInput]: |
| 153 | + return [ |
| 154 | + input_config.generate_model_input( |
| 155 | + tables=tables, |
| 156 | + weighted_tables=weighted_tables, |
| 157 | + ) |
| 158 | + for _ in range(num_batches) |
| 159 | + ] |
| 160 | + |
| 161 | + |
| 162 | +def _generate_model( |
| 163 | + tables: List[EmbeddingBagConfig], |
| 164 | + weighted_tables: List[EmbeddingBagConfig], |
| 165 | + dense_device: torch.device, |
| 166 | +) -> nn.Module: |
| 167 | + return TestSparseNN( |
| 168 | + tables=tables, |
| 169 | + weighted_tables=weighted_tables, |
| 170 | + dense_device=dense_device, |
| 171 | + sparse_device=torch.device("meta"), |
| 172 | + over_arch_clazz=TestOverArchLarge, |
| 173 | + ) |
| 174 | + |
| 175 | + |
| 176 | +def _generate_sharded_model_and_optimizer( |
| 177 | + model: nn.Module, |
| 178 | + sharding_type: str, |
| 179 | + kernel_type: str, |
| 180 | + pg: dist.ProcessGroup, |
| 181 | + device: torch.device, |
| 182 | + fused_params: Optional[Dict[str, Any]] = None, |
| 183 | +) -> Tuple[nn.Module, Optimizer]: |
| 184 | + sharder = TestEBCSharder( |
| 185 | + sharding_type=sharding_type, |
| 186 | + kernel_type=kernel_type, |
| 187 | + fused_params=fused_params, |
| 188 | + ) |
| 189 | + sharded_model = DistributedModelParallel( |
| 190 | + module=copy.deepcopy(model), |
| 191 | + env=ShardingEnv.from_process_group(pg), |
| 192 | + init_data_parallel=True, |
| 193 | + device=device, |
| 194 | + sharders=[ |
| 195 | + cast( |
| 196 | + ModuleSharder[nn.Module], |
| 197 | + sharder, |
| 198 | + ) |
| 199 | + ], |
| 200 | + ).to(device) |
| 201 | + optimizer = optim.SGD( |
| 202 | + [ |
| 203 | + param |
| 204 | + for name, param in sharded_model.named_parameters() |
| 205 | + if "sparse" not in name |
| 206 | + ], |
| 207 | + lr=0.1, |
| 208 | + ) |
| 209 | + return sharded_model, optimizer |
| 210 | + |
| 211 | + |
| 212 | +def runner( |
| 213 | + rank: int, |
| 214 | + world_size: int, |
| 215 | + tables: List[EmbeddingBagConfig], |
| 216 | + weighted_tables: List[EmbeddingBagConfig], |
| 217 | + run_option: RunOptions, |
| 218 | + input_config: TestSparseNNInputConfig, |
| 219 | + pipeline_config: PipelineConfig, |
| 220 | +) -> None: |
| 221 | + |
| 222 | + torch.autograd.set_detect_anomaly(True) |
| 223 | + with MultiProcessContext( |
| 224 | + rank=rank, |
| 225 | + world_size=world_size, |
| 226 | + backend="nccl", |
| 227 | + use_deterministic_algorithms=False, |
| 228 | + ) as ctx: |
| 229 | + unsharded_model = _generate_model( |
| 230 | + tables=tables, |
| 231 | + weighted_tables=weighted_tables, |
| 232 | + dense_device=ctx.device, |
| 233 | + ) |
| 234 | + |
| 235 | + sharded_model, optimizer = _generate_sharded_model_and_optimizer( |
| 236 | + model=unsharded_model, |
| 237 | + sharding_type=run_option.sharding_type.value, |
| 238 | + kernel_type=EmbeddingComputeKernel.FUSED.value, |
| 239 | + # pyre-ignore |
| 240 | + pg=ctx.pg, |
| 241 | + device=ctx.device, |
| 242 | + fused_params={ |
| 243 | + "optimizer": EmbOptimType.EXACT_ADAGRAD, |
| 244 | + "learning_rate": 0.1, |
| 245 | + }, |
| 246 | + ) |
| 247 | + bench_inputs = _generate_data( |
| 248 | + tables=tables, |
| 249 | + weighted_tables=weighted_tables, |
| 250 | + input_config=input_config, |
| 251 | + num_batches=run_option.num_batches, |
| 252 | + ) |
| 253 | + pipeline = pipeline_config.generate_pipeline( |
| 254 | + sharded_model, optimizer, ctx.device |
| 255 | + ) |
| 256 | + pipeline.progress(iter(bench_inputs)) |
| 257 | + |
| 258 | + def _func_to_benchmark( |
| 259 | + bench_inputs: List[ModelInput], |
| 260 | + model: nn.Module, |
| 261 | + pipeline: TrainPipeline, |
| 262 | + ) -> None: |
| 263 | + dataloader = iter(bench_inputs) |
| 264 | + while True: |
| 265 | + try: |
| 266 | + pipeline.progress(dataloader) |
| 267 | + except StopIteration: |
| 268 | + break |
| 269 | + |
| 270 | + result = benchmark_func( |
| 271 | + name=type(pipeline).__name__, |
| 272 | + bench_inputs=bench_inputs, # pyre-ignore |
| 273 | + prof_inputs=bench_inputs, # pyre-ignore |
| 274 | + num_benchmarks=5, |
| 275 | + num_profiles=2, |
| 276 | + profile_dir=run_option.profile, |
| 277 | + world_size=run_option.world_size, |
| 278 | + func_to_benchmark=_func_to_benchmark, |
| 279 | + benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline}, |
| 280 | + rank=rank, |
| 281 | + ) |
| 282 | + if rank == 0: |
| 283 | + print(result) |
| 284 | + |
| 285 | + |
| 286 | +if __name__ == "__main__": |
| 287 | + main() |
0 commit comments