1
+ # !/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ #
4
+ # Copyright (c) 2023 Intel Corporation
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import evaluate
19
+ import nltk
20
+ import numpy as np
21
+ import torch
22
+ from torch .utils .data import DataLoader
23
+
24
+ @torch .no_grad ()
25
+ def compute_rouge_metric (model , tokenizer , eval_dataset , training_args , gen_kwargs ):
26
+ model .eval ()
27
+ model .config .bos_token_id = tokenizer .bos_token_id
28
+ model .config .eos_token_id = tokenizer .eos_token_id
29
+ model .config .pad_token_id = tokenizer .pad_token_id
30
+ # Metric
31
+ metric = evaluate .load ("rouge" )
32
+
33
+ def collate_fn (batch ):
34
+ input_ids = [torch .tensor (ins ["decoder_input_ids" ]) for ins in batch ]
35
+ labels = [torch .tensor (ins ["decoder_labels" ]) for ins in batch ]
36
+ attention_mask = [torch .tensor (ins ["decoder_attention_mask" ]) for ins in batch ]
37
+ input_ids = torch .nn .utils .rnn .pad_sequence (
38
+ input_ids , batch_first = True , padding_value = tokenizer .eos_token_id )
39
+ labels = torch .nn .utils .rnn .pad_sequence (labels , batch_first = True , padding_value = - 100 )
40
+ attention_mask = torch .nn .utils .rnn .pad_sequence (attention_mask , batch_first = True , padding_value = 0 )
41
+ return dict (
42
+ input_ids = input_ids ,
43
+ labels = labels ,
44
+ attention_mask = attention_mask ,
45
+ )
46
+
47
+ # TODO: support batch_size >1
48
+ eval_dataloader = DataLoader (eval_dataset , collate_fn = collate_fn ,
49
+ batch_size = 1 )
50
+
51
+
52
+ def postprocess_text (preds , labels ):
53
+ preds = [pred .strip () for pred in preds ]
54
+ labels = [label .strip () for label in labels ]
55
+
56
+ # rougeLSum expects newline after each sentence
57
+ preds = ["\n " .join (nltk .sent_tokenize (pred )) for pred in preds ]
58
+ labels = ["\n " .join (nltk .sent_tokenize (label )) for label in labels ]
59
+
60
+ return preds , labels
61
+
62
+ for step , batch in enumerate (eval_dataloader ):
63
+ preds = model .generate (
64
+ input_ids = batch ["input_ids" ].to (model .device ),
65
+ attention_mask = batch ["attention_mask" ].to (model .device ),
66
+ ** gen_kwargs ,
67
+ )
68
+ labels = batch ["labels" ]
69
+ labels = labels .cpu ().numpy ()
70
+
71
+ preds = preds .cpu ().numpy ()
72
+
73
+ # Replace -100s used for padding as we can't decode them
74
+ preds = np .where (preds != - 100 , preds , tokenizer .pad_token_id ).tolist ()
75
+ # only pred
76
+ preds = [pred [batch ["input_ids" ].shape [1 ]:] for pred in preds ]
77
+
78
+ decoded_preds = tokenizer .batch_decode (preds , skip_special_tokens = True )
79
+
80
+ labels = np .where (labels != - 100 , labels , tokenizer .pad_token_id ).tolist ()
81
+ decoded_labels = tokenizer .batch_decode (labels , skip_special_tokens = True )
82
+
83
+ # Some simple post-processing
84
+ decoded_preds , decoded_labels = postprocess_text (decoded_preds , decoded_labels )
85
+
86
+ metric .add_batch (
87
+ predictions = decoded_preds ,
88
+ references = decoded_labels ,
89
+ )
90
+
91
+
92
+ result = metric .compute (use_stemmer = True )
93
+ result = {k : round (v * 100 , 4 ) for k , v in result .items ()}
94
+ return result
0 commit comments