|
| 1 | +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import argparse |
| 16 | +import glob |
| 17 | +import os |
| 18 | + |
| 19 | +from pipelines.agents import Agent, Tool |
| 20 | +from pipelines.agents.base import ToolsManager |
| 21 | +from pipelines.document_stores import FAISSDocumentStore |
| 22 | +from pipelines.nodes import ( |
| 23 | + CharacterTextSplitter, |
| 24 | + DensePassageRetriever, |
| 25 | + DocxToTextConverter, |
| 26 | + FileTypeClassifier, |
| 27 | + PDFToTextConverter, |
| 28 | + PromptNode, |
| 29 | + TextConverter, |
| 30 | + WebRetriever, |
| 31 | +) |
| 32 | +from pipelines.nodes.prompt.prompt_template import PromptTemplate |
| 33 | +from pipelines.pipelines import Pipeline, WebQAPipeline |
| 34 | +from pipelines.utils import fetch_archive_from_http |
| 35 | + |
| 36 | +few_shot_prompt = """ |
| 37 | +你是一个乐于助人、知识渊博的人工智能助手。为了实现正确回答复杂问题的目标,您可以使用以下工具: |
| 38 | +搜索: 当你需要用谷歌搜索问题时很有用。你应该问一些有针对性的问题,例如,谁是安东尼·迪雷尔的兄弟? |
| 39 | +要回答问题,你需要经历多个步骤,包括逐步思考和选择合适的工具及其输入;工具将以观察作为回应。当您准备好接受最终答案时,回答"最终答案": |
| 40 | +示例: |
| 41 | +## |
| 42 | +问题: 哈利波特的作者是谁? |
| 43 | +思考: 让我们一步一步地思考。要回答这个问题,我们首先需要了解哈利波特是什么。 |
| 44 | +工具: 搜索 |
| 45 | +工具输入: 哈利波特是什么? |
| 46 | +观察: 哈利波特是一系列非常受欢迎的魔幻小说,以及后来的电影和衍生作品。 |
| 47 | +思考: 我们了解到哈利波特是一系列魔幻小说。现在我们需要找到这些小说的作者是谁。 |
| 48 | +工具: 搜索 |
| 49 | +工具输入: 哈利波特的作者是谁? |
| 50 | +观察: 哈利波特系列的作者是J.K.罗琳(J.K. Rowling)。 |
| 51 | +思考: 根据搜索结果,哈利波特系列的作者是J.K.罗琳。所以最终答案是J.K.罗琳。 |
| 52 | +最终答案: J.K.罗琳 |
| 53 | +## |
| 54 | +问题: {query} |
| 55 | +思考:{transcript} |
| 56 | +""" |
| 57 | + |
| 58 | +# yapf: disable |
| 59 | +parser = argparse.ArgumentParser() |
| 60 | +parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.") |
| 61 | +parser.add_argument("--index_name", default='dureader_index', type=str, help="The ann index name of ANN.") |
| 62 | +parser.add_argument("--search_engine", choices=['faiss', 'milvus'], default="faiss", help="The type of ANN search engine.") |
| 63 | +parser.add_argument("--retriever", choices=['dense', 'SerpAPI'], default="dense", help="The type of Retriever.") |
| 64 | +parser.add_argument("--max_seq_len_query", default=64, type=int, help="The maximum total length of query after tokenization.") |
| 65 | +parser.add_argument("--max_seq_len_passage", default=256, type=int, help="The maximum total length of passage after tokenization.") |
| 66 | +parser.add_argument("--retriever_batch_size", default=16, type=int, help="The batch size of retriever to extract passage embedding for building ANN index.") |
| 67 | +parser.add_argument("--query_embedding_model", default="rocketqa-zh-base-query-encoder", type=str, help="The query_embedding_model path") |
| 68 | +parser.add_argument("--passage_embedding_model", default="rocketqa-zh-base-query-encoder", type=str, help="The passage_embedding_model path") |
| 69 | +parser.add_argument("--params_path", default="checkpoints/model_40/model_state.pdparams", type=str, help="The checkpoint path") |
| 70 | +parser.add_argument("--embedding_dim", default=768, type=int, help="The embedding_dim of index") |
| 71 | +parser.add_argument("--search_api_key", default=None, type=str, help="The SerpAPI key.") |
| 72 | +parser.add_argument('--embed_title', default=False, type=bool, help="The title to be embedded into embedding") |
| 73 | +parser.add_argument('--model_type', choices=['ernie_search', 'ernie', 'bert', 'neural_search'], default="ernie", help="the ernie model types") |
| 74 | +parser.add_argument('--llm_name', choices=['ernie-bot', 'THUDM/chatglm-6b', "gpt-3.5-turbo", "gpt-4"], default="THUDM/chatglm-6b", help="The chatbot models ") |
| 75 | +parser.add_argument("--api_key", default=None, type=str, help="The API Key.") |
| 76 | +parser.add_argument("--secret_key", default=None, type=str, help="The secret key.") |
| 77 | +args = parser.parse_args() |
| 78 | +# yapf: enable |
| 79 | + |
| 80 | + |
| 81 | +def indexing_files(retriever, document_store, filepaths, chunk_size): |
| 82 | + try: |
| 83 | + text_converter = TextConverter() |
| 84 | + pdf_converter = PDFToTextConverter() |
| 85 | + doc_converter = DocxToTextConverter() |
| 86 | + |
| 87 | + text_splitter = CharacterTextSplitter(separator="\f", chunk_size=chunk_size, chunk_overlap=0, filters=["\n"]) |
| 88 | + pdf_splitter = CharacterTextSplitter( |
| 89 | + separator="\f", |
| 90 | + chunk_size=chunk_size, |
| 91 | + chunk_overlap=0, |
| 92 | + filters=['([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))'], |
| 93 | + ) |
| 94 | + file_classifier = FileTypeClassifier() |
| 95 | + indexing_pipeline = Pipeline() |
| 96 | + indexing_pipeline.add_node(component=file_classifier, name="file_classifier", inputs=["File"]) |
| 97 | + indexing_pipeline.add_node(component=doc_converter, name="DocConverter", inputs=["file_classifier.output_4"]) |
| 98 | + indexing_pipeline.add_node(component=text_converter, name="TextConverter", inputs=["file_classifier.output_1"]) |
| 99 | + indexing_pipeline.add_node(component=pdf_converter, name="PDFConverter", inputs=["file_classifier.output_2"]) |
| 100 | + |
| 101 | + indexing_pipeline.add_node( |
| 102 | + component=text_splitter, name="TextSplitter", inputs=["TextConverter", "DocConverter"] |
| 103 | + ) |
| 104 | + indexing_pipeline.add_node(component=pdf_splitter, name="PDFSplitter", inputs=["PDFConverter"]) |
| 105 | + indexing_pipeline.add_node(component=retriever, name="Retriever", inputs=["TextSplitter", "PDFSplitter"]) |
| 106 | + indexing_pipeline.add_node(component=document_store, name="DocumentStore", inputs=["Retriever"]) |
| 107 | + files = glob.glob(filepaths + "/*.*", recursive=True) |
| 108 | + indexing_pipeline.run(file_paths=files) |
| 109 | + except Exception as e: |
| 110 | + print(e) |
| 111 | + pass |
| 112 | + |
| 113 | + |
| 114 | +def get_faiss_retriever(use_gpu): |
| 115 | + faiss_document_store = "faiss_document_store.db" |
| 116 | + if os.path.exists(args.index_name) and os.path.exists(faiss_document_store): |
| 117 | + # connect to existed FAISS Index |
| 118 | + document_store = FAISSDocumentStore.load(args.index_name) |
| 119 | + retriever = DensePassageRetriever( |
| 120 | + document_store=document_store, |
| 121 | + query_embedding_model=args.query_embedding_model, |
| 122 | + passage_embedding_model=args.passage_embedding_model, |
| 123 | + params_path=args.params_path, |
| 124 | + output_emb_size=args.embedding_dim if args.model_type in ["ernie_search", "neural_search"] else None, |
| 125 | + max_seq_len_query=args.max_seq_len_query, |
| 126 | + max_seq_len_passage=args.max_seq_len_passage, |
| 127 | + batch_size=args.retriever_batch_size, |
| 128 | + use_gpu=use_gpu, |
| 129 | + embed_title=args.embed_title, |
| 130 | + ) |
| 131 | + else: |
| 132 | + dureader_data = "https://paddlenlp.bj.bcebos.com/applications/dureader_dev.zip" |
| 133 | + zip_dir = "data/dureader_dev" |
| 134 | + fetch_archive_from_http(url=dureader_data, output_dir=zip_dir) |
| 135 | + |
| 136 | + document_store = FAISSDocumentStore(embedding_dim=args.embedding_dim, faiss_index_factory_str="Flat") |
| 137 | + retriever = DensePassageRetriever( |
| 138 | + document_store=document_store, |
| 139 | + query_embedding_model=args.query_embedding_model, |
| 140 | + passage_embedding_model=args.passage_embedding_model, |
| 141 | + params_path=args.params_path, |
| 142 | + output_emb_size=args.embedding_dim if args.model_type in ["ernie_search", "neural_search"] else None, |
| 143 | + max_seq_len_query=args.max_seq_len_query, |
| 144 | + max_seq_len_passage=args.max_seq_len_passage, |
| 145 | + batch_size=args.retriever_batch_size, |
| 146 | + use_gpu=use_gpu, |
| 147 | + embed_title=args.embed_title, |
| 148 | + top_k=5, |
| 149 | + ) |
| 150 | + filepaths = "data/dureader_dev/dureader_dev" |
| 151 | + indexing_files(retriever, document_store, filepaths, chunk_size=500) |
| 152 | + document_store.save(args.index_name) |
| 153 | + return retriever |
| 154 | + |
| 155 | + |
| 156 | +def search_and_action_example(web_retriever): |
| 157 | + |
| 158 | + qa_template = PromptTemplate( |
| 159 | + name="文档问答", |
| 160 | + prompt_text="使用以下段落作为来源回答以下问题。" |
| 161 | + "答案应该简短,最多几个字。\n" |
| 162 | + "段落:\n{documents}\n" |
| 163 | + "问题: {query}\n\n" |
| 164 | + "说明: 考虑以上所有段落及其相应的分数,得出答案。 " |
| 165 | + "虽然一个段落可能得分很高, " |
| 166 | + "但重要的是要考虑同一候选答案的所有段落,以便准确回答。\n\n" |
| 167 | + "在考虑了所有的可能性之后,最终答案是:\n", |
| 168 | + ) |
| 169 | + pn = PromptNode( |
| 170 | + args.llm_name, |
| 171 | + max_length=512, |
| 172 | + default_prompt_template=qa_template, |
| 173 | + api_key=args.api_key, |
| 174 | + secret_key=args.secret_key, |
| 175 | + ) |
| 176 | + |
| 177 | + pipeline = WebQAPipeline(retriever=web_retriever, prompt_node=pn) |
| 178 | + |
| 179 | + prompt_node = PromptNode( |
| 180 | + args.llm_name, max_length=512, api_key=args.api_key, secret_key=args.secret_key, stop_words=["观察: "] |
| 181 | + ) |
| 182 | + |
| 183 | + web_qa_tool = Tool( |
| 184 | + name="搜索", |
| 185 | + pipeline_or_node=pipeline, |
| 186 | + description="当你需要用谷歌搜索问题时很有用。", |
| 187 | + output_variable="results", |
| 188 | + ) |
| 189 | + few_shot_agent_template = PromptTemplate("few-shot-react", prompt_text=few_shot_prompt) |
| 190 | + # Time to initialize the Agent specifying the PromptNode to use and the Tools |
| 191 | + agent = Agent( |
| 192 | + prompt_node=prompt_node, |
| 193 | + prompt_template=few_shot_agent_template, |
| 194 | + tools_manager=ToolsManager( |
| 195 | + tools=[web_qa_tool], |
| 196 | + tool_pattern=r"工具:\s*(\w+)\s*工具输入:\s*(?:\"([\s\S]*?)\"|((?:.|\n)*))\s*", |
| 197 | + observation_prefix="观察: ", |
| 198 | + llm_prefix="思考: ", |
| 199 | + ), |
| 200 | + max_steps=8, |
| 201 | + final_answer_pattern=r"最终答案\s*:\s*(.*)", |
| 202 | + observation_prefix="观察: ", |
| 203 | + llm_prefix="思考: ", |
| 204 | + ) |
| 205 | + hotpot_questions = ["范冰冰的身高是多少?", "武则天传位给了谁?"] |
| 206 | + for question in hotpot_questions: |
| 207 | + result = agent.run(query=question) |
| 208 | + print(f"\n{result['transcript']}") |
| 209 | + |
| 210 | + |
| 211 | +if __name__ == "__main__": |
| 212 | + if args.retriever == "dense": |
| 213 | + use_gpu = True if args.device == "gpu" else False |
| 214 | + web_retriever = get_faiss_retriever(use_gpu) |
| 215 | + else: |
| 216 | + # https://serpapi.com/dashboard |
| 217 | + web_retriever = WebRetriever(api_key=args.search_api_key, engine="bing", top_search_results=2) |
| 218 | + search_and_action_example(web_retriever) |
0 commit comments