Skip to content

Commit a67438b

Browse files
pritamdamania87pritamdamaniabrianjo
authored
Fix bugs in pipeline tutorial with respect to batch size. (pytorch#1459)
* Fix bugs in pipeline tutorial with respect to batch size. Summary: As described in pytorch/pytorch#55036, certain modules were not handling batch size correctly. * Rebase Co-authored-by: pritam <[email protected]> Co-authored-by: Brian Johnson <[email protected]>
1 parent f9b5840 commit a67438b

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

intermediate_source/pipeline_tutorial.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
class Encoder(nn.Module):
5757
def __init__(self, ntoken, ninp, dropout=0.5):
5858
super(Encoder, self).__init__()
59-
self.src_mask = None
6059
self.pos_encoder = PositionalEncoding(ninp, dropout)
6160
self.encoder = nn.Embedding(ntoken, ninp)
6261
self.ninp = ninp
@@ -66,17 +65,9 @@ def init_weights(self):
6665
initrange = 0.1
6766
self.encoder.weight.data.uniform_(-initrange, initrange)
6867

69-
def _generate_square_subsequent_mask(self, sz):
70-
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
71-
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
72-
return mask
73-
7468
def forward(self, src):
75-
if self.src_mask is None or self.src_mask.size(0) != src.size(0):
76-
device = src.device
77-
mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
78-
self.src_mask = mask
79-
69+
# Need (S, N) format for encoder.
70+
src = src.t()
8071
src = self.encoder(src) * math.sqrt(self.ninp)
8172
return self.pos_encoder(src)
8273

@@ -92,7 +83,8 @@ def init_weights(self):
9283
self.decoder.weight.data.uniform_(-initrange, initrange)
9384

9485
def forward(self, inp):
95-
return self.decoder(inp)
86+
# Need batch dimension first for output of pipeline.
87+
return self.decoder(inp).permute(1, 0, 2)
9688

9789

9890
######################################################################
@@ -221,7 +213,8 @@ def get_batch(source, i):
221213
seq_len = min(bptt, len(source) - 1 - i)
222214
data = source[i:i+seq_len]
223215
target = source[i+1:i+1+seq_len].view(-1)
224-
return data, target
216+
# Need batch dimension first for pipeline parallelism.
217+
return data.t(), target
225218

226219
######################################################################
227220
# Model scale and Pipe initialization
@@ -297,7 +290,8 @@ def get_batch(source, i):
297290
from torch.distributed.pipeline.sync import Pipe
298291

299292
# Build the pipeline.
300-
model = Pipe(torch.nn.Sequential(*module_list), chunks = 8)
293+
chunks = 8
294+
model = Pipe(torch.nn.Sequential(*module_list), chunks = chunks)
301295

302296

303297
def get_total_params(module: torch.nn.Module):

0 commit comments

Comments
 (0)