Skip to content

Commit c9e8cda

Browse files
sjrljulian-risch
authored andcommitted
fix: Fix serialization and deserialization of ConditionalRouter with multiple outputs (#9490)
* Fix sede of ConditionalRouter with multiple outputs * Add reno
1 parent b0f3d19 commit c9e8cda

File tree

3 files changed

+48
-10
lines changed

3 files changed

+48
-10
lines changed

haystack/components/routers/conditional_router.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import ast
66
import contextlib
7-
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Union, get_args, get_origin
7+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, TypedDict, Union, get_args, get_origin
88

99
from jinja2 import Environment, TemplateSyntaxError, meta
1010
from jinja2.nativetypes import NativeEnvironment
@@ -24,6 +24,13 @@ class RouteConditionException(Exception):
2424
"""Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""
2525

2626

27+
class Route(TypedDict):
28+
condition: str
29+
output: Union[str, List[str]]
30+
output_name: Union[str, List[str]]
31+
output_type: Union[type, List[type]]
32+
33+
2734
@component
2835
class ConditionalRouter:
2936
"""
@@ -108,7 +115,7 @@ class ConditionalRouter:
108115

109116
def __init__( # pylint: disable=too-many-positional-arguments
110117
self,
111-
routes: List[Dict],
118+
routes: List[Route],
112119
custom_filters: Optional[Dict[str, Callable]] = None,
113120
unsafe: bool = False,
114121
validate_output_type: bool = False,
@@ -179,7 +186,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
179186
- Some variables are only needed for specific routing conditions
180187
- You're building flexible pipelines where not all inputs are guaranteed to be present
181188
"""
182-
self.routes: List[dict] = routes
189+
self.routes: List[Route] = routes
183190
self.custom_filters = custom_filters or {}
184191
self._unsafe = unsafe
185192
self._validate_output_type = validate_output_type
@@ -199,7 +206,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
199206
self._validate_routes(routes)
200207
# Inspect the routes to determine input and output types.
201208
input_types: Set[str] = set() # let's just store the name, type will always be Any
202-
output_types: Dict[str, str] = {}
209+
output_types: Dict[str, Union[type, List[type]]] = {}
203210

204211
for route in routes:
205212
# extract inputs
@@ -248,8 +255,12 @@ def to_dict(self) -> Dict[str, Any]:
248255
"""
249256
serialized_routes = []
250257
for route in self.routes:
251-
# output_type needs to be serialized to a string
252-
serialized_routes.append({**route, "output_type": serialize_type(route["output_type"])})
258+
serialized_output_type = (
259+
[serialize_type(t) for t in route["output_type"]]
260+
if isinstance(route["output_type"], list)
261+
else serialize_type(route["output_type"])
262+
)
263+
serialized_routes.append({**route, "output_type": serialized_output_type})
253264
se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
254265
return default_to_dict(
255266
self,
@@ -274,7 +285,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
274285
routes = init_params.get("routes")
275286
for route in routes:
276287
# output_type needs to be deserialized from a string to a type
277-
route["output_type"] = deserialize_type(route["output_type"])
288+
if isinstance(route["output_type"], list):
289+
route["output_type"] = [deserialize_type(t) for t in route["output_type"]]
290+
else:
291+
route["output_type"] = deserialize_type(route["output_type"])
278292

279293
# Since the custom_filters are typed as optional in the init signature, we catch the
280294
# case where they are not present in the serialized data and set them to an empty dict.
@@ -355,7 +369,7 @@ def run(self, **kwargs):
355369

356370
raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")
357371

358-
def _validate_routes(self, routes: List[Dict]):
372+
def _validate_routes(self, routes: List[Route]):
359373
"""
360374
Validates a list of routes.
361375
@@ -401,8 +415,7 @@ def _extract_variables(self, env: Environment, templates: List[str]) -> Set[str]
401415
"""
402416
variables = set()
403417
for template in templates:
404-
ast = env.parse(template)
405-
variables.update(meta.find_undeclared_variables(ast))
418+
variables.update(meta.find_undeclared_variables(env.parse(template)))
406419
return variables
407420

408421
def _validate_template(self, env: Environment, template_text: str):
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
In ConditionalRouter fixed the to_dict and from_dict methods to properly handle the case when output_type is a List of types or a List of strings. This occurs when a user specifies a route in ConditionalRouter to have multiple outputs.

test/components/routers/test_conditional_router.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,3 +615,24 @@ def test_multiple_outputs_validation(self):
615615
}
616616
]
617617
)
618+
619+
def test_sede_multiple_outputs(self):
620+
routes = [
621+
{
622+
"condition": "{{phone_num|get_area_code == 123}}",
623+
"output": ["{{phone_num}}", "{{phone_num|get_area_code}}"],
624+
"output_name": ["phone_num", "area_code"],
625+
"output_type": [str, int],
626+
},
627+
{
628+
"condition": "{{phone_num|get_area_code != 123}}",
629+
"output": ["{{phone_num}}", "{{phone_num|get_area_code}}"],
630+
"output_name": ["phone_num", "area_code"],
631+
"output_type": [str, int],
632+
},
633+
]
634+
635+
router = ConditionalRouter(routes, custom_filters={"get_area_code": custom_filter_to_sede})
636+
reloaded_router = ConditionalRouter.from_dict(router.to_dict())
637+
assert reloaded_router.custom_filters == router.custom_filters
638+
assert reloaded_router.routes == router.routes

0 commit comments

Comments
 (0)