Skip to content

Commit a19467d

Browse files
authored
[NeuralChat] Add bm25 into enabled retrievers and add Uts (intel#1313)
Add bm25 into enabled retrievers and add Uts Signed-off-by: XuhuiRen <[email protected]>
1 parent f432a7a commit a19467d

File tree

4 files changed

+63
-4
lines changed

4 files changed

+63
-4
lines changed

intel_extension_for_transformers/langchain/retrievers/bge_reranker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@
2626
from FlagEmbedding import FlagReranker
2727

2828
class BgeReranker(BaseDocumentCompressor):
29-
model_name:str = 'bge_reranker_model_path'
3029
top_n: int = 3 # Number of documents to return.
31-
model:FlagReranker = FlagReranker(model_name)
30+
model:FlagReranker
3231
"""CrossEncoder instance to use for reranking."""
3332

3433
def bge_rerank(self, query, docs):
3534
model_inputs = [[query, doc] for doc in docs]
3635
scores = self.model.compute_score(model_inputs)
36+
if len(docs) == 1:
37+
return [(0, scores)]
3738
results = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
3839
return results[:self.top_n]
3940

intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(self,
8787
allowed_retrieval_type: ClassVar[Collection[str]] = (
8888
"default",
8989
"child_parent",
90+
'bm25',
9091
)
9192
allowed_generation_mode: ClassVar[Collection[str]] = (
9293
"accuracy",

intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retriever_adapter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@ def __init__(self, retrieval_type='default', document_store=None, child_document
3636
self.retrieval_type = retrieval_type
3737
if enable_rerank:
3838
from intel_extension_for_transformers.langchain.retrievers.bge_reranker import BgeReranker
39-
self.reranker = BgeReranker(model_name = reranker_model, top_n=top_n)
39+
from FlagEmbedding import FlagReranker
40+
reranker = FlagReranker(reranker_model)
41+
self.reranker = BgeReranker(model = reranker, top_n=top_n)
4042
else:
4143
self.reranker = None
4244

4345
if self.retrieval_type == "default":
4446
self.retriever = VectorStoreRetriever(vectorstore=document_store, **kwargs)
45-
if self.retrieval_type == "bm25":
47+
elif self.retrieval_type == "bm25":
4648
self.retriever = BM25Retriever.from_documents(docs, **kwargs)
4749
elif self.retrieval_type == "child_parent":
4850
self.retriever = ChildParentRetriever(parentstore=document_store, \

intel_extension_for_transformers/neural_chat/tests/ci/plugins/retrieval/test_parameters.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,61 @@ def test_accuracy_mode(self):
7979
plugins.retrieval.args = {}
8080
plugins.retrieval.enable = False
8181

82+
class TestBM25Retriever(unittest.TestCase):
83+
def setUp(self):
84+
if os.path.exists("./bm25"):
85+
shutil.rmtree("./bm25", ignore_errors=True)
86+
return super().setUp()
87+
88+
def tearDown(self) -> None:
89+
if os.path.exists("./bm25"):
90+
shutil.rmtree("./bm25", ignore_errors=True)
91+
return super().tearDown()
92+
93+
def test_accuracy_mode(self):
94+
plugins.retrieval.args = {}
95+
plugins.retrieval.enable = True
96+
plugins.retrieval.args["input_path"] = "../assets/docs/sample.txt"
97+
plugins.retrieval.args["persist_directory"] = "./bm25"
98+
plugins.retrieval.args["retrieval_type"] = 'bm25'
99+
plugins.retrieval.args["mode"] = 'accuracy'
100+
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
101+
plugins=plugins)
102+
chatbot = build_chatbot(config)
103+
response = chatbot.predict("How many cores does the Intel Xeon Platinum 8480+ Processor have in total?")
104+
print(response)
105+
self.assertIsNotNone(response)
106+
plugins.retrieval.args = {}
107+
plugins.retrieval.enable = False
108+
109+
class TestRerank(unittest.TestCase):
110+
def setUp(self):
111+
if os.path.exists("./rerank"):
112+
shutil.rmtree("./rerank", ignore_errors=True)
113+
return super().setUp()
114+
115+
def tearDown(self) -> None:
116+
if os.path.exists("./rerank"):
117+
shutil.rmtree("./rerank", ignore_errors=True)
118+
return super().tearDown()
119+
120+
def test_general_mode(self):
121+
plugins.retrieval.args = {}
122+
plugins.retrieval.enable = True
123+
plugins.retrieval.args["input_path"] = "../assets/docs/sample.txt"
124+
plugins.retrieval.args["persist_directory"] = "./rerank"
125+
plugins.retrieval.args["retrieval_type"] = 'default'
126+
plugins.retrieval.args['enable_rerank'] = True
127+
plugins.retrieval.args['reranker_model'] = 'BAAI/bge-reranker-base'
128+
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
129+
plugins=plugins)
130+
chatbot = build_chatbot(config)
131+
response = chatbot.predict("How many cores does the Intel Xeon Platinum 8480+ Processor have in total?")
132+
print(response)
133+
self.assertIsNotNone(response)
134+
plugins.retrieval.args = {}
135+
plugins.retrieval.enable = False
136+
82137
class TestGeneralMode(unittest.TestCase):
83138
def setUp(self):
84139
if os.path.exists("./general_mode"):

0 commit comments

Comments
 (0)