44
55import ast
66import contextlib
7- from typing import Any , Callable , Dict , List , Mapping , Optional , Sequence , Set , Union , get_args , get_origin
7+ from typing import Any , Callable , Dict , List , Mapping , Optional , Sequence , Set , TypedDict , Union , get_args , get_origin
88
99from jinja2 import Environment , TemplateSyntaxError , meta
1010from jinja2 .nativetypes import NativeEnvironment
@@ -24,6 +24,13 @@ class RouteConditionException(Exception):
2424 """Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""
2525
2626
27+ class Route (TypedDict ):
28+ condition : str
29+ output : Union [str , List [str ]]
30+ output_name : Union [str , List [str ]]
31+ output_type : Union [type , List [type ]]
32+
33+
2734@component
2835class ConditionalRouter :
2936 """
@@ -108,7 +115,7 @@ class ConditionalRouter:
108115
109116 def __init__ ( # pylint: disable=too-many-positional-arguments
110117 self ,
111- routes : List [Dict ],
118+ routes : List [Route ],
112119 custom_filters : Optional [Dict [str , Callable ]] = None ,
113120 unsafe : bool = False ,
114121 validate_output_type : bool = False ,
@@ -179,7 +186,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
179186 - Some variables are only needed for specific routing conditions
180187 - You're building flexible pipelines where not all inputs are guaranteed to be present
181188 """
182- self .routes : List [dict ] = routes
189+ self .routes : List [Route ] = routes
183190 self .custom_filters = custom_filters or {}
184191 self ._unsafe = unsafe
185192 self ._validate_output_type = validate_output_type
@@ -199,7 +206,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
199206 self ._validate_routes (routes )
200207 # Inspect the routes to determine input and output types.
201208 input_types : Set [str ] = set () # let's just store the name, type will always be Any
202- output_types : Dict [str , str ] = {}
209+ output_types : Dict [str , Union [ type , List [ type ]] ] = {}
203210
204211 for route in routes :
205212 # extract inputs
@@ -248,8 +255,12 @@ def to_dict(self) -> Dict[str, Any]:
248255 """
249256 serialized_routes = []
250257 for route in self .routes :
251- # output_type needs to be serialized to a string
252- serialized_routes .append ({** route , "output_type" : serialize_type (route ["output_type" ])})
258+ serialized_output_type = (
259+ [serialize_type (t ) for t in route ["output_type" ]]
260+ if isinstance (route ["output_type" ], list )
261+ else serialize_type (route ["output_type" ])
262+ )
263+ serialized_routes .append ({** route , "output_type" : serialized_output_type })
253264 se_filters = {name : serialize_callable (filter_func ) for name , filter_func in self .custom_filters .items ()}
254265 return default_to_dict (
255266 self ,
@@ -274,7 +285,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
274285 routes = init_params .get ("routes" )
275286 for route in routes :
276287 # output_type needs to be deserialized from a string to a type
277- route ["output_type" ] = deserialize_type (route ["output_type" ])
288+ if isinstance (route ["output_type" ], list ):
289+ route ["output_type" ] = [deserialize_type (t ) for t in route ["output_type" ]]
290+ else :
291+ route ["output_type" ] = deserialize_type (route ["output_type" ])
278292
279293 # Since the custom_filters are typed as optional in the init signature, we catch the
280294 # case where they are not present in the serialized data and set them to an empty dict.
@@ -355,7 +369,7 @@ def run(self, **kwargs):
355369
356370 raise NoRouteSelectedException (f"No route fired. Routes: { self .routes } " )
357371
358- def _validate_routes (self , routes : List [Dict ]):
372+ def _validate_routes (self , routes : List [Route ]):
359373 """
360374 Validates a list of routes.
361375
@@ -401,8 +415,7 @@ def _extract_variables(self, env: Environment, templates: List[str]) -> Set[str]
401415 """
402416 variables = set ()
403417 for template in templates :
404- ast = env .parse (template )
405- variables .update (meta .find_undeclared_variables (ast ))
418+ variables .update (meta .find_undeclared_variables (env .parse (template )))
406419 return variables
407420
408421 def _validate_template (self , env : Environment , template_text : str ):
0 commit comments