Skip to content

Commit d962cd0

Browse files
authored
Fixed error for replace -100s in predictions by the pad token (intel#353)
Signed-off-by: Cheng, Penghui <[email protected]>
1 parent 0410a33 commit d962cd0

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

examples/huggingface/pytorch/summarization/quantization/run_summarization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,8 @@ def compute_metrics(eval_preds):
651651
preds, labels = eval_preds
652652
if isinstance(preds, tuple):
653653
preds = preds[0]
654+
# Replace -100s used for padding as we can't decode them
655+
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
654656
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
655657
if data_args.ignore_pad_token_for_loss:
656658
# Replace -100 in the labels as we can't decode them.

intel_extension_for_transformers/transformers/modeling/modeling_auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
6969
if load_in_8bit or load_in_4bit or quantization_config is not None:
7070
from intel_extension_for_transformers.llm.quantization.utils import convert_to_quantized_model
7171
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
72-
72+
7373
if load_in_4bit:
7474
if quantization_config is None:
7575
quantization_config = WeightOnlyQuantConfig(compute_dtype=torch_dtype, weight_dtype="nf4")
@@ -204,6 +204,7 @@ def default_calib_func(model):
204204
)
205205
return model
206206

207+
207208
class AutoModelForCausalLM(_BaseQBitsAutoModelClass):
208209
ORIG_MODEL = transformers.AutoModelForCausalLM
209210

@@ -214,4 +215,3 @@ class AutoModel(_BaseQBitsAutoModelClass):
214215

215216
class AutoModelForSeq2SeqLM(_BaseQBitsAutoModelClass):
216217
ORIG_MODEL = transformers.AutoModelForSeq2SeqLM
217-

0 commit comments

Comments
 (0)