Skip to content

Commit 546fca6

Browse files
authored
[Paddle-pipelines] Add ReAct zh examples (#6095)
* Add ReAct zh examples * refactor chatglm * Add chatgpt agents example * Add Chinese prompts * Fix agents dead loop bug * Add semantic search for ReAct * Update ReAct_example_cn * Add ernie_bot invocation * Add unitests
1 parent c3b921b commit 546fca6

File tree

24 files changed

+563
-56
lines changed

24 files changed

+563
-56
lines changed

pipelines/examples/agents/ReAct_example.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,28 +77,31 @@
7777
Final Answer: Gainsville, Florida
7878
##
7979
Question: {query}
80-
Thought:
80+
Thought:{transcript}
8181
"""
8282

8383
# yapf: disable
8484
parser = argparse.ArgumentParser()
8585
parser.add_argument("--search_api_key", default=None, type=str, help="The SerpAPI key.")
86+
parser.add_argument('--llm_name', choices=['THUDM/chatglm-6b', "THUDM/chatglm-6b-v1.1", "gpt-3.5-turbo", "gpt-4"], default="THUDM/chatglm-6b-v1.1", help="The chatbot models ")
87+
parser.add_argument("--api_key", default=None, type=str, help="The API Key.")
8688
args = parser.parse_args()
8789
# yapf: enable
8890

8991

9092
def search_and_action_example():
9193
pn = PromptNode(
92-
"THUDM/chatglm-6b",
93-
max_length=512,
94+
args.llm_name,
95+
max_length=256,
96+
api_key=args.api_key,
9497
default_prompt_template="question-answering-with-document-scores",
9598
)
9699

97100
# https://serpapi.com/dashboard
98101
web_retriever = WebRetriever(api_key=args.search_api_key, top_search_results=2)
99102
pipeline = WebQAPipeline(retriever=web_retriever, prompt_node=pn)
100103

101-
prompt_node = PromptNode("THUDM/chatglm-6b", max_length=512, stop_words=["Observation:"])
104+
prompt_node = PromptNode(args.llm_name, api_key=args.api_key, max_length=512, stop_words=["Observation:"])
102105

103106
web_qa_tool = Tool(
104107
name="Search",
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import glob
17+
import os
18+
19+
from pipelines.agents import Agent, Tool
20+
from pipelines.agents.base import ToolsManager
21+
from pipelines.document_stores import FAISSDocumentStore
22+
from pipelines.nodes import (
23+
CharacterTextSplitter,
24+
DensePassageRetriever,
25+
DocxToTextConverter,
26+
FileTypeClassifier,
27+
PDFToTextConverter,
28+
PromptNode,
29+
TextConverter,
30+
WebRetriever,
31+
)
32+
from pipelines.nodes.prompt.prompt_template import PromptTemplate
33+
from pipelines.pipelines import Pipeline, WebQAPipeline
34+
from pipelines.utils import fetch_archive_from_http
35+
36+
few_shot_prompt = """
37+
你是一个乐于助人、知识渊博的人工智能助手。为了实现正确回答复杂问题的目标,您可以使用以下工具:
38+
搜索: 当你需要用谷歌搜索问题时很有用。你应该问一些有针对性的问题,例如,谁是安东尼·迪雷尔的兄弟?
39+
要回答问题,你需要经历多个步骤,包括逐步思考和选择合适的工具及其输入;工具将以观察作为回应。当您准备好接受最终答案时,回答"最终答案":
40+
示例:
41+
##
42+
问题: 哈利波特的作者是谁?
43+
思考: 让我们一步一步地思考。要回答这个问题,我们首先需要了解哈利波特是什么。
44+
工具: 搜索
45+
工具输入: 哈利波特是什么?
46+
观察: 哈利波特是一系列非常受欢迎的魔幻小说,以及后来的电影和衍生作品。
47+
思考: 我们了解到哈利波特是一系列魔幻小说。现在我们需要找到这些小说的作者是谁。
48+
工具: 搜索
49+
工具输入: 哈利波特的作者是谁?
50+
观察: 哈利波特系列的作者是J.K.罗琳(J.K. Rowling)。
51+
思考: 根据搜索结果,哈利波特系列的作者是J.K.罗琳。所以最终答案是J.K.罗琳。
52+
最终答案: J.K.罗琳
53+
##
54+
问题: {query}
55+
思考:{transcript}
56+
"""
57+
58+
# yapf: disable
59+
parser = argparse.ArgumentParser()
60+
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.")
61+
parser.add_argument("--index_name", default='dureader_index', type=str, help="The ann index name of ANN.")
62+
parser.add_argument("--search_engine", choices=['faiss', 'milvus'], default="faiss", help="The type of ANN search engine.")
63+
parser.add_argument("--retriever", choices=['dense', 'SerpAPI'], default="dense", help="The type of Retriever.")
64+
parser.add_argument("--max_seq_len_query", default=64, type=int, help="The maximum total length of query after tokenization.")
65+
parser.add_argument("--max_seq_len_passage", default=256, type=int, help="The maximum total length of passage after tokenization.")
66+
parser.add_argument("--retriever_batch_size", default=16, type=int, help="The batch size of retriever to extract passage embedding for building ANN index.")
67+
parser.add_argument("--query_embedding_model", default="rocketqa-zh-base-query-encoder", type=str, help="The query_embedding_model path")
68+
parser.add_argument("--passage_embedding_model", default="rocketqa-zh-base-query-encoder", type=str, help="The passage_embedding_model path")
69+
parser.add_argument("--params_path", default="checkpoints/model_40/model_state.pdparams", type=str, help="The checkpoint path")
70+
parser.add_argument("--embedding_dim", default=768, type=int, help="The embedding_dim of index")
71+
parser.add_argument("--search_api_key", default=None, type=str, help="The SerpAPI key.")
72+
parser.add_argument('--embed_title', default=False, type=bool, help="The title to be embedded into embedding")
73+
parser.add_argument('--model_type', choices=['ernie_search', 'ernie', 'bert', 'neural_search'], default="ernie", help="the ernie model types")
74+
parser.add_argument('--llm_name', choices=['ernie-bot', 'THUDM/chatglm-6b', "gpt-3.5-turbo", "gpt-4"], default="THUDM/chatglm-6b", help="The chatbot models ")
75+
parser.add_argument("--api_key", default=None, type=str, help="The API Key.")
76+
parser.add_argument("--secret_key", default=None, type=str, help="The secret key.")
77+
args = parser.parse_args()
78+
# yapf: enable
79+
80+
81+
def indexing_files(retriever, document_store, filepaths, chunk_size):
82+
try:
83+
text_converter = TextConverter()
84+
pdf_converter = PDFToTextConverter()
85+
doc_converter = DocxToTextConverter()
86+
87+
text_splitter = CharacterTextSplitter(separator="\f", chunk_size=chunk_size, chunk_overlap=0, filters=["\n"])
88+
pdf_splitter = CharacterTextSplitter(
89+
separator="\f",
90+
chunk_size=chunk_size,
91+
chunk_overlap=0,
92+
filters=['([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))'],
93+
)
94+
file_classifier = FileTypeClassifier()
95+
indexing_pipeline = Pipeline()
96+
indexing_pipeline.add_node(component=file_classifier, name="file_classifier", inputs=["File"])
97+
indexing_pipeline.add_node(component=doc_converter, name="DocConverter", inputs=["file_classifier.output_4"])
98+
indexing_pipeline.add_node(component=text_converter, name="TextConverter", inputs=["file_classifier.output_1"])
99+
indexing_pipeline.add_node(component=pdf_converter, name="PDFConverter", inputs=["file_classifier.output_2"])
100+
101+
indexing_pipeline.add_node(
102+
component=text_splitter, name="TextSplitter", inputs=["TextConverter", "DocConverter"]
103+
)
104+
indexing_pipeline.add_node(component=pdf_splitter, name="PDFSplitter", inputs=["PDFConverter"])
105+
indexing_pipeline.add_node(component=retriever, name="Retriever", inputs=["TextSplitter", "PDFSplitter"])
106+
indexing_pipeline.add_node(component=document_store, name="DocumentStore", inputs=["Retriever"])
107+
files = glob.glob(filepaths + "/*.*", recursive=True)
108+
indexing_pipeline.run(file_paths=files)
109+
except Exception as e:
110+
print(e)
111+
pass
112+
113+
114+
def get_faiss_retriever(use_gpu):
115+
faiss_document_store = "faiss_document_store.db"
116+
if os.path.exists(args.index_name) and os.path.exists(faiss_document_store):
117+
# connect to existed FAISS Index
118+
document_store = FAISSDocumentStore.load(args.index_name)
119+
retriever = DensePassageRetriever(
120+
document_store=document_store,
121+
query_embedding_model=args.query_embedding_model,
122+
passage_embedding_model=args.passage_embedding_model,
123+
params_path=args.params_path,
124+
output_emb_size=args.embedding_dim if args.model_type in ["ernie_search", "neural_search"] else None,
125+
max_seq_len_query=args.max_seq_len_query,
126+
max_seq_len_passage=args.max_seq_len_passage,
127+
batch_size=args.retriever_batch_size,
128+
use_gpu=use_gpu,
129+
embed_title=args.embed_title,
130+
)
131+
else:
132+
dureader_data = "https://paddlenlp.bj.bcebos.com/applications/dureader_dev.zip"
133+
zip_dir = "data/dureader_dev"
134+
fetch_archive_from_http(url=dureader_data, output_dir=zip_dir)
135+
136+
document_store = FAISSDocumentStore(embedding_dim=args.embedding_dim, faiss_index_factory_str="Flat")
137+
retriever = DensePassageRetriever(
138+
document_store=document_store,
139+
query_embedding_model=args.query_embedding_model,
140+
passage_embedding_model=args.passage_embedding_model,
141+
params_path=args.params_path,
142+
output_emb_size=args.embedding_dim if args.model_type in ["ernie_search", "neural_search"] else None,
143+
max_seq_len_query=args.max_seq_len_query,
144+
max_seq_len_passage=args.max_seq_len_passage,
145+
batch_size=args.retriever_batch_size,
146+
use_gpu=use_gpu,
147+
embed_title=args.embed_title,
148+
top_k=5,
149+
)
150+
filepaths = "data/dureader_dev/dureader_dev"
151+
indexing_files(retriever, document_store, filepaths, chunk_size=500)
152+
document_store.save(args.index_name)
153+
return retriever
154+
155+
156+
def search_and_action_example(web_retriever):
157+
158+
qa_template = PromptTemplate(
159+
name="文档问答",
160+
prompt_text="使用以下段落作为来源回答以下问题。"
161+
"答案应该简短,最多几个字。\n"
162+
"段落:\n{documents}\n"
163+
"问题: {query}\n\n"
164+
"说明: 考虑以上所有段落及其相应的分数,得出答案。 "
165+
"虽然一个段落可能得分很高, "
166+
"但重要的是要考虑同一候选答案的所有段落,以便准确回答。\n\n"
167+
"在考虑了所有的可能性之后,最终答案是:\n",
168+
)
169+
pn = PromptNode(
170+
args.llm_name,
171+
max_length=512,
172+
default_prompt_template=qa_template,
173+
api_key=args.api_key,
174+
secret_key=args.secret_key,
175+
)
176+
177+
pipeline = WebQAPipeline(retriever=web_retriever, prompt_node=pn)
178+
179+
prompt_node = PromptNode(
180+
args.llm_name, max_length=512, api_key=args.api_key, secret_key=args.secret_key, stop_words=["观察: "]
181+
)
182+
183+
web_qa_tool = Tool(
184+
name="搜索",
185+
pipeline_or_node=pipeline,
186+
description="当你需要用谷歌搜索问题时很有用。",
187+
output_variable="results",
188+
)
189+
few_shot_agent_template = PromptTemplate("few-shot-react", prompt_text=few_shot_prompt)
190+
# Time to initialize the Agent specifying the PromptNode to use and the Tools
191+
agent = Agent(
192+
prompt_node=prompt_node,
193+
prompt_template=few_shot_agent_template,
194+
tools_manager=ToolsManager(
195+
tools=[web_qa_tool],
196+
tool_pattern=r"工具:\s*(\w+)\s*工具输入:\s*(?:\"([\s\S]*?)\"|((?:.|\n)*))\s*",
197+
observation_prefix="观察: ",
198+
llm_prefix="思考: ",
199+
),
200+
max_steps=8,
201+
final_answer_pattern=r"最终答案\s*:\s*(.*)",
202+
observation_prefix="观察: ",
203+
llm_prefix="思考: ",
204+
)
205+
hotpot_questions = ["范冰冰的身高是多少?", "武则天传位给了谁?"]
206+
for question in hotpot_questions:
207+
result = agent.run(query=question)
208+
print(f"\n{result['transcript']}")
209+
210+
211+
if __name__ == "__main__":
212+
if args.retriever == "dense":
213+
use_gpu = True if args.device == "gpu" else False
214+
web_retriever = get_faiss_retriever(use_gpu)
215+
else:
216+
# https://serpapi.com/dashboard
217+
web_retriever = WebRetriever(api_key=args.search_api_key, engine="bing", top_search_results=2)
218+
search_and_action_example(web_retriever)

pipelines/pipelines/agents/agent_step.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def __init__(
3636
final_answer_pattern: Optional[str] = None,
3737
prompt_node_response: str = "",
3838
transcript: str = "",
39+
observation_prefix: str = "Observation: ",
40+
llm_prefix: str = "Thought: ",
3941
):
4042
"""
4143
:param current_step: The current step in the execution of the agent.
@@ -50,6 +52,8 @@ def __init__(
5052
self.final_answer_pattern = final_answer_pattern or r"^([\s\S]+)$"
5153
self.prompt_node_response = prompt_node_response
5254
self.transcript = transcript
55+
self.observation_prefix = observation_prefix
56+
self.llm_prefix = llm_prefix
5357

5458
def create_next_step(self, prompt_node_response: Any, current_step: Optional[int] = None) -> AgentStep:
5559
"""
@@ -119,7 +123,7 @@ def completed(self, observation: Optional[str]) -> None:
119123
:param observation: received observation from the Agent environment.
120124
"""
121125
self.transcript += (
122-
f"{self.prompt_node_response}\nObservation: {observation}\nThought:"
126+
f"{self.prompt_node_response}\n{self.observation_prefix} {observation}\n{self.llm_prefix} "
123127
if observation
124128
else self.prompt_node_response
125129
)
@@ -149,7 +153,6 @@ def parse_final_answer(self) -> Optional[str]:
149153
"""
150154
# Search for a match with the final answer pattern in the prompt node response
151155
final_answer_match = re.search(self.final_answer_pattern, self.prompt_node_response)
152-
153156
if final_answer_match:
154157
# If a match is found, get the first group (i.e., the content inside the parentheses of the regex pattern)
155158
final_answer = final_answer_match.group(1)

pipelines/pipelines/agents/base.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ def __init__(
134134
self,
135135
tools: Optional[List[Tool]] = None,
136136
tool_pattern: str = r"Tool:\s*(\w+)\s*Tool Input:\s*(?:\"([\s\S]*?)\"|((?:.|\n)*))\s*",
137+
observation_prefix: str = "Observation: ",
138+
llm_prefix: str = "Thought: ",
137139
):
138140
"""
139141
:param tools: A list of tools to add to the ToolManager. Each tool must have a unique name.
@@ -143,6 +145,8 @@ def __init__(
143145
self._tools: Dict[str, Tool] = {tool.name: tool for tool in tools} if tools else {}
144146
self.tool_pattern = tool_pattern
145147
self.callback_manager = Events(("on_tool_start", "on_tool_finish", "on_tool_error"))
148+
self.observation_prefix = observation_prefix
149+
self.llm_prefix = llm_prefix
146150

147151
def add_tool(self, tool: Tool):
148152
"""
@@ -193,8 +197,8 @@ def run_tool(self, llm_response: str, params: Optional[Dict[str, Any]] = None) -
193197
tool_result = tool.run(tool_input, params)
194198
self.callback_manager.on_tool_finish(
195199
tool_result,
196-
observation_prefix="Observation: ",
197-
llm_prefix="Thought: ",
200+
observation_prefix=f"{self.observation_prefix}",
201+
llm_prefix="{self.llm_prefix}",
198202
color=tool.logging_color,
199203
)
200204
except Exception as e:
@@ -241,6 +245,8 @@ def __init__(
241245
prompt_parameters_resolver: Optional[Callable] = None,
242246
max_steps: int = 8,
243247
final_answer_pattern: str = r"Final Answer\s*:\s*(.*)",
248+
observation_prefix: str = "Observation: ",
249+
llm_prefix: str = "Thought: ",
244250
):
245251
"""
246252
Creates an Agent instance.
@@ -269,6 +275,8 @@ def __init__(
269275
self.prompt_node = prompt_node
270276
prompt_template = prompt_template or "zero-shot-react"
271277
resolved_prompt_template = prompt_node.get_prompt_template(prompt_template)
278+
self.observation_prefix = observation_prefix
279+
self.llm_prefix = llm_prefix
272280
if not resolved_prompt_template:
273281
raise ValueError(
274282
f"Prompt template '{prompt_template}' not found. Please check the spelling of the template name."
@@ -394,18 +402,15 @@ def run(
394402
def _step(self, query: str, current_step: AgentStep, params: Optional[dict] = None):
395403
# plan next step using the LLM
396404
prompt_node_response = self._plan(query, current_step)
397-
398405
# from the LLM response, create the next step
399406
next_step = current_step.create_next_step(prompt_node_response)
400407
self.callback_manager.on_agent_step(next_step)
401-
402408
# run the tool selected by the LLM
403409
observation = self.tm.run_tool(next_step.prompt_node_response, params) if not next_step.is_last() else None
404410

405411
# save the input, output and observation to memory (if memory is enabled)
406412
memory_data = self.prepare_data_for_memory(input=query, output=prompt_node_response, observation=observation)
407413
self.memory.save(data=memory_data)
408-
409414
# update the next step with the observation
410415
next_step.completed(observation)
411416
return next_step

pipelines/pipelines/nodes/llm/chatglm.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@
2121

2222

2323
class ChatGLMBot(BaseComponent):
24-
def __init__(self, batch_size: int = 2, max_seq_length: int = 1024, tgt_length: int = 256, **kwargs):
24+
def __init__(
25+
self,
26+
model_name_or_path="THUDM/chatglm-6b-v1.1",
27+
batch_size: int = 2,
28+
max_seq_length: int = 2048,
29+
tgt_length: int = 2048,
30+
**kwargs
31+
):
2532
"""
2633
Initialize the ChatGLMBot instance.
2734
@@ -32,18 +39,23 @@ def __init__(self, batch_size: int = 2, max_seq_length: int = 1024, tgt_length:
3239
self.kwargs = kwargs
3340
self.chatglm = Taskflow(
3441
"text2text_generation",
42+
model=model_name_or_path,
3543
batch_size=batch_size,
3644
max_seq_length=max_seq_length,
3745
tgt_length=tgt_length,
3846
**self.kwargs,
3947
)
4048

49+
def predict(self, query, stream=False):
50+
result = self.chatglm(query)
51+
return result
52+
4153
def run(self, query, stream=False):
4254
"""
4355
Using the chatbot to generate the answers
4456
:param query: The user's input/query to be sent to the chatGLM.
4557
:param stream: Whether to use streaming mode when making the request. Currently not in use. Defaults to False.
4658
"""
4759
logger.info(query)
48-
result = self.chatglm(query)
60+
result = self.predict(query=query, stream=stream)
4961
return result, "output_1"

0 commit comments

Comments
 (0)