Skip to content

Commit e9517ed

Browse files
authored
Create utils.py
1 parent f95460d commit e9517ed

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
from transformers import StoppingCriteria
3+
4+
5+
class StopWordsCriteria(StoppingCriteria):
6+
7+
def __init__(self, stop_indices: list):
8+
self.stop_indices = stop_indices
9+
10+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
11+
# do not support batch inference
12+
for i in range(len(self.stop_indices)):
13+
if self.stop_indices[-1-i] != input_ids[0][-1-i]:
14+
return False
15+
return True

0 commit comments

Comments
 (0)