Skip to content

Commit 01539f9

Browse files
authored
Merge pull request #908 from SciPioneer/patch-3
Replace RRef with RemoteModule
2 parents 5b528f5 + fba20d5 commit 01539f9

File tree

1 file changed

+20
-25
lines changed

1 file changed

+20
-25
lines changed

distributed/rpc/ddp_rpc/main.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
import os
21
import random
3-
from functools import wraps
42

53
import torch
64
import torch.distributed as dist
75
import torch.distributed.autograd as dist_autograd
86
import torch.distributed.rpc as rpc
97
import torch.multiprocessing as mp
108
import torch.optim as optim
9+
from torch.distributed.nn import RemoteModule
1110
from torch.distributed.optim import DistributedOptimizer
1211
from torch.distributed.rpc import RRef
1312
from torch.distributed.rpc import TensorPipeRpcBackendOptions
@@ -19,30 +18,24 @@
1918

2019
class HybridModel(torch.nn.Module):
2120
r"""
22-
The model consists of a sparse part and a dense part. The dense part is an
23-
nn.Linear module that is replicated across all trainers using
24-
DistributedDataParallel. The sparse part is an nn.EmbeddingBag that is
25-
stored on the parameter server.
26-
The model holds a Remote Reference to the embedding table on the parameter
27-
server.
21+
The model consists of a sparse part and a dense part.
22+
1) The dense part is an nn.Linear module that is replicated across all trainers using DistributedDataParallel.
23+
2) The sparse part is a Remote Module that holds an nn.EmbeddingBag on the parameter server.
24+
This remote model can get a Remote Reference to the embedding table on the parameter server.
2825
"""
2926

30-
def __init__(self, emb_rref, device):
27+
def __init__(self, remote_emb_module, device):
3128
super(HybridModel, self).__init__()
32-
self.emb_rref = emb_rref
29+
self.remote_emb_module = remote_emb_module
3330
self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
3431
self.device = device
3532

3633
def forward(self, indices, offsets):
37-
emb_lookup = self.emb_rref.rpc_sync().forward(indices, offsets)
34+
emb_lookup = self.remote_emb_module.forward(indices, offsets)
3835
return self.fc(emb_lookup.cuda(self.device))
3936

4037

41-
def _retrieve_embedding_parameters(emb_rref):
42-
return [RRef(p) for p in emb_rref.local_value().parameters()]
43-
44-
45-
def _run_trainer(emb_rref, rank):
38+
def _run_trainer(remote_emb_module, rank):
4639
r"""
4740
Each trainer runs a forward pass which involves an embedding lookup on the
4841
parameter server and running nn.Linear locally. During the backward pass,
@@ -52,17 +45,18 @@ def _run_trainer(emb_rref, rank):
5245
"""
5346

5447
# Setup the model.
55-
model = HybridModel(emb_rref, rank)
48+
model = HybridModel(remote_emb_module, rank)
5649

5750
# Retrieve all model parameters as rrefs for DistributedOptimizer.
5851

5952
# Retrieve parameters for embedding table.
60-
model_parameter_rrefs = rpc.rpc_sync(
61-
"ps", _retrieve_embedding_parameters, args=(emb_rref,)
62-
)
53+
model_parameter_rrefs = model.remote_emb_module.remote_parameters()
6354

64-
# model.parameters() only includes local parameters.
65-
for param in model.parameters():
55+
# model.fc.parameters() only includes local parameters.
56+
# NOTE: Cannot call model.parameters() here,
57+
# because this will call remote_emb_module.parameters(),
58+
# which supports remote_parameters() but not parameters().
59+
for param in model.fc.parameters():
6660
model_parameter_rrefs.append(RRef(param))
6761

6862
# Setup distributed optimizer
@@ -131,8 +125,7 @@ def run_worker(rank, world_size):
131125
rpc_backend_options=rpc_backend_options,
132126
)
133127

134-
# Build the embedding table on the ps.
135-
emb_rref = rpc.remote(
128+
remote_emb_module = RemoteModule(
136129
"ps",
137130
torch.nn.EmbeddingBag,
138131
args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
@@ -143,7 +136,9 @@ def run_worker(rank, world_size):
143136
futs = []
144137
for trainer_rank in [0, 1]:
145138
trainer_name = "trainer{}".format(trainer_rank)
146-
fut = rpc.rpc_async(trainer_name, _run_trainer, args=(emb_rref, rank))
139+
fut = rpc.rpc_async(
140+
trainer_name, _run_trainer, args=(remote_emb_module, rank)
141+
)
147142
futs.append(fut)
148143

149144
# Wait for all training to finish.

0 commit comments

Comments
 (0)