Skip to content

Commit 8eeb15b

Browse files
authored
wrap
1 parent c4e6458 commit 8eeb15b

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torc
222222
def project(self, x):
223223
# (batch, seq_len, vocab_size)
224224
return self.projection_layer(x)
225+
226+
def forward(self, encoder_output, encoder_mask, decoder_input, decoder_mask):
227+
encoder_output = self.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
228+
decoder_output = self.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
229+
return self.project(decoder_output) # (B, seq_len, vocab_size))
225230

226231
def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048) -> Transformer:
227232
# Create the embedding layers

train.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,12 @@ def train_model(config: ModelConfig):
261261
encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
262262
decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
263263

264-
# Run the tensors through the encoder, decoder and the projection layer
265-
encoder_output = model.module.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
266-
decoder_output = model.module.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
267-
proj_output = model.module.project(decoder_output) # (B, seq_len, vocab_size)
268-
264+
# # Run the tensors through the encoder, decoder and the projection layer
265+
# encoder_output = model.module.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
266+
# decoder_output = model.module.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
267+
# proj_output = model.module.project(decoder_output) # (B, seq_len, vocab_size)
268+
proj_output = model(encoder_output, encoder_mask, decoder_input, decoder_mask)
269+
269270
# Compare the output with the label
270271
label = batch['label'].to(device) # (B, seq_len)
271272

0 commit comments

Comments
 (0)