Skip to content

Commit 4075ea6

Browse files
authored
Transformer input id supports int32 & mv beam search v2 (#905)
* support int32 * mv beam search v2
1 parent 2ae7a87 commit 4075ea6

File tree

7 files changed

+240
-220
lines changed

7 files changed

+240
-220
lines changed

examples/machine_translation/transformer/configs/transformer.base.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ random_seed: None
1818
output_file: "predict.txt"
1919
# The <bos>, <eos> and <unk> tokens in the dictionary.
2020
special_token: ["<s>", "<e>", "<unk>"]
21+
# The data type of input ids.
22+
input_dtype: "int64"
2123

2224
# Device to use.
2325
device: "gpu"

examples/machine_translation/transformer/configs/transformer.big.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ random_seed: None
1818
output_file: "predict.txt"
1919
# The <bos>, <eos> and <unk> tokens in the dictionary.
2020
special_token: ["<s>", "<e>", "<unk>"]
21+
# The data type of input ids.
22+
input_dtype: "int64"
2123

2224
# Device to use.
2325
device: "gpu"

examples/machine_translation/transformer/reader.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ def convert_samples(sample):
9595
bos_idx=args.bos_idx,
9696
eos_idx=args.eos_idx,
9797
pad_idx=args.bos_idx,
98-
pad_seq=args.pad_seq),
98+
pad_seq=args.pad_seq,
99+
dtype=args.input_dtype),
99100
num_workers=0)
100101
data_loaders[i] = (data_loader)
101102
return data_loaders
@@ -142,7 +143,8 @@ def convert_samples(sample):
142143
bos_idx=args.bos_idx,
143144
eos_idx=args.eos_idx,
144145
pad_idx=args.bos_idx,
145-
pad_seq=args.pad_seq),
146+
pad_seq=args.pad_seq,
147+
dtype=args.input_dtype),
146148
num_workers=0,
147149
return_list=True)
148150
return data_loader, trg_vocab.to_tokens
@@ -163,11 +165,16 @@ def adapt_vocab_size(args):
163165
args.trg_vocab_size = padding_vocab(len(trg_vocab))
164166

165167

166-
def prepare_train_input(insts, bos_idx, eos_idx, pad_idx, pad_seq=1):
168+
def prepare_train_input(insts,
169+
bos_idx,
170+
eos_idx,
171+
pad_idx,
172+
pad_seq=1,
173+
dtype="int64"):
167174
"""
168175
Put all padded data needed by training into a list.
169176
"""
170-
word_pad = Pad(pad_idx, dtype="int64")
177+
word_pad = Pad(pad_idx, dtype=dtype)
171178
src_max_len = (
172179
max([len(inst[0]) for inst in insts]) + pad_seq) // pad_seq * pad_seq
173180
trg_max_len = (
@@ -190,11 +197,16 @@ def prepare_train_input(insts, bos_idx, eos_idx, pad_idx, pad_seq=1):
190197
return data_inputs
191198

192199

193-
def prepare_infer_input(insts, bos_idx, eos_idx, pad_idx, pad_seq=1):
200+
def prepare_infer_input(insts,
201+
bos_idx,
202+
eos_idx,
203+
pad_idx,
204+
pad_seq=1,
205+
dtype="int64"):
194206
"""
195207
Put all padded data needed by beam search decoder into a list.
196208
"""
197-
word_pad = Pad(pad_idx, dtype="int64")
209+
word_pad = Pad(pad_idx, dtype=dtype)
198210
src_max_len = (
199211
max([len(inst[0]) for inst in insts]) + pad_seq) // pad_seq * pad_seq
200212
src_word = word_pad([

examples/machine_translation/transformer/static/predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def do_predict(args):
8585
startup_program = paddle.static.Program()
8686
with paddle.static.program_guard(test_program, startup_program):
8787
src_word = paddle.static.data(
88-
name="src_word", shape=[None, None], dtype="int64")
88+
name="src_word", shape=[None, None], dtype=args.input_dtype)
8989

9090
# Define model
9191
transformer = InferTransformerModel(

examples/machine_translation/transformer/static/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ def do_train(args):
9292
startup_program = paddle.static.Program()
9393
with paddle.static.program_guard(train_program, startup_program):
9494
src_word = paddle.static.data(
95-
name="src_word", shape=[None, None], dtype="int64")
95+
name="src_word", shape=[None, None], dtype=args.input_dtype)
9696
trg_word = paddle.static.data(
97-
name="trg_word", shape=[None, None], dtype="int64")
97+
name="trg_word", shape=[None, None], dtype=args.input_dtype)
9898
lbl_word = paddle.static.data(
99-
name="lbl_word", shape=[None, None, 1], dtype="int64")
99+
name="lbl_word", shape=[None, None, 1], dtype=args.input_dtype)
100100

101101
# Define model
102102
transformer = TransformerModel(

paddlenlp/ops/faster_transformer/transformer/faster_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def forward(self, src_word):
108108
src_word == self.bos_id,
109109
dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9
110110
src_pos = paddle.cast(
111-
src_word != self.bos_id, dtype="int64") * paddle.arange(
111+
src_word != self.bos_id, dtype=src_word.dtype) * paddle.arange(
112112
start=0, end=src_max_len)
113113

114114
# Run encoder

0 commit comments

Comments
 (0)