Skip to content

gpt2推理结果不正确 #268

@cdj0311

Description

@cdj0311

大佬好,
我用gpt2_example.py推理gpt2,生成的第1个token没问题,但把生成的token拼接到前面的序列后继续推理生成的结果就不对了,
比如我的输入是:input_ids = torch.tensor([[12166, 10699, 16752, 4454]], dtype=torch.long).to(test_device)
推理代码:
for _ in range(32):
res = tt_model(input_ids) # sequence_output, pooled_output
gen_id = torch.argmax(res[0])
input_ids = torch.cat([input_ids, gen_id.unsqueeze(0).unsqueeze(1)], dim=-1)

生成结果:tensor([[12166, 10699, 16752, 4454, 477, 477, 477, .....]], device='cuda:0')
其中第1个477正确,后面看起来还是用的第1次的输入。
请问这是怎么回事?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions