-
Notifications
You must be signed in to change notification settings - Fork 205
Open
Description
大佬好,
我用gpt2_example.py推理gpt2,生成的第1个token没问题,但把生成的token拼接到前面的序列后继续推理生成的结果就不对了,
比如我的输入是:input_ids = torch.tensor([[12166, 10699, 16752, 4454]], dtype=torch.long).to(test_device)
推理代码:
for _ in range(32):
res = tt_model(input_ids) # sequence_output, pooled_output
gen_id = torch.argmax(res[0])
input_ids = torch.cat([input_ids, gen_id.unsqueeze(0).unsqueeze(1)], dim=-1)
生成结果:tensor([[12166, 10699, 16752, 4454, 477, 477, 477, .....]], device='cuda:0')
其中第1个477正确,后面看起来还是用的第1次的输入。
请问这是怎么回事?
Metadata
Metadata
Assignees
Labels
No labels