1
- import os
2
1
import random
3
- from functools import wraps
4
2
5
3
import torch
6
4
import torch .distributed as dist
7
5
import torch .distributed .autograd as dist_autograd
8
6
import torch .distributed .rpc as rpc
9
7
import torch .multiprocessing as mp
10
8
import torch .optim as optim
9
+ from torch .distributed .nn import RemoteModule
11
10
from torch .distributed .optim import DistributedOptimizer
12
11
from torch .distributed .rpc import RRef
13
12
from torch .distributed .rpc import TensorPipeRpcBackendOptions
19
18
20
19
class HybridModel (torch .nn .Module ):
21
20
r"""
22
- The model consists of a sparse part and a dense part. The dense part is an
23
- nn.Linear module that is replicated across all trainers using
24
- DistributedDataParallel. The sparse part is an nn.EmbeddingBag that is
25
- stored on the parameter server.
26
- The model holds a Remote Reference to the embedding table on the parameter
27
- server.
21
+ The model consists of a sparse part and a dense part.
22
+ 1) The dense part is an nn.Linear module that is replicated across all trainers using DistributedDataParallel.
23
+ 2) The sparse part is a Remote Module that holds an nn.EmbeddingBag on the parameter server.
24
+ This remote model can get a Remote Reference to the embedding table on the parameter server.
28
25
"""
29
26
30
- def __init__ (self , emb_rref , device ):
27
+ def __init__ (self , remote_emb_module , device ):
31
28
super (HybridModel , self ).__init__ ()
32
- self .emb_rref = emb_rref
29
+ self .remote_emb_module = remote_emb_module
33
30
self .fc = DDP (torch .nn .Linear (16 , 8 ).cuda (device ), device_ids = [device ])
34
31
self .device = device
35
32
36
33
def forward (self , indices , offsets ):
37
- emb_lookup = self .emb_rref . rpc_sync () .forward (indices , offsets )
34
+ emb_lookup = self .remote_emb_module .forward (indices , offsets )
38
35
return self .fc (emb_lookup .cuda (self .device ))
39
36
40
37
41
- def _retrieve_embedding_parameters (emb_rref ):
42
- return [RRef (p ) for p in emb_rref .local_value ().parameters ()]
43
-
44
-
45
- def _run_trainer (emb_rref , rank ):
38
+ def _run_trainer (remote_emb_module , rank ):
46
39
r"""
47
40
Each trainer runs a forward pass which involves an embedding lookup on the
48
41
parameter server and running nn.Linear locally. During the backward pass,
@@ -52,17 +45,18 @@ def _run_trainer(emb_rref, rank):
52
45
"""
53
46
54
47
# Setup the model.
55
- model = HybridModel (emb_rref , rank )
48
+ model = HybridModel (remote_emb_module , rank )
56
49
57
50
# Retrieve all model parameters as rrefs for DistributedOptimizer.
58
51
59
52
# Retrieve parameters for embedding table.
60
- model_parameter_rrefs = rpc .rpc_sync (
61
- "ps" , _retrieve_embedding_parameters , args = (emb_rref ,)
62
- )
53
+ model_parameter_rrefs = model .remote_emb_module .remote_parameters ()
63
54
64
- # model.parameters() only includes local parameters.
65
- for param in model .parameters ():
55
+ # model.fc.parameters() only includes local parameters.
56
+ # NOTE: Cannot call model.parameters() here,
57
+ # because this will call remote_emb_module.parameters(),
58
+ # which supports remote_parameters() but not parameters().
59
+ for param in model .fc .parameters ():
66
60
model_parameter_rrefs .append (RRef (param ))
67
61
68
62
# Setup distributed optimizer
@@ -131,8 +125,7 @@ def run_worker(rank, world_size):
131
125
rpc_backend_options = rpc_backend_options ,
132
126
)
133
127
134
- # Build the embedding table on the ps.
135
- emb_rref = rpc .remote (
128
+ remote_emb_module = RemoteModule (
136
129
"ps" ,
137
130
torch .nn .EmbeddingBag ,
138
131
args = (NUM_EMBEDDINGS , EMBEDDING_DIM ),
@@ -143,7 +136,9 @@ def run_worker(rank, world_size):
143
136
futs = []
144
137
for trainer_rank in [0 , 1 ]:
145
138
trainer_name = "trainer{}" .format (trainer_rank )
146
- fut = rpc .rpc_async (trainer_name , _run_trainer , args = (emb_rref , rank ))
139
+ fut = rpc .rpc_async (
140
+ trainer_name , _run_trainer , args = (remote_emb_module , rank )
141
+ )
147
142
futs .append (fut )
148
143
149
144
# Wait for all training to finish.
0 commit comments