1
- from functools import wraps
2
1
import os
3
2
import random
3
+ from functools import wraps
4
4
5
5
import torch
6
6
import torch .distributed as dist
7
7
import torch .distributed .autograd as dist_autograd
8
- from torch .distributed .optim import DistributedOptimizer
9
8
import torch .distributed .rpc as rpc
9
+ import torch .multiprocessing as mp
10
+ import torch .optim as optim
11
+ from torch .distributed .optim import DistributedOptimizer
10
12
from torch .distributed .rpc import RRef
11
13
from torch .distributed .rpc import TensorPipeRpcBackendOptions
12
- import torch .multiprocessing as mp
13
14
from torch .nn .parallel import DistributedDataParallel as DDP
14
- import torch .optim as optim
15
15
16
16
NUM_EMBEDDINGS = 100
17
17
EMBEDDING_DIM = 16
18
18
19
+
19
20
class HybridModel (torch .nn .Module ):
20
21
r"""
21
22
The model consists of a sparse part and a dense part. The dense part is an
22
23
nn.Linear module that is replicated across all trainers using
23
24
DistributedDataParallel. The sparse part is an nn.EmbeddingBag that is
24
25
stored on the parameter server.
25
-
26
26
The model holds a Remote Reference to the embedding table on the parameter
27
27
server.
28
28
"""
@@ -37,6 +37,7 @@ def forward(self, indices, offsets):
37
37
emb_lookup = self .emb_rref .rpc_sync ().forward (indices , offsets )
38
38
return self .fc (emb_lookup .cuda (self .device ))
39
39
40
+
40
41
def _retrieve_embedding_parameters (emb_rref ):
41
42
return [RRef (p ) for p in emb_rref .local_value ().parameters ()]
42
43
@@ -57,7 +58,8 @@ def _run_trainer(emb_rref, rank):
57
58
58
59
# Retrieve parameters for embedding table.
59
60
model_parameter_rrefs = rpc .rpc_sync (
60
- "ps" , _retrieve_embedding_parameters , args = (emb_rref ,))
61
+ "ps" , _retrieve_embedding_parameters , args = (emb_rref ,)
62
+ )
61
63
62
64
# model.parameters() only includes local parameters.
63
65
for param in model .parameters ():
@@ -118,29 +120,30 @@ def run_worker(rank, world_size):
118
120
# We need to use different port numbers in TCP init_method for init_rpc and
119
121
# init_process_group to avoid port conflicts.
120
122
rpc_backend_options = TensorPipeRpcBackendOptions ()
121
- rpc_backend_options .init_method = ' tcp://localhost:29501'
123
+ rpc_backend_options .init_method = " tcp://localhost:29501"
122
124
123
125
# Rank 2 is master, 3 is ps and 0 and 1 are trainers.
124
126
if rank == 2 :
125
127
rpc .init_rpc (
126
- "master" ,
127
- rank = rank ,
128
- world_size = world_size ,
129
- rpc_backend_options = rpc_backend_options )
128
+ "master" ,
129
+ rank = rank ,
130
+ world_size = world_size ,
131
+ rpc_backend_options = rpc_backend_options ,
132
+ )
130
133
131
134
# Build the embedding table on the ps.
132
135
emb_rref = rpc .remote (
133
- "ps" ,
134
- torch .nn .EmbeddingBag ,
135
- args = (NUM_EMBEDDINGS , EMBEDDING_DIM ),
136
- kwargs = {"mode" : "sum" })
136
+ "ps" ,
137
+ torch .nn .EmbeddingBag ,
138
+ args = (NUM_EMBEDDINGS , EMBEDDING_DIM ),
139
+ kwargs = {"mode" : "sum" },
140
+ )
137
141
138
142
# Run the training loop on trainers.
139
143
futs = []
140
144
for trainer_rank in [0 , 1 ]:
141
145
trainer_name = "trainer{}" .format (trainer_rank )
142
- fut = rpc .rpc_async (
143
- trainer_name , _run_trainer , args = (emb_rref , rank ))
146
+ fut = rpc .rpc_async (trainer_name , _run_trainer , args = (emb_rref , rank ))
144
147
futs .append (fut )
145
148
146
149
# Wait for all training to finish.
@@ -149,32 +152,34 @@ def run_worker(rank, world_size):
149
152
elif rank <= 1 :
150
153
# Initialize process group for Distributed DataParallel on trainers.
151
154
dist .init_process_group (
152
- backend = "gloo" , rank = rank , world_size = 2 ,
153
- init_method = 'tcp://localhost:29500' )
155
+ backend = "gloo" , rank = rank , world_size = 2 , init_method = "tcp://localhost:29500"
156
+ )
154
157
155
158
# Initialize RPC.
156
159
trainer_name = "trainer{}" .format (rank )
157
160
rpc .init_rpc (
158
- trainer_name ,
159
- rank = rank ,
160
- world_size = world_size ,
161
- rpc_backend_options = rpc_backend_options )
161
+ trainer_name ,
162
+ rank = rank ,
163
+ world_size = world_size ,
164
+ rpc_backend_options = rpc_backend_options ,
165
+ )
162
166
163
167
# Trainer just waits for RPCs from master.
164
168
else :
165
169
rpc .init_rpc (
166
- "ps" ,
167
- rank = rank ,
168
- world_size = world_size ,
169
- rpc_backend_options = rpc_backend_options )
170
+ "ps" ,
171
+ rank = rank ,
172
+ world_size = world_size ,
173
+ rpc_backend_options = rpc_backend_options ,
174
+ )
170
175
# parameter server do nothing
171
176
pass
172
177
173
178
# block until all rpcs finish
174
179
rpc .shutdown ()
175
180
176
181
177
- if __name__ == "__main__" :
182
+ if __name__ == "__main__" :
178
183
# 2 trainers, 1 parameter server, 1 master.
179
184
world_size = 4
180
- mp .spawn (run_worker , args = (world_size , ), nprocs = world_size , join = True )
185
+ mp .spawn (run_worker , args = (world_size ,), nprocs = world_size , join = True )
0 commit comments