Skip to content

Commit 6e83811

Browse files
authored
typo
1 parent 8eeb15b commit 6e83811

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def project(self, x):
223223
# (batch, seq_len, vocab_size)
224224
return self.projection_layer(x)
225225

226-
def forward(self, encoder_output, encoder_mask, decoder_input, decoder_mask):
226+
def forward(self, encoder_input, encoder_mask, decoder_input, decoder_mask):
227227
encoder_output = self.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
228228
decoder_output = self.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
229229
return self.project(decoder_output) # (B, seq_len, vocab_size))

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ def train_model(config: ModelConfig):
265265
# encoder_output = model.module.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
266266
# decoder_output = model.module.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
267267
# proj_output = model.module.project(decoder_output) # (B, seq_len, vocab_size)
268-
proj_output = model(encoder_output, encoder_mask, decoder_input, decoder_mask)
269-
268+
proj_output = model(encoder_input, encoder_mask, decoder_input, decoder_mask)
269+
270270
# Compare the output with the label
271271
label = batch['label'].to(device) # (B, seq_len)
272272

0 commit comments

Comments
 (0)