Skip to content

Commit 9602532

Browse files
authored
fix paddle.sum (#909)
1 parent 4075ea6 commit 9602532

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

paddlenlp/ops/faster_transformer/transformer/faster_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def forward(self, src_word):
125125

126126
mem_seq_lens = paddle.sum(paddle.cast(
127127
src_word != self.bos_id, dtype="int32"),
128+
dtype="int32",
128129
axis=1)
129130
ids = self.decoding(enc_output, mem_seq_lens)
130131

0 commit comments

Comments
 (0)