Skip to content

Commit ac739f4

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
add custom all reduce support for 2D parallel (#2758)
Summary: Pull Request resolved: #2758 Add support for user defined all reduce function for embedding weight and optimizer sync. Users do not need to create a new process group, they can access the process group for ranks participating in the model replication through DMPCollection._replica_pg. Users will need to handle the stream synchronization. Reviewed By: kausv Differential Revision: D69990461 fbshipit-source-id: 7cea3f7c7e10bc198984bbe90fc177d6f3a3e769
1 parent 7500a0f commit ac739f4

File tree

4 files changed

+134
-21
lines changed

4 files changed

+134
-21
lines changed

torchrec/distributed/model_parallel.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import copy
1212
import logging as logger
1313
from collections import OrderedDict
14-
from typing import Any, cast, Dict, Iterator, List, Optional, Set, Tuple, Type
14+
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Set, Tuple, Type
1515

1616
import torch
1717
import torch.distributed as dist
@@ -691,6 +691,7 @@ def __init__(
691691
init_parameters: bool = True,
692692
data_parallel_wrapper: Optional[DataParallelWrapper] = None,
693693
use_inter_host_allreduce: bool = False,
694+
custom_all_reduce: Optional[Callable[[torch.Tensor], None]] = None,
694695
) -> None:
695696
assert device.type == "cuda", "DMPCollection only supports CUDA"
696697
self._device = device
@@ -700,6 +701,9 @@ def __init__(
700701
self._sharding_pg: dist.ProcessGroup = None # pyre-ignore[8]
701702
self._replica_pg: dist.ProcessGroup = None # pyre-ignore[8]
702703
self._global_rank: int = dist.get_rank(global_pg)
704+
self._custom_all_reduce: Optional[Callable[[torch.Tensor], None]] = (
705+
custom_all_reduce
706+
)
703707

704708
self._device_mesh, self._sharding_pg, self._replica_pg = (
705709
self._create_process_groups(
@@ -744,33 +748,68 @@ def sync(self, include_optimizer_state: bool = True) -> None:
744748
It uses the `dist.AllreduceCoalescedOptions` to perform an all-reduce operation on the weights,
745749
which averages the weights across all processes in the inter-process group.
746750
751+
The default CUDA stream is used for the all-reduce operation, and the method does not return any value.
752+
747753
Args:
748754
include_optimizer_state (bool): Flag to include optimizer state syncing upon call
749755
"""
750756
assert self._replica_pg is not None, "replica_pg is not initialized!"
751-
opts = dist.AllreduceCoalescedOptions()
752-
opts.reduceOp = dist.ReduceOp.AVG
753-
all_weights = [
757+
all_weights: List[torch.Tensor] = [
754758
w
755759
for emb_kernel in self._modules_to_sync
756760
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
757761
for w in emb_kernel.split_embedding_weights()
758762
]
759-
handle = self._replica_pg.allreduce_coalesced(all_weights, opts=opts)
760-
handle.wait()
763+
764+
opts = None
765+
if self._custom_all_reduce is None:
766+
opts = dist.AllreduceCoalescedOptions()
767+
opts.reduceOp = dist.ReduceOp.AVG
768+
self._allreduce_tensors(all_weights, opts)
761769

762770
if include_optimizer_state:
763-
# Sync accumulated square of grad of local optimizer shards
764-
optim_list = []
771+
optimizer_tensors = []
765772
for emb_kernel in self._modules_to_sync:
766773
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
767-
all_optimizer_states = emb_kernel.get_optimizer_state()
768-
momentum1 = [optim["sum"] for optim in all_optimizer_states]
769-
optim_list.extend(momentum1)
770-
# Some optimizers do not have states to sync, we check if states exist before collective call
771-
if optim_list:
772-
handle = self._replica_pg.allreduce_coalesced(optim_list, opts=opts)
773-
handle.wait()
774+
optimizer_states = emb_kernel.get_optimizer_state()
775+
optimizer_tensors.extend([state["sum"] for state in optimizer_states])
776+
if optimizer_tensors:
777+
self._allreduce_tensors(optimizer_tensors, opts)
778+
779+
def _allreduce_tensors(
780+
self,
781+
tensors: List[torch.Tensor],
782+
opts: Optional[dist.AllreduceCoalescedOptions] = None,
783+
) -> None:
784+
"""
785+
Helper to perform all reduce on given tensors, uses custom all reduce function if provided
786+
"""
787+
if self._custom_all_reduce is not None:
788+
# pyre-ignore[6]
789+
self._custom_all_reduce(tensors)
790+
else:
791+
handle = self._replica_pg.allreduce_coalesced(tensors, opts=opts)
792+
handle.wait()
793+
794+
def set_all_reduce_hook(
795+
self,
796+
reduce_hook: Callable[[torch.Tensor], None],
797+
) -> None:
798+
"""
799+
Replace default all reduce with custom callable. Users can alternatively
800+
pass in the custom all reduce function through the constructor. The hook
801+
expects the user to handle distributed communication call, associated
802+
process group, and stream synchronization.
803+
804+
Args:
805+
reduce_hook (Callable[[torch.Tensor], torch.Tensor]): The custom all reduce function to use for
806+
embedding weights and optimizer states
807+
"""
808+
if self._custom_all_reduce is not None:
809+
logger.warning(
810+
"[TorchRec 2D Parallel] Custom all reduce function already defined, overriding with new callable"
811+
)
812+
self._custom_all_reduce = reduce_hook
774813

775814
def _create_process_groups(
776815
self,

torchrec/distributed/test_utils/test_model_parallel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def _test_sharding(
151151
data_type: DataType = DataType.FP32,
152152
use_inter_host_allreduce: bool = False,
153153
allow_zero_batch_size: bool = False,
154+
custom_all_reduce: bool = False,
154155
) -> None:
155156
self._build_tables_and_groups(data_type=data_type)
156157
self._run_multi_process_test(
@@ -174,6 +175,7 @@ def _test_sharding(
174175
global_constant_batch=global_constant_batch,
175176
use_inter_host_allreduce=use_inter_host_allreduce,
176177
allow_zero_batch_size=allow_zero_batch_size,
178+
custom_all_reduce=custom_all_reduce,
177179
)
178180

179181

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,24 @@
99

1010
import random
1111
from enum import Enum
12-
from typing import Any, cast, Dict, List, Optional, Protocol, Tuple, Type, Union
12+
from typing import (
13+
Any,
14+
Callable,
15+
cast,
16+
Dict,
17+
List,
18+
Optional,
19+
Protocol,
20+
Tuple,
21+
Type,
22+
Union,
23+
)
1324

1425
import torch
1526
import torch.distributed as dist
1627
import torch.nn as nn
1728
from fbgemm_gpu.split_embedding_configs import EmbOptimType
18-
from torch.distributed._tensor import DTensor
29+
from torch.distributed._tensor import DeviceMesh, DTensor
1930
from torch.distributed.optim import (
2031
_apply_optimizer_in_backward as apply_optimizer_in_backward,
2132
)
@@ -314,11 +325,12 @@ def sharding_single_rank_test(
314325
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
315326
variable_batch_per_feature: bool = False, # VBE
316327
global_constant_batch: bool = False,
317-
world_size_2D: Optional[int] = None,
318-
node_group_size: Optional[int] = None,
319-
use_inter_host_allreduce: bool = False,
328+
world_size_2D: Optional[int] = None, # 2D parallel
329+
node_group_size: Optional[int] = None, # 2D parallel
330+
use_inter_host_allreduce: bool = False, # 2D parallel
320331
input_type: str = "kjt", # "kjt" or "td"
321332
allow_zero_batch_size: bool = False,
333+
custom_all_reduce: bool = False, # 2D parallel
322334
) -> None:
323335
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
324336
batch_size = (
@@ -428,17 +440,37 @@ def sharding_single_rank_test(
428440
)
429441

430442
assert ctx.pg is not None
443+
hook_called: bool = False
431444
if world_size_2D is not None:
445+
all_reduce_func = None
446+
if custom_all_reduce:
447+
all_reduce_pg: dist.ProcessGroup = create_device_mesh_for_2D(
448+
use_inter_host_allreduce,
449+
world_size=ctx.world_size,
450+
local_size=world_size_2D,
451+
).get_group(mesh_dim="replicate")
452+
453+
def _custom_hook(input: List[torch.Tensor]) -> None:
454+
nonlocal hook_called
455+
opts = dist.AllreduceCoalescedOptions()
456+
opts.reduceOp = dist.ReduceOp.AVG
457+
handle = all_reduce_pg.allreduce_coalesced(input, opts=opts)
458+
handle.wait()
459+
hook_called = True
460+
461+
all_reduce_func = _custom_hook
462+
432463
local_model = DMPCollection(
433464
module=local_model,
434465
sharding_group_size=world_size_2D,
435466
world_size=ctx.world_size,
436-
global_pg=ctx.pg,
467+
global_pg=ctx.pg, # pyre-ignore[6]
437468
node_group_size=node_group_size,
438469
plan=plan,
439470
sharders=sharders,
440471
device=ctx.device,
441472
use_inter_host_allreduce=use_inter_host_allreduce,
473+
custom_all_reduce=all_reduce_func, # pyre-ignore[6]
442474
)
443475
else:
444476
local_model = DistributedModelParallel(
@@ -469,6 +501,9 @@ def sharding_single_rank_test(
469501
local_input,
470502
)
471503

504+
if world_size_2D is not None and custom_all_reduce:
505+
assert hook_called, "custom all reduce hook was not called"
506+
472507
# TODO: support non-sharded forward with zero batch size KJT
473508
if not allow_zero_batch_size:
474509
all_local_pred = []
@@ -501,6 +536,28 @@ def sharding_single_rank_test(
501536
)
502537

503538

539+
def create_device_mesh_for_2D(
540+
use_inter_host_allreduce: bool, world_size: int, local_size: int
541+
) -> DeviceMesh:
542+
if use_inter_host_allreduce:
543+
peer_matrix = [
544+
list(range(i, i + local_size)) for i in range(0, world_size, local_size)
545+
]
546+
else:
547+
peer_matrix = []
548+
step = world_size // local_size
549+
for group_rank in range(world_size // local_size):
550+
peer_matrix.append([step * r + group_rank for r in range(local_size)])
551+
552+
mesh = DeviceMesh(
553+
device_type="cuda",
554+
mesh=peer_matrix,
555+
mesh_dim_names=("replicate", "shard"),
556+
)
557+
558+
return mesh
559+
560+
504561
def gen_full_pred_after_one_step(
505562
model: nn.Module,
506563
opt: torch.optim.Optimizer,

torchrec/distributed/tests/test_2d_sharding.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def setUp(self, backend: str = "nccl") -> None:
8787
),
8888
pooling=st.sampled_from([PoolingType.SUM]),
8989
use_inter_host_allreduce=st.booleans(),
90+
custom_all_reduce=st.booleans(),
9091
)
9192
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
9293
def test_sharding_cw_2D(
@@ -99,6 +100,7 @@ def test_sharding_cw_2D(
99100
],
100101
pooling: PoolingType,
101102
use_inter_host_allreduce: bool,
103+
custom_all_reduce: bool,
102104
) -> None:
103105
if (
104106
self.device == torch.device("cpu")
@@ -133,6 +135,7 @@ def test_sharding_cw_2D(
133135
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
134136
pooling=pooling,
135137
use_inter_host_allreduce=use_inter_host_allreduce,
138+
custom_all_reduce=custom_all_reduce,
136139
)
137140

138141
@unittest.skipIf(
@@ -176,6 +179,7 @@ def test_sharding_cw_2D(
176179
),
177180
pooling=st.sampled_from([PoolingType.SUM]),
178181
use_inter_host_allreduce=st.booleans(),
182+
custom_all_reduce=st.booleans(),
179183
)
180184
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
181185
def test_sharding_tw_2D(
@@ -188,6 +192,7 @@ def test_sharding_tw_2D(
188192
],
189193
pooling: PoolingType,
190194
use_inter_host_allreduce: bool,
195+
custom_all_reduce: bool,
191196
) -> None:
192197
if (
193198
self.device == torch.device("cpu")
@@ -223,6 +228,7 @@ def test_sharding_tw_2D(
223228
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
224229
pooling=pooling,
225230
use_inter_host_allreduce=use_inter_host_allreduce,
231+
custom_all_reduce=custom_all_reduce,
226232
)
227233

228234
@unittest.skipIf(
@@ -266,6 +272,7 @@ def test_sharding_tw_2D(
266272
),
267273
pooling=st.sampled_from([PoolingType.SUM]),
268274
use_inter_host_allreduce=st.booleans(),
275+
custom_all_reduce=st.booleans(),
269276
)
270277
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
271278
def test_sharding_grid_2D(
@@ -278,6 +285,7 @@ def test_sharding_grid_2D(
278285
],
279286
pooling: PoolingType,
280287
use_inter_host_allreduce: bool,
288+
custom_all_reduce: bool,
281289
) -> None:
282290
if (
283291
self.device == torch.device("cpu")
@@ -335,6 +343,7 @@ def test_sharding_grid_2D(
335343
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
336344
pooling=pooling,
337345
use_inter_host_allreduce=use_inter_host_allreduce,
346+
custom_all_reduce=custom_all_reduce,
338347
)
339348

340349
@unittest.skipIf(
@@ -375,6 +384,7 @@ def test_sharding_grid_2D(
375384
variable_batch_size=st.booleans(),
376385
pooling=st.sampled_from([PoolingType.SUM]),
377386
use_inter_host_allreduce=st.booleans(),
387+
custom_all_reduce=st.booleans(),
378388
)
379389
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
380390
def test_sharding_rw_2D(
@@ -388,6 +398,7 @@ def test_sharding_rw_2D(
388398
variable_batch_size: bool,
389399
pooling: PoolingType,
390400
use_inter_host_allreduce: bool,
401+
custom_all_reduce: bool,
391402
) -> None:
392403
if self.backend == "gloo":
393404
self.skipTest(
@@ -421,6 +432,7 @@ def test_sharding_rw_2D(
421432
variable_batch_size=variable_batch_size,
422433
pooling=pooling,
423434
use_inter_host_allreduce=use_inter_host_allreduce,
435+
custom_all_reduce=custom_all_reduce,
424436
)
425437

426438
@unittest.skipIf(
@@ -464,6 +476,7 @@ def test_sharding_rw_2D(
464476
),
465477
pooling=st.sampled_from([PoolingType.SUM]),
466478
use_inter_host_allreduce=st.booleans(),
479+
custom_all_reduce=st.booleans(),
467480
)
468481
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
469482
def test_sharding_twrw_2D(
@@ -476,6 +489,7 @@ def test_sharding_twrw_2D(
476489
],
477490
pooling: PoolingType,
478491
use_inter_host_allreduce: bool,
492+
custom_all_reduce: bool,
479493
) -> None:
480494
if (
481495
self.device == torch.device("cpu")
@@ -511,6 +525,7 @@ def test_sharding_twrw_2D(
511525
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
512526
pooling=pooling,
513527
use_inter_host_allreduce=use_inter_host_allreduce,
528+
custom_all_reduce=custom_all_reduce,
514529
)
515530

516531

0 commit comments

Comments
 (0)