Skip to content

Commit 5cea44a

Browse files
Ali-Tehranifacebook-github-bot
authored andcommitted
Set intra_group_size from env var inside comm.py. (meta-pytorch#3697)
Summary: Context --------- TorchRec comms needs a way to obtain the pod-size (topology-domain-multiple) and the total amount of process groups of the topology group for TWRW/Grid-sharding. We obtain the number of intra-nodes within a pod by obtaining the `TOPOLOGY_DOMAIN_MULTIPLE` from the environment variables (see diff stack). Implementation ------------------ - created `get_intra_group_size` function that obtains the number of intra-node-size from envrionemnt variable, and if not it defaults to usual `get_local_size`. - updated `intra_and_cross_node_pg` to utilize `get_intra_node_size` instead. Differential Revision: D91617889
1 parent 302da75 commit 5cea44a

File tree

4 files changed

+77
-6
lines changed

4 files changed

+77
-6
lines changed

torchrec/distributed/comm.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,68 @@ def get_num_groups(world_size: Optional[int] = None) -> int:
115115
return world_size // get_local_size(world_size)
116116

117117

118+
def get_topology_domain_multiple() -> Optional[int]:
119+
"""The number of host/server/node per pod/domain."""
120+
topology_domain_multiple = _env2int(
121+
[
122+
"TOPOLOGY_DOMAIN_MULTIPLE",
123+
],
124+
-1,
125+
)
126+
if topology_domain_multiple == -1:
127+
return None
128+
return topology_domain_multiple
129+
130+
131+
def get_topology_group_world_size(world_size: Optional[int] = None) -> int:
132+
"""
133+
Gets topology group world size, total number of processes linked within a topology group
134+
135+
This is the largest number of processes linked together by high-bandwidth communication.
136+
If it isn't specified, it falls back to LOCAL_WORLD_SIZE
137+
"""
138+
topology_domain_multiple = get_topology_domain_multiple()
139+
local_world_size = get_local_size(world_size)
140+
141+
if topology_domain_multiple is None:
142+
logger.warn(
143+
"Could not determine TOPOLOGY_DOMAIN_MULTIPLE from environment,"
144+
" utilizing LOCAL_WORLD_SIZE instead."
145+
)
146+
return local_world_size
147+
148+
# Total number of gpu in domain = topology_domain_mult * number_gpu_per_domain
149+
total_numb_proc = topology_domain_multiple * local_world_size
150+
world_size = dist.get_world_size()
151+
if world_size % total_numb_proc != 0:
152+
raise ValueError(
153+
f"World size {world_size} is not a multiple of the topology group: {total_numb_proc}"
154+
)
155+
return total_numb_proc
156+
157+
118158
def intra_and_cross_node_pg(
119159
device: Optional[torch.device] = None,
120160
backend: Optional[str] = None,
121161
) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]:
122162
"""
123163
Creates sub process groups (intra and cross node)
164+
165+
e.g. world_size = 12 need to split into groups of size `local_size = 6`
166+
process groups = [0, 1, 2, ... 11]
167+
168+
intra-group:
169+
[0] -> [0, 1, .., 5]
170+
[1] -> [0, 1, .., 5]
171+
...
172+
[6] -> [6, ..., 11]
173+
...
174+
[11] -> [6, ..., 11]
175+
cross-group:
176+
[0] -> [[0, 6]]
177+
[1] -> [[1, 7]]
178+
...
179+
[5] -> [[5, 11]]
124180
"""
125181
if device is not None and device.type == "meta":
126182
return None, None
@@ -130,10 +186,12 @@ def intra_and_cross_node_pg(
130186

131187
my_size = dist.get_world_size()
132188
my_rank = dist.get_rank()
133-
my_local_rank = get_local_rank(my_size, my_rank)
134-
local_size = get_local_size(my_size)
135-
my_group_rank = get_group_rank(my_size, my_rank)
136-
group_count = get_num_groups(my_size)
189+
local_size = get_topology_group_world_size(my_size)
190+
# TODO: Alireza look into incorporating topology group WS in
191+
# get_group_rank, get_num_groups, get_local_rank
192+
my_group_rank = my_rank // local_size
193+
group_count = my_size // local_size
194+
my_local_rank = my_rank % local_size # Not the same as the actual local_rank
137195
if backend is None:
138196
backend = dist.get_backend()
139197

torchrec/distributed/test_utils/test_model_parallel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _test_sharding(
140140
backend: str = "gloo",
141141
world_size: int = 2,
142142
local_size: Optional[int] = None,
143+
pod_size: Optional[int] = None,
143144
world_size_2D: Optional[int] = None,
144145
node_group_size: Optional[int] = None,
145146
constraints: Optional[Dict[str, ParameterConstraints]] = None,
@@ -173,6 +174,7 @@ def _test_sharding(
173174
rank=0,
174175
world_size=world_size,
175176
local_size=local_size,
177+
pod_size=pod_size,
176178
world_size_2D=world_size_2D,
177179
node_group_size=node_group_size,
178180
model_class=model_class, # pyre-ignore[6]
@@ -205,6 +207,7 @@ def _test_sharding(
205207
callable=sharding_single_rank_test,
206208
world_size=world_size,
207209
local_size=local_size,
210+
pod_size=pod_size,
208211
world_size_2D=world_size_2D,
209212
node_group_size=node_group_size,
210213
model_class=model_class,

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ def sharding_single_rank_test_single_process(
770770
weighted_tables: Optional[List[EmbeddingTableConfig]] = None,
771771
constraints: Optional[Dict[str, ParameterConstraints]] = None,
772772
local_size: Optional[int] = None,
773+
pod_size: Optional[int] = None,
773774
qcomms_config: Optional[QCommsConfig] = None,
774775
apply_optimizer_in_backward_config: Optional[
775776
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
@@ -869,11 +870,11 @@ def sharding_single_rank_test_single_process(
869870
world_size=world_size_2D if world_size_2D else world_size,
870871
compute_device=device.type,
871872
local_world_size=node_group_size if node_group_size else local_size,
873+
pod_size=pod_size,
872874
),
873875
constraints=constraints,
874876
)
875877
plan: ShardingPlan = planner.collective_plan(local_model, sharders, pg)
876-
877878
if submodule_configs is not None:
878879
# Dynamic 2D parallel, create a new plan for each submodule
879880
for config in submodule_configs:
@@ -1057,6 +1058,7 @@ def sharding_single_rank_test(
10571058
weighted_tables: Optional[List[EmbeddingTableConfig]] = None,
10581059
constraints: Optional[Dict[str, ParameterConstraints]] = None,
10591060
local_size: Optional[int] = None,
1061+
pod_size: Optional[int] = None,
10601062
qcomms_config: Optional[QCommsConfig] = None,
10611063
apply_optimizer_in_backward_config: Optional[
10621064
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
@@ -1098,6 +1100,7 @@ def sharding_single_rank_test(
10981100
weighted_tables=weighted_tables,
10991101
constraints=constraints,
11001102
local_size=local_size,
1103+
pod_size=pod_size,
11011104
qcomms_config=qcomms_config,
11021105
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
11031106
variable_batch_size=variable_batch_size,

torchrec/distributed/tests/test_model_parallel_hierarchical.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class ModelParallelHierarchicalTest(ModelParallelTestShared):
6060
EmbeddingComputeKernel.FUSED.value,
6161
]
6262
),
63+
topology_domain=st.sampled_from([None, 1]),
6364
local_size=st.sampled_from([2]),
6465
qcomms_config=st.sampled_from(
6566
[
@@ -92,6 +93,7 @@ def test_sharding_nccl_twrw(
9293
sharder_type: str,
9394
sharding_type: str,
9495
kernel_type: str,
96+
topology_domain: int,
9597
local_size: int,
9698
qcomms_config: Optional[QCommsConfig],
9799
apply_optimizer_in_backward_config: Optional[
@@ -111,6 +113,10 @@ def test_sharding_nccl_twrw(
111113
)
112114
# Make sure detail debug will work with non-even collective
113115
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
116+
world_size = 4
117+
if topology_domain:
118+
# Need this to test topology group for TWRW
119+
os.environ["TOPOLOGY_DOMAIN_MULTIPLE"] = str(topology_domain)
114120

115121
self._test_sharding(
116122
# pyre-ignore[6]
@@ -123,8 +129,9 @@ def test_sharding_nccl_twrw(
123129
device=torch.device("cuda"),
124130
),
125131
],
132+
pod_size=topology_domain,
126133
backend="nccl",
127-
world_size=4,
134+
world_size=world_size,
128135
local_size=local_size,
129136
qcomms_config=qcomms_config,
130137
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,

0 commit comments

Comments
 (0)