Skip to content

Commit bb949ef

Browse files
authored
feat: Add Agent state-mapping parameters to MCPTool (#2501)
* Add Agent state-mapping parameters to MCPTool * PR feedback, add Agent state MCPTool integration test
1 parent dee4f0a commit bb949ef

File tree

2 files changed

+180
-5
lines changed

2 files changed

+180
-5
lines changed

integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,11 @@ class MCPTool(Tool):
845845
- The JSON contains the structured response from the MCP server
846846
- Use json.loads() to parse the response into a dictionary
847847
848+
State-mapping support:
849+
- MCPTool supports state-mapping parameters (`outputs_to_string`, `inputs_from_state`, `outputs_to_state`)
850+
- These enable integration with Agent state for automatic parameter injection and output handling
851+
- See the `__init__` method documentation for details on each parameter
852+
848853
Example using Streamable HTTP:
849854
```python
850855
import json
@@ -902,6 +907,9 @@ def __init__(
902907
connection_timeout: int = 30,
903908
invocation_timeout: int = 30,
904909
eager_connect: bool = False,
910+
outputs_to_string: dict[str, Any] | None = None,
911+
inputs_from_state: dict[str, str] | None = None,
912+
outputs_to_state: dict[str, dict[str, Any]] | None = None,
905913
):
906914
"""
907915
Initialize the MCP tool.
@@ -914,6 +922,17 @@ def __init__(
914922
:param eager_connect: If True, connect to server during initialization.
915923
If False (default), defer connection until warm_up or first tool use,
916924
whichever comes first.
925+
:param outputs_to_string: Optional dictionary defining how tool outputs should be converted into a string.
926+
If the source is provided only the specified output key is sent to the handler.
927+
If the source is omitted the whole tool result is sent to the handler.
928+
Example: `{"source": "docs", "handler": my_custom_function}`
929+
:param inputs_from_state: Optional dictionary mapping state keys to tool parameter names.
930+
Example: `{"repository": "repo"}` maps state's "repository" to tool's "repo" parameter.
931+
:param outputs_to_state: Optional dictionary defining how tool outputs map to keys within state as well as
932+
optional handlers. If the source is provided only the specified output key is sent
933+
to the handler.
934+
Example with source: `{"documents": {"source": "docs", "handler": custom_handler}}`
935+
Example without source: `{"documents": {"handler": custom_handler}}`
917936
:raises MCPConnectionError: If connection to the server fails
918937
:raises MCPToolNotFoundError: If no tools are available or the requested tool is not found
919938
:raises TimeoutError: If connection times out
@@ -924,6 +943,9 @@ def __init__(
924943
self._connection_timeout = connection_timeout
925944
self._invocation_timeout = invocation_timeout
926945
self._eager_connect = eager_connect
946+
self._outputs_to_string = outputs_to_string
947+
self._inputs_from_state = inputs_from_state
948+
self._outputs_to_state = outputs_to_state
927949
self._client: MCPClient | None = None
928950
self._worker: _MCPClientSessionManager | None = None
929951
self._lock = threading.RLock()
@@ -934,7 +956,15 @@ def __init__(
934956
# without discovering the remote schema during validation.
935957
# Tool parameters/schema will be replaced with the correct schema (from the MCP server) on first use.
936958
params = {"type": "object", "properties": {}, "additionalProperties": True}
937-
super().__init__(name=name, description=description or "", parameters=params, function=self._invoke_tool)
959+
super().__init__(
960+
name=name,
961+
description=description or "",
962+
parameters=params,
963+
function=self._invoke_tool,
964+
outputs_to_string=outputs_to_string,
965+
inputs_from_state=inputs_from_state,
966+
outputs_to_state=outputs_to_state,
967+
)
938968
return
939969

940970
logger.debug(f"TOOL: Initializing MCPTool '{name}'")
@@ -950,7 +980,19 @@ def __init__(
950980
description=description or tool_info.description or "",
951981
parameters=tool_info.inputSchema,
952982
function=self._invoke_tool,
983+
outputs_to_string=outputs_to_string,
984+
inputs_from_state=inputs_from_state,
985+
outputs_to_state=outputs_to_state,
953986
)
987+
988+
# Remove inputs_from_state keys from parameters schema if present
989+
# This matches the behavior of ComponentTool
990+
if inputs_from_state and "properties" in self.parameters:
991+
for key in inputs_from_state.values():
992+
self.parameters["properties"].pop(key, None)
993+
if "required" in self.parameters and key in self.parameters["required"]:
994+
self.parameters["required"].remove(key)
995+
954996
logger.debug(f"TOOL: Initialization complete for '{name}'")
955997

956998
except Exception as e:
@@ -1069,13 +1111,21 @@ def warm_up(self) -> None:
10691111
tool = self._connect_and_initialize(self.name)
10701112
self.parameters = tool.inputSchema
10711113

1114+
# Remove inputs_from_state keys from parameters schema if present
1115+
# This matches the behavior of ComponentTool
1116+
if self._inputs_from_state and "properties" in self.parameters:
1117+
for key in self._inputs_from_state.values():
1118+
self.parameters["properties"].pop(key, None)
1119+
if "required" in self.parameters and key in self.parameters["required"]:
1120+
self.parameters["required"].remove(key)
1121+
10721122
def to_dict(self) -> dict[str, Any]:
10731123
"""
10741124
Serializes the MCPTool to a dictionary.
10751125
10761126
The serialization preserves all information needed to recreate the tool,
1077-
including server connection parameters and timeout settings. Note that the
1078-
active connection is not maintained.
1127+
including server connection parameters, timeout settings, and state-mapping parameters.
1128+
Note that the active connection is not maintained.
10791129
10801130
:returns: Dictionary with serialized data in the format:
10811131
`{"type": fully_qualified_class_name, "data": {parameters}}`
@@ -1087,6 +1137,9 @@ def to_dict(self) -> dict[str, Any]:
10871137
"connection_timeout": self._connection_timeout,
10881138
"invocation_timeout": self._invocation_timeout,
10891139
"eager_connect": self._eager_connect,
1140+
"outputs_to_string": self._outputs_to_string,
1141+
"inputs_from_state": self._inputs_from_state,
1142+
"outputs_to_state": self._outputs_to_state,
10901143
}
10911144
return {
10921145
"type": generate_qualified_class_name(type(self)),
@@ -1099,8 +1152,8 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":
10991152
Deserializes the MCPTool from a dictionary.
11001153
11011154
This method reconstructs an MCPTool instance from a serialized dictionary,
1102-
including recreating the server_info object. A new connection will be established
1103-
to the MCP server during initialization.
1155+
including recreating the server_info object and state-mapping parameters.
1156+
A new connection will be established to the MCP server during initialization.
11041157
11051158
:param data: Dictionary containing serialized tool data
11061159
:returns: A fully initialized MCPTool instance
@@ -1121,6 +1174,11 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":
11211174
invocation_timeout = inner_data.get("invocation_timeout", 30)
11221175
eager_connect = inner_data.get("eager_connect", False) # because False is the default
11231176

1177+
# Handle state-mapping parameters
1178+
outputs_to_string = inner_data.get("outputs_to_string")
1179+
inputs_from_state = inner_data.get("inputs_from_state")
1180+
outputs_to_state = inner_data.get("outputs_to_state")
1181+
11241182
# Create a new MCPTool instance with the deserialized parameters
11251183
# This will establish a new connection to the MCP server
11261184
return cls(
@@ -1130,6 +1188,9 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":
11301188
connection_timeout=connection_timeout,
11311189
invocation_timeout=invocation_timeout,
11321190
eager_connect=eager_connect,
1191+
outputs_to_string=outputs_to_string,
1192+
inputs_from_state=inputs_from_state,
1193+
outputs_to_state=outputs_to_state,
11331194
)
11341195

11351196
def close(self):

integrations/mcp/tests/test_mcp_tool.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,69 @@ def test_mcp_tool_serde(self, mcp_tool_cleanup):
133133

134134
assert isinstance(new_tool._server_info, InMemoryServerInfo)
135135

136+
def test_mcp_tool_state_mapping_parameters(self, mcp_tool_cleanup):
137+
"""Test that MCPTool correctly initializes with state-mapping parameters."""
138+
server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server)
139+
140+
# Create tool with state-mapping parameters
141+
# Map state key "state_a" to tool parameter "a"
142+
tool = MCPTool(
143+
name="add",
144+
server_info=server_info,
145+
eager_connect=False,
146+
outputs_to_string={"source": "result", "handler": str},
147+
inputs_from_state={"state_a": "a"},
148+
outputs_to_state={"result": {"source": "output", "handler": str}},
149+
)
150+
mcp_tool_cleanup(tool)
151+
152+
# Verify the parameters are stored correctly
153+
assert tool._outputs_to_string == {"source": "result", "handler": str}
154+
assert tool._inputs_from_state == {"state_a": "a"}
155+
assert tool._outputs_to_state == {"result": {"source": "output", "handler": str}}
156+
157+
# Warm up the tool to trigger schema adjustment
158+
tool.warm_up()
159+
160+
# Verify that "a" was removed from parameters since it's in inputs_from_state
161+
assert "a" not in tool.parameters["properties"]
162+
assert "a" not in tool.parameters.get("required", [])
163+
# Verify that "b" is still present (not removed)
164+
assert "b" in tool.parameters["properties"]
165+
assert "b" in tool.parameters["required"]
166+
167+
def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup):
168+
"""Test serialization and deserialization of MCPTool with state-mapping parameters."""
169+
server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server)
170+
171+
# Create tool with state-mapping parameters
172+
tool = MCPTool(
173+
name="add",
174+
server_info=server_info,
175+
eager_connect=False,
176+
outputs_to_string={"source": "result"},
177+
inputs_from_state={"filter": "query_filter"},
178+
outputs_to_state={"result": {"source": "output"}},
179+
)
180+
mcp_tool_cleanup(tool)
181+
182+
# Test serialization (to_dict)
183+
tool_dict = tool.to_dict()
184+
185+
# Verify state-mapping parameters are serialized
186+
assert tool_dict["data"]["outputs_to_string"] == {"source": "result"}
187+
assert tool_dict["data"]["inputs_from_state"] == {"filter": "query_filter"}
188+
assert tool_dict["data"]["outputs_to_state"] == {"result": {"source": "output"}}
189+
190+
# Test deserialization (from_dict)
191+
new_tool = MCPTool.from_dict(tool_dict)
192+
mcp_tool_cleanup(new_tool)
193+
194+
# Verify state-mapping parameters are restored
195+
assert new_tool._outputs_to_string == {"source": "result"}
196+
assert new_tool._inputs_from_state == {"filter": "query_filter"}
197+
assert new_tool._outputs_to_state == {"result": {"source": "output"}}
198+
136199
@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set")
137200
@pytest.mark.integration
138201
def test_pipeline_warmup_with_mcp_tool(self):
@@ -155,3 +218,54 @@ def test_pipeline_warmup_with_mcp_tool(self):
155218
finally:
156219
if tool:
157220
tool.close()
221+
222+
@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set")
223+
@pytest.mark.integration
224+
def test_agent_with_state_mapping(self):
225+
"""Test Agent with MCPTool using state-mapping to inject location from state."""
226+
227+
# Create MCPTool with state-mapping that injects home_city from state as timezone parameter
228+
server_info = StdioServerInfo(command="uvx", args=["mcp-server-time", "--local-timezone=Europe/Berlin"])
229+
tool = MCPTool(
230+
name="get_current_time",
231+
server_info=server_info,
232+
inputs_from_state={"home_city": "timezone"}, # Inject home_city from state as timezone
233+
)
234+
235+
try:
236+
# Build Agent with state schema that includes home_city
237+
agent = Agent(
238+
chat_generator=OpenAIChatGenerator(model="gpt-4o-mini"),
239+
tools=[tool],
240+
state_schema={"home_city": {"type": str}},
241+
)
242+
pipeline = Pipeline()
243+
pipeline.add_component("agent", agent)
244+
245+
# Ask for time without mentioning the location - it should use home_city from state
246+
user_input_msg = ChatMessage.from_user(text="What time is it at home?")
247+
result = pipeline.run(
248+
{
249+
"agent": {
250+
"messages": [user_input_msg],
251+
"home_city": "America/New_York", # Inject New York as home city
252+
}
253+
}
254+
)
255+
256+
# Verify the agent got the time for New York
257+
final_message = result["agent"]["messages"][-1].text
258+
259+
# The response should mention time
260+
assert any(keyword in final_message.lower() for keyword in ["time", "o'clock", "am", "pm"]), (
261+
f"Expected time in response: {final_message}"
262+
)
263+
264+
# Verify the response mentions New York or Eastern timezone (proving state-mapping injected it)
265+
# The user never mentioned location, but timezone info should appear in the response
266+
assert any(keyword in final_message for keyword in ["New York", "New_York"]), (
267+
f"Expected timezone reference (New York) to confirm state-mapping: {final_message}"
268+
)
269+
finally:
270+
if tool:
271+
tool.close()

0 commit comments

Comments
 (0)