Skip to content

Commit 5b528f5

Browse files
authored
Merge pull request #904 from SciPioneer/patch-1
Clang-format
2 parents b0649dc + 41a7004 commit 5b528f5

File tree

1 file changed

+34
-29
lines changed

1 file changed

+34
-29
lines changed

distributed/rpc/ddp_rpc/main.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
1-
from functools import wraps
21
import os
32
import random
3+
from functools import wraps
44

55
import torch
66
import torch.distributed as dist
77
import torch.distributed.autograd as dist_autograd
8-
from torch.distributed.optim import DistributedOptimizer
98
import torch.distributed.rpc as rpc
9+
import torch.multiprocessing as mp
10+
import torch.optim as optim
11+
from torch.distributed.optim import DistributedOptimizer
1012
from torch.distributed.rpc import RRef
1113
from torch.distributed.rpc import TensorPipeRpcBackendOptions
12-
import torch.multiprocessing as mp
1314
from torch.nn.parallel import DistributedDataParallel as DDP
14-
import torch.optim as optim
1515

1616
NUM_EMBEDDINGS = 100
1717
EMBEDDING_DIM = 16
1818

19+
1920
class HybridModel(torch.nn.Module):
2021
r"""
2122
The model consists of a sparse part and a dense part. The dense part is an
2223
nn.Linear module that is replicated across all trainers using
2324
DistributedDataParallel. The sparse part is an nn.EmbeddingBag that is
2425
stored on the parameter server.
25-
2626
The model holds a Remote Reference to the embedding table on the parameter
2727
server.
2828
"""
@@ -37,6 +37,7 @@ def forward(self, indices, offsets):
3737
emb_lookup = self.emb_rref.rpc_sync().forward(indices, offsets)
3838
return self.fc(emb_lookup.cuda(self.device))
3939

40+
4041
def _retrieve_embedding_parameters(emb_rref):
4142
return [RRef(p) for p in emb_rref.local_value().parameters()]
4243

@@ -57,7 +58,8 @@ def _run_trainer(emb_rref, rank):
5758

5859
# Retrieve parameters for embedding table.
5960
model_parameter_rrefs = rpc.rpc_sync(
60-
"ps", _retrieve_embedding_parameters, args=(emb_rref,))
61+
"ps", _retrieve_embedding_parameters, args=(emb_rref,)
62+
)
6163

6264
# model.parameters() only includes local parameters.
6365
for param in model.parameters():
@@ -118,29 +120,30 @@ def run_worker(rank, world_size):
118120
# We need to use different port numbers in TCP init_method for init_rpc and
119121
# init_process_group to avoid port conflicts.
120122
rpc_backend_options = TensorPipeRpcBackendOptions()
121-
rpc_backend_options.init_method='tcp://localhost:29501'
123+
rpc_backend_options.init_method = "tcp://localhost:29501"
122124

123125
# Rank 2 is master, 3 is ps and 0 and 1 are trainers.
124126
if rank == 2:
125127
rpc.init_rpc(
126-
"master",
127-
rank=rank,
128-
world_size=world_size,
129-
rpc_backend_options=rpc_backend_options)
128+
"master",
129+
rank=rank,
130+
world_size=world_size,
131+
rpc_backend_options=rpc_backend_options,
132+
)
130133

131134
# Build the embedding table on the ps.
132135
emb_rref = rpc.remote(
133-
"ps",
134-
torch.nn.EmbeddingBag,
135-
args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
136-
kwargs={"mode": "sum"})
136+
"ps",
137+
torch.nn.EmbeddingBag,
138+
args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
139+
kwargs={"mode": "sum"},
140+
)
137141

138142
# Run the training loop on trainers.
139143
futs = []
140144
for trainer_rank in [0, 1]:
141145
trainer_name = "trainer{}".format(trainer_rank)
142-
fut = rpc.rpc_async(
143-
trainer_name, _run_trainer, args=(emb_rref, rank))
146+
fut = rpc.rpc_async(trainer_name, _run_trainer, args=(emb_rref, rank))
144147
futs.append(fut)
145148

146149
# Wait for all training to finish.
@@ -149,32 +152,34 @@ def run_worker(rank, world_size):
149152
elif rank <= 1:
150153
# Initialize process group for Distributed DataParallel on trainers.
151154
dist.init_process_group(
152-
backend="gloo", rank=rank, world_size=2,
153-
init_method='tcp://localhost:29500')
155+
backend="gloo", rank=rank, world_size=2, init_method="tcp://localhost:29500"
156+
)
154157

155158
# Initialize RPC.
156159
trainer_name = "trainer{}".format(rank)
157160
rpc.init_rpc(
158-
trainer_name,
159-
rank=rank,
160-
world_size=world_size,
161-
rpc_backend_options=rpc_backend_options)
161+
trainer_name,
162+
rank=rank,
163+
world_size=world_size,
164+
rpc_backend_options=rpc_backend_options,
165+
)
162166

163167
# Trainer just waits for RPCs from master.
164168
else:
165169
rpc.init_rpc(
166-
"ps",
167-
rank=rank,
168-
world_size=world_size,
169-
rpc_backend_options=rpc_backend_options)
170+
"ps",
171+
rank=rank,
172+
world_size=world_size,
173+
rpc_backend_options=rpc_backend_options,
174+
)
170175
# parameter server do nothing
171176
pass
172177

173178
# block until all rpcs finish
174179
rpc.shutdown()
175180

176181

177-
if __name__=="__main__":
182+
if __name__ == "__main__":
178183
# 2 trainers, 1 parameter server, 1 master.
179184
world_size = 4
180-
mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)
185+
mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)

0 commit comments

Comments
 (0)