diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 4e82e5448f..492e36c85d 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1070,12 +1070,13 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: ret = False for cmd in commands: try: - cmd.result = await self.parse_response( + result = await self.parse_response( connection, cmd.args[0], **cmd.kwargs ) except Exception as e: - cmd.result = e + result = e ret = True + cmd.set_node_result(self.name, result) # Release connection self._free.append(connection) @@ -1514,7 +1515,7 @@ async def _execute( allow_redirections: bool = True, ) -> List[Any]: todo = [ - cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception) + cmd for cmd in stack if not cmd.unwrap_result() or cmd.get_all_exceptions() ] nodes = {} @@ -1530,12 +1531,11 @@ async def _execute( raise RedisClusterException( f"No targets were found to execute {cmd.args} command on" ) - if len(target_nodes) > 1: - raise RedisClusterException(f"Too many targets for command {cmd.args}") - node = target_nodes[0] - if node.name not in nodes: - nodes[node.name] = (node, []) - nodes[node.name][1].append(cmd) + cmd.target_nodes = target_nodes + for node in target_nodes: + if node.name not in nodes: + nodes[node.name] = (node, []) + nodes[node.name][1].append(cmd) errors = await asyncio.gather( *( @@ -1548,25 +1548,38 @@ async def _execute( if allow_redirections: # send each errored command individually for cmd in todo: - if isinstance(cmd.result, (TryAgainError, MovedError, AskError)): - try: - cmd.result = await client.execute_command( - *cmd.args, **cmd.kwargs - ) - except Exception as e: - cmd.result = e + for name, exc in cmd.get_all_exceptions(): + if isinstance(exc, (TryAgainError, MovedError, AskError)): + try: + result = await client.execute_command( + *cmd.args, **cmd.kwargs + ) + except Exception as e: + result = e + + if isinstance(result, dict): + cmd.result = result + else: + cmd.set_node_result(name, result) + + # We have already retried the command on all nodes. + break if raise_on_error: for cmd in todo: - result = cmd.result - if isinstance(result, Exception): + name_exc = cmd.get_first_exception() + if name_exc: + name, exc = name_exc command = " ".join(map(safe_str, cmd.args)) + # Note: this will only raise the first exception, but that is + # consistent with RedisCluster.execute_command. msg = ( f"Command # {cmd.position + 1} ({command}) of pipeline " - f"caused error: {result.args}" + f"caused error on node {name}: " + f"{exc.args}" ) - result.args = (msg,) + result.args[1:] - raise result + exc.args = (msg,) + exc.args[1:] + raise exc default_node = nodes.get(client.get_default_node().name) if default_node is not None: @@ -1574,14 +1587,19 @@ async def _execute( # to replace it. # Note: when the error is raised we'll reset the default node in the # caller function. + has_exc = False for cmd in default_node[1]: # Check if it has a command that failed with a relevant # exception - if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY: - client.replace_default_node() + for name, exc in cmd.get_all_exceptions(): + if type(exc) in self.__class__.ERRORS_ALLOW_RETRY: + client.replace_default_node() + has_exc = True + break + if has_exc: break - return [cmd.result for cmd in stack] + return [cmd.unwrap_result() for cmd in stack] def _split_command_across_slots( self, command: str, *keys: KeyT @@ -1620,7 +1638,28 @@ def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: self.args = args self.kwargs = kwargs self.position = position - self.result: Union[Any, Exception] = None + self.result: Dict[str, Union[Any, Exception]] = {} + self.target_nodes = None + + def set_node_result(self, node_name: str, result: Union[Any, Exception]): + self.result[node_name] = result + + def unwrap_result( + self, + ) -> Optional[Union[Any, Exception, Dict[str, Union[Any, Exception]]]]: + if len(self.result) == 0: + return None + if len(self.result) == 1: + return next(iter(self.result.values())) + return self.result + + def get_first_exception(self) -> Optional[Tuple[str, Exception]]: + return next( + ((n, r) for n, r in self.result.items() if isinstance(r, Exception)), None + ) + + def get_all_exceptions(self) -> List[Tuple[str, Exception]]: + return [(n, r) for n, r in self.result.items() if isinstance(r, Exception)] def __repr__(self) -> str: return f"[{self.position}] {self.args} ({self.kwargs})"