Skip to content

Commit 2de8830

Browse files
[NeuralChat] enable lm_eval during training. (intel#1363)
* enable lm_eval during training. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 90646c1 commit 2de8830

File tree

5 files changed

+143
-5
lines changed

5 files changed

+143
-5
lines changed

intel_extension_for_transformers/llm/evaluation/lm_eval/evaluator.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
MODEL_REGISTRY = {
3535
"hf-causal": huggingface.AutoCausalLM,
3636
"hf-seq2seq": huggingface.AutoSeq2SeqLM,
37-
37+
"simple-hf-causal": huggingface.HFModelAdapter,
3838
}
3939

4040
def itrex_bootstrap_stderr(f, xs, iters):
@@ -69,6 +69,8 @@ def evaluate(model,
6969
output_base_path=None,
7070
seed=1234,
7171
user_model=None,
72+
user_tokenizer=None,
73+
warmup=False,
7274
model_format='torch'
7375
):
7476
"""Instantiate and evaluate a model on a list of tasks.
@@ -125,6 +127,18 @@ def evaluate(model,
125127
}
126128
if user_model:
127129
kwargs["init_empty_weights"] = True
130+
131+
if device == "hpu":
132+
# if hpu, set user_model
133+
kwargs["user_model"] = user_model
134+
if model == "hf-causal":
135+
model = "simple-hf-causal"
136+
if model == "simple-hf-causal":
137+
kwargs["warmup"] = warmup
138+
139+
if user_tokenizer:
140+
kwargs["user_tokenizer"] = user_tokenizer
141+
128142
lm = get_model(model).create_from_arg_string(
129143
model_args, kwargs
130144
)

intel_extension_for_transformers/llm/evaluation/lm_eval/models/huggingface.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from transformers import BatchEncoding
3030

3131
from lm_eval import utils
32-
from lm_eval.base import BaseLM
32+
from lm_eval.base import BaseLM, CacheHook
3333
import re
3434

3535
TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding]
@@ -1078,3 +1078,80 @@ def stop_sequences_criteria(
10781078
],
10791079
]
10801080
)
1081+
1082+
1083+
class HFModelAdapter(HuggingFaceAutoLM):
1084+
AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
1085+
AUTO_PEFT_CLASS = peft.PeftModel
1086+
1087+
def __init__(self, *args, user_model=None, user_tokenizer=None, **kwargs):
1088+
self.cache_hook = CacheHook(None)
1089+
self.model = user_model
1090+
if user_tokenizer is None:
1091+
self.tokenizer = self._create_auto_tokenizer(
1092+
pretrained=kwargs["pretrained"],
1093+
revision="main",
1094+
subfolder=None,)
1095+
else:
1096+
self.tokenizer = user_tokenizer
1097+
self._batch_size = kwargs["batch_size"]
1098+
self._add_special_tokens = None
1099+
self.model_format = "torch"
1100+
self.buckets = [16, 32, 64, 128, 189, 284]
1101+
self._device = kwargs["device"]
1102+
if self._device == "hpu":
1103+
from optimum.habana.checkpoint_utils import model_is_optimized # pylint: disable=E0611, E0401
1104+
self.static_shapes = model_is_optimized(self.model.config)
1105+
else:
1106+
self.static_shapes = False
1107+
if kwargs["warmup"]:
1108+
print("lm-eval warmup for Gaudi.")
1109+
self.warm_up()
1110+
1111+
def warm_up(self):
1112+
for bucket_size in reversed(self.buckets):
1113+
inps = torch.ones((self._batch_size, bucket_size), dtype=torch.int64)
1114+
self._model_call(inps)
1115+
pass
1116+
1117+
@property
1118+
def eot_token_id(self):
1119+
return self.model.config.eos_token_id
1120+
1121+
@property
1122+
def max_length(self):
1123+
return self.buckets[-1]
1124+
1125+
@property
1126+
def max_gen_toks(self):
1127+
raise NotImplementedError()
1128+
1129+
@property
1130+
def batch_size(self):
1131+
return self._batch_size
1132+
1133+
@property
1134+
def device(self):
1135+
# We need to do padding ourselves, otherwise we'll end up with recompilations
1136+
# Returning 'cpu' to keep tensors on CPU in lm_eval code
1137+
return "cpu"
1138+
1139+
def _model_generate(self, context, max_length, eos_token_id):
1140+
raise NotImplementedError()
1141+
1142+
def find_bucket(self, length):
1143+
return [b for b in self.buckets if b >= length][0]
1144+
1145+
def _model_call(self, inps):
1146+
bs, seq_length = inps.shape
1147+
padding_length = 0
1148+
if self.static_shapes:
1149+
bucket_length = self.find_bucket(seq_length)
1150+
padding_length = bucket_length - seq_length
1151+
inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id)
1152+
logits = self.model(inps.to(self._device))["logits"].cpu()
1153+
1154+
if self.static_shapes and padding_length > 0:
1155+
logits = logits[:, :-padding_length, :]
1156+
logits = logits.to(torch.float32)
1157+
return logits

intel_extension_for_transformers/llm/finetuning/eval_utils.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from transformers.trainer_utils import speed_metrics
2626
from transformers.debug_utils import DebugOption
2727
import math
28+
from transformers import TrainerCallback
2829

2930
@torch.no_grad()
3031
def compute_rouge_metric(model, tokenizer, eval_dataset, training_args, gen_kwargs):
@@ -155,10 +156,39 @@ def evaluate_plus_ppl(
155156

156157
output.metrics[f"{metric_key_prefix}_ppl"] = math.exp(output.metrics[f"{metric_key_prefix}_loss"])
157158

158-
self.log(output.metrics)
159-
160159
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
161160

161+
self.log(output.metrics)
162+
162163
self._memory_tracker.stop_and_update_metrics(output.metrics)
163164

164165
return output.metrics
166+
167+
168+
class LMEvalCallback(TrainerCallback):
169+
def __init__(self, lm_eval_func, device=None):
170+
self.lm_eval = lm_eval_func
171+
self.device = device
172+
self.warmup = True
173+
174+
def on_evaluate(self, args, state, control, **kwargs):
175+
if not state.is_local_process_zero:
176+
return
177+
if self.device == "hpu":
178+
results = self.lm_eval(user_model=kwargs["model"],
179+
user_tokenizer=kwargs["tokenizer"],
180+
warmup=self.warmup)
181+
self.warmup = False
182+
else:
183+
results = self.lm_eval(model="simple-hf-causal",
184+
user_model=kwargs["model"],
185+
user_tokenizer=kwargs["tokenizer"],
186+
warmup=False)
187+
task_metrics = {}
188+
for task_name in results["results"]:
189+
for metric in results["results"][task_name]:
190+
if "stderr" in metric:
191+
continue
192+
metric_name = task_name + "_" + metric
193+
task_metrics[metric_name] = results["results"][task_name][metric]
194+
kwargs["metrics"].update(task_metrics)

intel_extension_for_transformers/llm/finetuning/finetuning.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,21 @@ def concatenate_data(dataset, max_seq_length):
565565
if model_dtype == torch.bfloat16:
566566
model = model.to(model_dtype)
567567

568+
lm_eval_callback = None
569+
if training_args.do_eval and finetune_args.do_lm_eval:
570+
from .eval_utils import LMEvalCallback
571+
from functools import partial
572+
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
573+
lm_eval_func = partial(evaluate,
574+
model="hf-causal",
575+
model_args='pretrained='+model_args.model_name_or_path+\
576+
',tokenizer='+model_args.model_name_or_path+',dtype=float16',
577+
device=finetune_args.device,
578+
batch_size=training_args.per_device_eval_batch_size,
579+
tasks=finetune_args.lm_eval_tasks,
580+
limit=data_args.max_eval_samples)
581+
lm_eval_callback = LMEvalCallback(lm_eval_func, device=finetune_args.device)
582+
568583
if finetune_args.device != 'hpu':
569584
# Initialize our Trainer
570585
trainer = Trainer(
@@ -574,6 +589,7 @@ def concatenate_data(dataset, max_seq_length):
574589
eval_dataset=eval_dataset if training_args.do_eval else None,
575590
tokenizer=tokenizer,
576591
data_collator=data_collator,
592+
callbacks=[lm_eval_callback] if lm_eval_callback is not None else None
577593
)
578594
else:
579595
from optimum.habana import GaudiConfig, GaudiTrainer # pylint: disable=E0611 E0401
@@ -590,6 +606,7 @@ def concatenate_data(dataset, max_seq_length):
590606
eval_dataset=eval_dataset if training_args.do_eval else None,
591607
tokenizer=tokenizer,
592608
data_collator=data_collator,
609+
callbacks=[lm_eval_callback] if lm_eval_callback is not None else None
593610
)
594611

595612
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)

intel_extension_for_transformers/neural_chat/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ class FinetuningArguments:
349349
metadata={"help": "whether to run the LM evaluation with EleutherAI/lm-evaluation-harness"},
350350
)
351351
lm_eval_tasks: Optional[List[str]] = field(
352-
default_factory=lambda: ["truthfulqa_mc"],
352+
default_factory=lambda: ["truthfulqa_mc", "lambada_openai"],
353353
metadata={"help": "tasks list for accuracy validation with EleutherAI/lm-evaluation-harness."},
354354
)
355355
qlora: bool = field(

0 commit comments

Comments
 (0)