diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 0039cea540..704f402e4b 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -624,13 +624,6 @@ async def close(self, close_connection_pool: Optional[bool] = None) -> None: """ await self.aclose(close_connection_pool) - async def _send_command_parse_response(self, conn, command_name, *args, **options): - """ - Send a command and parse the response - """ - await conn.send_command(*args) - return await self.parse_response(conn, command_name, **options) - async def _disconnect_raise(self, conn: Connection, error: Exception): """ Close the connection and raise an exception @@ -655,10 +648,12 @@ async def execute_command(self, *args, **options): if self.single_connection_client: await self._single_conn_lock.acquire() try: + await conn.retry.call_with_retry( + lambda: conn.send_command(*args), + lambda error: self._disconnect_raise(conn, error), + ) return await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), + lambda: self.parse_response(conn, command_name, **options), lambda error: self._disconnect_raise(conn, error), ) finally: @@ -1383,10 +1378,12 @@ async def immediate_execute_command(self, *args, **options): conn = await self.connection_pool.get_connection() self.connection = conn + await conn.retry.call_with_retry( + lambda: conn.send_command(*args), + lambda error: self._disconnect_reset_raise(conn, error), + ) return await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), + lambda: self.parse_response(conn, command_name, **options), lambda error: self._disconnect_reset_raise(conn, error), ) diff --git a/redis/client.py b/redis/client.py index 2c4a1fadff..d16334e1fa 100755 --- a/redis/client.py +++ b/redis/client.py @@ -590,13 +590,6 @@ def close(self) -> None: if self.auto_close_connection_pool: self.connection_pool.disconnect() - def _send_command_parse_response(self, conn, command_name, *args, **options): - """ - Send a command and parse the response - """ - conn.send_command(*args, **options) - return self.parse_response(conn, command_name, **options) - def _disconnect_raise(self, conn, error): """ Close the connection and raise an exception @@ -623,10 +616,12 @@ def _execute_command(self, *args, **options): if self._single_connection_client: self.single_connection_lock.acquire() try: + conn.retry.call_with_retry( + lambda: conn.send_command(*args, **options), + lambda error: self._disconnect_raise(conn, error), + ) return conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), + lambda: self.parse_response(conn, command_name, **options), lambda error: self._disconnect_raise(conn, error), ) finally: @@ -1408,10 +1403,12 @@ def immediate_execute_command(self, *args, **options): conn = self.connection_pool.get_connection() self.connection = conn + conn.retry.call_with_retry( + lambda: conn.send_command(*args, **options), + lambda error: self._disconnect_reset_raise(conn, error), + ) return conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), + lambda: self.parse_response(conn, command_name, **options), lambda error: self._disconnect_reset_raise(conn, error), ) diff --git a/tests/conftest.py b/tests/conftest.py index 7eaccb1acb..89f16046c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,13 @@ CacheKey, EvictionPolicy, ) -from redis.connection import Connection, ConnectionInterface, SSLConnection, parse_url +from redis.connection import ( + Connection, + ConnectionInterface, + SSLConnection, + parse_url, + ConnectionPool, +) from redis.credentials import CredentialProvider from redis.exceptions import RedisClusterException from redis.retry import Retry @@ -582,6 +588,12 @@ def mock_connection() -> ConnectionInterface: return mock_connection +@pytest.fixture() +def mock_pool() -> ConnectionPool: + mock_pool = Mock(spec=ConnectionPool) + return mock_pool + + @pytest.fixture() def cache_key(request) -> CacheKey: command = request.param.get("command") diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 340d146ea3..fd7a816cd7 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,6 +1,7 @@ import random from contextlib import asynccontextmanager as _asynccontextmanager from typing import Union +from unittest.mock import Mock import pytest import pytest_asyncio @@ -8,7 +9,7 @@ from packaging.version import Version from redis.asyncio import Sentinel from redis.asyncio.client import Monitor -from redis.asyncio.connection import Connection, parse_url +from redis.asyncio.connection import Connection, parse_url, ConnectionPool from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.credentials import CredentialProvider @@ -219,6 +220,18 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs): yield mocked +@pytest_asyncio.fixture() +def mock_connection() -> Connection: + mock_connection = Mock(spec=Connection) + return mock_connection + + +@pytest_asyncio.fixture() +def mock_pool() -> ConnectionPool: + mock_pool = Mock(spec=ConnectionPool) + return mock_pool + + @pytest_asyncio.fixture() async def credential_provider(request) -> CredentialProvider: return get_credential_provider(request) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 38764d30cd..3de70737e3 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -2,7 +2,7 @@ import socket import types from errno import ECONNREFUSED -from unittest.mock import patch +from unittest.mock import patch, AsyncMock import pytest import redis @@ -20,7 +20,7 @@ parse_url, ) from redis.asyncio.retry import Retry -from redis.backoff import NoBackoff +from redis.backoff import NoBackoff, ExponentialBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from redis.utils import HIREDIS_AVAILABLE from tests.conftest import skip_if_server_version_lt @@ -91,7 +91,7 @@ async def get_conn(): await asyncio.gather(r.set("a", "b"), r.set("c", "d")) assert init_call_count == 1 - assert command_call_count == 2 + assert command_call_count == 4 r.connection = None # it was a Mock await r.aclose() @@ -315,6 +315,59 @@ async def get_redis_connection(): await r1.aclose() +@pytest.mark.onlynoncluster +async def test_client_do_not_retry_write_on_read_failure(mock_connection, mock_pool): + mock_connection.send_command.return_value = True + mock_connection.read_response.side_effect = [ + ConnectionError, + ConnectionError, + b"OK", + ] + mock_connection.retry = Retry(ExponentialBackoff(), 3) + mock_connection.retry_on_error = (ConnectionError,) + mock_pool.get_connection = AsyncMock(return_value=mock_connection) + mock_pool.connection_kwargs = {} + + r = Redis( + connection_pool=mock_pool, + retry=Retry(ExponentialBackoff(), 3), + single_connection_client=True, + ) + await r.set("key", "value") + + # If read from socket fails, writes won't be executed. + mock_connection.send_command.assert_called_once_with("SET", "key", "value") + mock_connection.read_response.call_count = 3 + + +@pytest.mark.onlynoncluster +async def test_pipeline_immediate_do_not_retry_write_on_read_failure( + mock_connection, mock_pool +): + mock_connection.send_command.return_value = True + mock_connection.read_response.side_effect = [ + ConnectionError, + ConnectionError, + b"OK", + ] + mock_connection.retry = Retry(ExponentialBackoff(), 3) + mock_connection.retry_on_error = (ConnectionError,) + mock_pool.get_connection = AsyncMock(return_value=mock_connection) + mock_pool.connection_kwargs = {} + + r = Redis( + connection_pool=mock_pool, + retry=Retry(ExponentialBackoff(), 3), + single_connection_client=True, + ) + pipe = r.pipeline(transaction=False) + await pipe.immediate_execute_command("SET", "key", "value") + + # If read from socket fails, writes won't be executed. + mock_connection.send_command.assert_called_once_with("SET", "key", "value") + mock_connection.read_response.call_count = 3 + + async def test_close_is_aclose(request): """Verify close() calls aclose()""" calls = 0 diff --git a/tests/test_connection.py b/tests/test_connection.py index 9664146ce5..3260a2ae34 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -10,10 +10,11 @@ from unittest.mock import call, patch import pytest + import redis from redis import ConnectionPool, Redis from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser -from redis.backoff import NoBackoff +from redis.backoff import NoBackoff, ExponentialBackoff from redis.cache import ( CacheConfig, CacheEntry, @@ -249,6 +250,51 @@ def get_redis_connection(): r1.close() +@pytest.mark.onlynoncluster +def test_client_do_not_retry_write_on_read_failure(mock_connection, mock_pool): + mock_connection.send_command.return_value = True + mock_connection.read_response.side_effect = [ + ConnectionError, + ConnectionError, + b"OK", + ] + mock_connection.retry = Retry(ExponentialBackoff(), 3) + mock_connection.retry_on_error = (ConnectionError,) + mock_pool.get_connection.return_value = mock_connection + mock_pool.connection_kwargs = {} + + r = Redis(connection_pool=mock_pool, retry=Retry(ExponentialBackoff(), 3)) + r.set("key", "value") + + # If read from socket fails, writes won't be executed. + mock_connection.send_command.assert_called_once_with("SET", "key", "value") + mock_connection.read_response.call_count = 3 + + +@pytest.mark.onlynoncluster +def test_pipeline_immediate_do_not_retry_write_on_read_failure( + mock_connection, mock_pool +): + mock_connection.send_command.return_value = True + mock_connection.read_response.side_effect = [ + ConnectionError, + ConnectionError, + b"OK", + ] + mock_connection.retry = Retry(ExponentialBackoff(), 3) + mock_connection.retry_on_error = (ConnectionError,) + mock_pool.get_connection.return_value = mock_connection + mock_pool.connection_kwargs = {} + + r = Redis(connection_pool=mock_pool, retry=Retry(ExponentialBackoff(), 3)) + pipe = r.pipeline(transaction=False) + pipe.immediate_execute_command("SET", "key", "value") + + # If read from socket fails, writes won't be executed. + mock_connection.send_command.assert_called_once_with("SET", "key", "value") + mock_connection.read_response.call_count = 3 + + @pytest.mark.skipif(sys.version_info == (3, 9), reason="Flacky test on Python 3.9") @pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) def test_redis_connection_pool(request, from_url):