Skip to content

Commit c3ad585

Browse files
Added QLoRA support in NeuralChat finetuning and refined NeuralChat optimization API. (intel#174)
1 parent 3b8878e commit c3ad585

File tree

12 files changed

+572
-281
lines changed

12 files changed

+572
-281
lines changed

.github/workflows/unit-test-engine.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
include:
3535
- test_branch: ${{ github.ref }}
3636
test_name: "PR-test"
37-
- test_branch: "ut_parallal"
37+
- test_branch: "main"
3838
test_name: "baseline"
3939
steps:
4040
- name: Docker Clean Up

.github/workflows/unit-test-optimize.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
include:
3535
- test_branch: ${{ github.ref }}
3636
test_name: "PR-test"
37-
- test_branch: "ut_parallal"
37+
- test_branch: "main"
3838
test_name: "baseline"
3939
steps:
4040
- name: Docker Clean Up

intel_extension_for_transformers/llm/finetuning/data_utils.py

Lines changed: 228 additions & 183 deletions
Large diffs are not rendered by default.
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# !/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import evaluate
19+
import nltk
20+
import numpy as np
21+
import torch
22+
from torch.utils.data import DataLoader
23+
24+
@torch.no_grad()
25+
def compute_rouge_metric(model, tokenizer, eval_dataset, training_args, gen_kwargs):
26+
model.eval()
27+
model.config.bos_token_id = tokenizer.bos_token_id
28+
model.config.eos_token_id = tokenizer.eos_token_id
29+
model.config.pad_token_id = tokenizer.pad_token_id
30+
# Metric
31+
metric = evaluate.load("rouge")
32+
33+
def collate_fn(batch):
34+
input_ids = [torch.tensor(ins["decoder_input_ids"]) for ins in batch]
35+
labels = [torch.tensor(ins["decoder_labels"]) for ins in batch]
36+
attention_mask = [torch.tensor(ins["decoder_attention_mask"]) for ins in batch]
37+
input_ids = torch.nn.utils.rnn.pad_sequence(
38+
input_ids, batch_first=True, padding_value=tokenizer.eos_token_id)
39+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
40+
attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
41+
return dict(
42+
input_ids=input_ids,
43+
labels=labels,
44+
attention_mask=attention_mask,
45+
)
46+
47+
# TODO: support batch_size >1
48+
eval_dataloader = DataLoader(eval_dataset, collate_fn=collate_fn,
49+
batch_size=1)
50+
51+
52+
def postprocess_text(preds, labels):
53+
preds = [pred.strip() for pred in preds]
54+
labels = [label.strip() for label in labels]
55+
56+
# rougeLSum expects newline after each sentence
57+
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
58+
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
59+
60+
return preds, labels
61+
62+
for step, batch in enumerate(eval_dataloader):
63+
preds = model.generate(
64+
input_ids=batch["input_ids"].to(model.device),
65+
attention_mask=batch["attention_mask"].to(model.device),
66+
**gen_kwargs,
67+
)
68+
labels = batch["labels"]
69+
labels = labels.cpu().numpy()
70+
71+
preds = preds.cpu().numpy()
72+
73+
# Replace -100s used for padding as we can't decode them
74+
preds = np.where(preds != -100, preds, tokenizer.pad_token_id).tolist()
75+
# only pred
76+
preds = [pred[batch["input_ids"].shape[1]:] for pred in preds]
77+
78+
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
79+
80+
labels = np.where(labels != -100, labels, tokenizer.pad_token_id).tolist()
81+
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
82+
83+
# Some simple post-processing
84+
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
85+
86+
metric.add_batch(
87+
predictions=decoded_preds,
88+
references=decoded_labels,
89+
)
90+
91+
92+
result = metric.compute(use_stemmer=True)
93+
result = {k: round(v * 100, 4) for k, v in result.items()}
94+
return result

0 commit comments

Comments
 (0)