diff --git a/redis/background.py b/redis/background.py new file mode 100644 index 0000000000..6466649859 --- /dev/null +++ b/redis/background.py @@ -0,0 +1,89 @@ +import asyncio +import threading +from typing import Callable + +class BackgroundScheduler: + """ + Schedules background tasks execution either in separate thread or in the running event loop. + """ + def __init__(self): + self._next_timer = None + + def __del__(self): + if self._next_timer: + self._next_timer.cancel() + + def run_once(self, delay: float, callback: Callable, *args): + """ + Runs callable task once after certain delay in seconds. + """ + # Run loop in a separate thread to unblock main thread. + loop = asyncio.new_event_loop() + thread = threading.Thread( + target=_start_event_loop_in_thread, + args=(loop, self._call_later, delay, callback, *args), + daemon=True + ) + thread.start() + + def run_recurring( + self, + interval: float, + callback: Callable, + *args + ): + """ + Runs recurring callable task with given interval in seconds. + """ + # Run loop in a separate thread to unblock main thread. + loop = asyncio.new_event_loop() + + thread = threading.Thread( + target=_start_event_loop_in_thread, + args=(loop, self._call_later_recurring, interval, callback, *args), + daemon=True + ) + thread.start() + + def _call_later(self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args): + self._next_timer = loop.call_later(delay, callback, *args) + + def _call_later_recurring( + self, + loop: asyncio.AbstractEventLoop, + interval: float, + callback: Callable, + *args + ): + self._call_later( + loop, interval, self._execute_recurring, loop, interval, callback, *args + ) + + def _execute_recurring( + self, + loop: asyncio.AbstractEventLoop, + interval: float, + callback: Callable, + *args + ): + """ + Executes recurring callable task with given interval in seconds. + """ + callback(*args) + + self._call_later( + loop, interval, self._execute_recurring, loop, interval, callback, *args + ) + + +def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop, call_soon_cb: Callable, *args): + """ + Starts event loop in a thread and schedule callback as soon as event loop is ready. + Used to be able to schedule tasks using loop.call_later. + + :param event_loop: + :return: + """ + asyncio.set_event_loop(event_loop) + event_loop.call_soon(call_soon_cb, event_loop, *args) + event_loop.run_forever() \ No newline at end of file diff --git a/redis/client.py b/redis/client.py index 0e05b6f542..060fc29493 100755 --- a/redis/client.py +++ b/redis/client.py @@ -603,7 +603,7 @@ def _send_command_parse_response(self, conn, command_name, *args, **options): conn.send_command(*args, **options) return self.parse_response(conn, command_name, **options) - def _close_connection(self, conn) -> None: + def _close_connection(self, conn, error, *args) -> None: """ Close the connection before retrying. @@ -633,7 +633,7 @@ def _execute_command(self, *args, **options): lambda: self._send_command_parse_response( conn, command_name, *args, **options ), - lambda _: self._close_connection(conn), + lambda error: self._close_connection(conn, error, *args), ) finally: if self._single_connection_client: diff --git a/redis/data_structure.py b/redis/data_structure.py new file mode 100644 index 0000000000..5b0df7f017 --- /dev/null +++ b/redis/data_structure.py @@ -0,0 +1,75 @@ +import threading +from typing import List, Any, TypeVar, Generic, Union + +from redis.typing import Number + +T = TypeVar('T') + +class WeightedList(Generic[T]): + """ + Thread-safe weighted list. + """ + def __init__(self): + self._items: List[tuple[Any, Number]] = [] + self._lock = threading.RLock() + + def add(self, item: Any, weight: float) -> None: + """Add item with weight, maintaining sorted order""" + with self._lock: + # Find insertion point using binary search + left, right = 0, len(self._items) + while left < right: + mid = (left + right) // 2 + if self._items[mid][1] < weight: + right = mid + else: + left = mid + 1 + + self._items.insert(left, (item, weight)) + + def remove(self, item): + """Remove first occurrence of item""" + with self._lock: + for i, (stored_item, weight) in enumerate(self._items): + if stored_item == item: + self._items.pop(i) + return weight + raise ValueError("Item not found") + + def get_by_weight_range(self, min_weight: float, max_weight: float) -> List[tuple[Any, Number]]: + """Get all items within weight range""" + with self._lock: + result = [] + for item, weight in self._items: + if min_weight <= weight <= max_weight: + result.append((item, weight)) + return result + + def get_top_n(self, n: int) -> List[tuple[Any, Number]]: + """Get top N the highest weighted items""" + with self._lock: + return [(item, weight) for item, weight in self._items[:n]] + + def update_weight(self, item, new_weight: float): + with self._lock: + """Update weight of an item""" + old_weight = self.remove(item) + self.add(item, new_weight) + return old_weight + + def __iter__(self): + """Iterate in descending weight order""" + with self._lock: + items_copy = self._items.copy() # Create snapshot as lock released after each 'yield' + + for item, weight in items_copy: + yield item, weight + + def __len__(self): + with self._lock: + return len(self._items) + + def __getitem__(self, index) -> tuple[Any, Number]: + with self._lock: + item, weight = self._items[index] + return item, weight \ No newline at end of file diff --git a/redis/event.py b/redis/event.py index b86c66b082..fdb42a04d5 100644 --- a/redis/event.py +++ b/redis/event.py @@ -2,7 +2,7 @@ import threading from abc import ABC, abstractmethod from enum import Enum -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict, Type from redis.auth.token import TokenInterface from redis.credentials import CredentialProvider, StreamingCredentialProvider @@ -42,6 +42,11 @@ def dispatch(self, event: object): async def dispatch_async(self, event: object): pass + @abstractmethod + def register_listeners(self, mappings: Dict[Type[object], List[EventListenerInterface]]): + """Register additional listeners.""" + pass + class EventException(Exception): """ @@ -56,11 +61,14 @@ def __init__(self, exception: Exception, event: object): class EventDispatcher(EventDispatcherInterface): # TODO: Make dispatcher to accept external mappings. - def __init__(self): + def __init__( + self, + event_listeners: Optional[Dict[Type[object], List[EventListenerInterface]]] = None, + ): """ - Mapping should be extended for any new events or listeners to be added. + Dispatcher that dispatches events to listeners associated with given event. """ - self._event_listeners_mapping = { + self._event_listeners_mapping: Dict[Type[object], List[EventListenerInterface]]= { AfterConnectionReleasedEvent: [ ReAuthConnectionListener(), ], @@ -77,17 +85,35 @@ def __init__(self): ], } + self._lock = threading.Lock() + self._async_lock = asyncio.Lock() + + if event_listeners: + self.register_listeners(event_listeners) + def dispatch(self, event: object): - listeners = self._event_listeners_mapping.get(type(event)) + with self._lock: + listeners = self._event_listeners_mapping.get(type(event), []) - for listener in listeners: - listener.listen(event) + for listener in listeners: + listener.listen(event) async def dispatch_async(self, event: object): - listeners = self._event_listeners_mapping.get(type(event)) + with self._async_lock: + listeners = self._event_listeners_mapping.get(type(event), []) - for listener in listeners: - await listener.listen(event) + for listener in listeners: + await listener.listen(event) + + def register_listeners(self, event_listeners: Dict[Type[object], List[EventListenerInterface]]): + with self._lock: + for event_type in event_listeners: + if event_type in self._event_listeners_mapping: + self._event_listeners_mapping[event_type] = list( + set(self._event_listeners_mapping[event_type] + event_listeners[event_type]) + ) + else: + self._event_listeners_mapping[event_type] = event_listeners[event_type] class AfterConnectionReleasedEvent: @@ -225,6 +251,31 @@ def nodes(self) -> dict: def credential_provider(self) -> Union[CredentialProvider, None]: return self._credential_provider +class OnCommandFailEvent: + """ + Event fired whenever a command fails during the execution. + """ + def __init__( + self, + command: tuple, + exception: Exception, + client, + ): + self._command = command + self._exception = exception + self._client = client + + @property + def command(self) -> tuple: + return self._command + + @property + def exception(self) -> Exception: + return self._exception + + @property + def client(self): + return self._client class ReAuthConnectionListener(EventListenerInterface): """ diff --git a/redis/multidb/__init__.py b/redis/multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py new file mode 100644 index 0000000000..9211173c83 --- /dev/null +++ b/redis/multidb/circuit.py @@ -0,0 +1,108 @@ +from abc import abstractmethod, ABC +from enum import Enum +from typing import Callable + +import pybreaker + +class State(Enum): + CLOSED = 'closed' + OPEN = 'open' + HALF_OPEN = 'half-open' + +class CircuitBreaker(ABC): + @property + @abstractmethod + def grace_period(self) -> float: + """The grace period in seconds when the circle should be kept open.""" + pass + + @grace_period.setter + @abstractmethod + def grace_period(self, grace_period: float): + """Set the grace period in seconds.""" + + @property + @abstractmethod + def state(self) -> State: + """The current state of the circuit.""" + pass + + @state.setter + @abstractmethod + def state(self, state: State): + """Set current state of the circuit.""" + pass + + @property + @abstractmethod + def database(self): + """Database associated with this circuit.""" + pass + + @database.setter + @abstractmethod + def database(self, database): + """Set database associated with this circuit.""" + pass + + @abstractmethod + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + """Callback called when the state of the circuit changes.""" + pass + +class PBListener(pybreaker.CircuitBreakerListener): + def __init__( + self, + cb: Callable[[CircuitBreaker, State, State], None], + database, + ): + """Wrapper for callback to be compatible with pybreaker implementation.""" + self._cb = cb + self._database = database + + def state_change(self, cb, old_state, new_state): + cb = PBCircuitBreakerAdapter(cb) + cb.database = self._database + old_state = State(value=old_state.name) + new_state = State(value=new_state.name) + self._cb(cb, old_state, new_state) + + +class PBCircuitBreakerAdapter(CircuitBreaker): + def __init__(self, cb: pybreaker.CircuitBreaker): + """Adapter for pybreaker CircuitBreaker.""" + self._cb = cb + self._state_pb_mapper = { + State.CLOSED: self._cb.close, + State.OPEN: self._cb.open, + State.HALF_OPEN: self._cb.half_open, + } + self._database = None + + @property + def grace_period(self) -> float: + return self._cb.reset_timeout + + @grace_period.setter + def grace_period(self, grace_period: float): + self._cb.reset_timeout = grace_period + + @property + def state(self) -> State: + return State(value=self._cb.state.name) + + @state.setter + def state(self, state: State): + self._state_pb_mapper[state]() + + @property + def database(self): + return self._database + + @database.setter + def database(self, database): + self._database = database + + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + listener = PBListener(cb, self.database) + self._cb.add_listener(listener) \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py new file mode 100644 index 0000000000..78ce039868 --- /dev/null +++ b/redis/multidb/client.py @@ -0,0 +1,221 @@ +import threading +import socket + +from redis.background import BackgroundScheduler +from redis.exceptions import ConnectionError, TimeoutError +from redis.commands import RedisModuleCommands, CoreCommands, SentinelCommands +from redis.multidb.command_executor import DefaultCommandExecutor +from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD +from redis.multidb.circuit import State as CBState, CircuitBreaker +from redis.multidb.database import State as DBState, Database, AbstractDatabase, Databases +from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.failure_detector import FailureDetector +from redis.multidb.healthcheck import HealthCheck + + +class MultiDBClient(RedisModuleCommands, CoreCommands, SentinelCommands): + """ + Client that operates on multiple logical Redis databases. + Should be used in Active-Active database setups. + """ + def __init__(self, config: MultiDbConfig): + self._databases = config.databases() + self._health_checks = config.health_checks + self._health_check_interval = config.health_check_interval + self._failure_detectors = config.failure_detectors + self._failover_strategy = config.failover_strategy + self._failover_strategy.set_databases(self._databases) + self._auto_fallback_interval = config.auto_fallback_interval + self._event_dispatcher = config.event_dispatcher + self._command_executor = DefaultCommandExecutor( + failure_detectors=self._failure_detectors, + databases=self._databases, + failover_strategy=self._failover_strategy, + event_dispatcher=self._event_dispatcher, + auto_fallback_interval=self._auto_fallback_interval, + ) + self._initialized = False + self._hc_lock = threading.RLock() + self._bg_scheduler = BackgroundScheduler() + + def _initialize(self): + """ + Perform initialization of databases to define their initial state. + """ + + # Initial databases check to define initial state + self._check_databases_health() + + # Starts recurring health checks on the background. + self._bg_scheduler.run_recurring( + self._health_check_interval, + self._check_databases_health, + ) + + is_active_db_found = False + + for database, weight in self._databases: + # Set on state changed callback for each circuit. + database.circuit.on_state_changed(self._on_circuit_state_change_callback) + + # Set states according to a weights and circuit state + if database.circuit.state == CBState.CLOSED and not is_active_db_found: + database.state = DBState.ACTIVE + self._command_executor.active_database = database + is_active_db_found = True + elif database.circuit.state == CBState.CLOSED and is_active_db_found: + database.state = DBState.PASSIVE + else: + database.state = DBState.DISCONNECTED + + if not is_active_db_found: + raise NoValidDatabaseException('Initial connection failed - no active database found') + + self._initialized = True + + def get_databases(self) -> Databases: + """ + Returns a sorted (by weight) list of all databases. + """ + return self._databases + + def set_active_database(self, database: AbstractDatabase) -> None: + """ + Promote one of the existing databases to become an active. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError('Given database is not a member of database list') + + self._check_db_health(database) + + if database.circuit.state == CBState.CLOSED: + highest_weighted_db, _ = self._databases.get_top_n(1)[0] + highest_weighted_db.state = DBState.PASSIVE + database.state = DBState.ACTIVE + self._command_executor.active_database = database + return + + raise NoValidDatabaseException('Cannot set active database, database is unhealthy') + + def add_database(self, database: AbstractDatabase): + """ + Adds a new database to the database list. + """ + for existing_db, _ in self._databases: + if existing_db == database: + raise ValueError('Given database already exists') + + self._check_db_health(database) + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.add(database, database.weight) + self._change_active_database(database, highest_weighted_db) + + def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase): + if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: + new_database.state = DBState.ACTIVE + self._command_executor.active_database = new_database + highest_weight_database.state = DBState.PASSIVE + + def remove_database(self, database: Database): + """ + Removes a database from the database list. + """ + weight = self._databases.remove(database) + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + + if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: + highest_weighted_db.state = DBState.ACTIVE + self._command_executor.active_database = highest_weighted_db + + def update_database_weight(self, database: AbstractDatabase, weight: float): + """ + Updates a database from the database list. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError('Given database is not a member of database list') + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.update_weight(database, weight) + database.weight = weight + self._change_active_database(database, highest_weighted_db) + + def add_failure_detector(self, failure_detector: FailureDetector): + """ + Adds a new failure detector to the database. + """ + self._failure_detectors.append(failure_detector) + + def add_health_check(self, healthcheck: HealthCheck): + """ + Adds a new health check to the database. + """ + with self._hc_lock: + self._health_checks.append(healthcheck) + + def execute_command(self, *args, **options): + """ + Executes a single command and return its result. + """ + if not self._initialized: + self._initialize() + + return self._command_executor.execute_command(*args, **options) + + def _check_db_health(self, database: AbstractDatabase) -> None: + """ + Runs health checks on the given database until first failure. + """ + is_healthy = True + + with self._hc_lock: + # Health check will setup circuit state + for health_check in self._health_checks: + if not is_healthy: + # If one of the health checks failed, it's considered unhealthy + break + + try: + is_healthy = health_check.check_health(database) + + if not is_healthy and database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED + except (ConnectionError, TimeoutError, socket.timeout): + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + is_healthy = False + + + def _check_databases_health(self): + """ + Runs health checks as recurring task. + """ + for database, _ in self._databases: + self._check_db_health(database) + + def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): + if new_state == CBState.HALF_OPEN: + self._check_db_health(circuit.database) + return + + if old_state == CBState.CLOSED and new_state == CBState.OPEN: + self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) + +def _half_open_circuit(circuit: CircuitBreaker): + circuit.state = CBState.HALF_OPEN \ No newline at end of file diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py new file mode 100644 index 0000000000..60dbeca36b --- /dev/null +++ b/redis/multidb/command_executor.py @@ -0,0 +1,164 @@ +import socket +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from redis.exceptions import ConnectionError, TimeoutError +from redis.event import EventDispatcherInterface, OnCommandFailEvent +from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.multidb.database import Database, AbstractDatabase, Databases +from redis.multidb.circuit import State as CBState +from redis.multidb.event import RegisterCommandFailure +from redis.multidb.failover import FailoverStrategy +from redis.multidb.failure_detector import FailureDetector + + +class CommandExecutor(ABC): + + @property + @abstractmethod + def failure_detectors(self) -> List[FailureDetector]: + """Returns a list of failure detectors.""" + pass + + @abstractmethod + def add_failure_detector(self, failure_detector: FailureDetector) -> None: + """Adds new failure detector to the list of failure detectors.""" + pass + + @property + @abstractmethod + def databases(self) -> Databases: + """Returns a list of databases.""" + pass + + @property + @abstractmethod + def active_database(self) -> Union[Database, None]: + """Returns currently active database.""" + pass + + @active_database.setter + @abstractmethod + def active_database(self, database: AbstractDatabase) -> None: + """Sets currently active database.""" + pass + + @property + @abstractmethod + def failover_strategy(self) -> FailoverStrategy: + """Returns failover strategy.""" + pass + + @property + @abstractmethod + def auto_fallback_interval(self) -> float: + """Returns auto-fallback interval.""" + pass + + @auto_fallback_interval.setter + @abstractmethod + def auto_fallback_interval(self, auto_fallback_interval: float) -> None: + """Sets auto-fallback interval.""" + pass + + @abstractmethod + def execute_command(self, *args, **options): + """Executes a command and returns the result.""" + pass + + +class DefaultCommandExecutor(CommandExecutor): + + def __init__( + self, + failure_detectors: List[FailureDetector], + databases: Databases, + failover_strategy: FailoverStrategy, + event_dispatcher: EventDispatcherInterface, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + """ + :param failure_detectors: List of failure detectors. + :param databases: List of databases. + :param failover_strategy: Strategy that defines the failover logic. + :param event_dispatcher: Event dispatcher. + :param auto_fallback_interval: Interval between fallback attempts. Fallback to a new database according to + failover_strategy. + """ + self._failure_detectors = failure_detectors + self._databases = databases + self._failover_strategy = failover_strategy + self._event_dispatcher = event_dispatcher + self._auto_fallback_interval = auto_fallback_interval + self._next_fallback_attempt: datetime + self._active_database: Union[Database, None] = None + self._setup_event_dispatcher() + self._schedule_next_fallback() + + @property + def failure_detectors(self) -> List[FailureDetector]: + return self._failure_detectors + + def add_failure_detector(self, failure_detector: FailureDetector) -> None: + self._failure_detectors.append(failure_detector) + + @property + def databases(self) -> Databases: + return self._databases + + @property + def active_database(self) -> Optional[AbstractDatabase]: + return self._active_database + + @active_database.setter + def active_database(self, database: AbstractDatabase) -> None: + self._active_database = database + + @property + def failover_strategy(self) -> FailoverStrategy: + return self._failover_strategy + + @property + def auto_fallback_interval(self) -> float: + return self._auto_fallback_interval + + @auto_fallback_interval.setter + def auto_fallback_interval(self, auto_fallback_interval: int) -> None: + self._auto_fallback_interval = auto_fallback_interval + + def execute_command(self, *args, **options): + if ( + self._active_database is None + or self._active_database.circuit.state != CBState.CLOSED + or ( + self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL + and self._next_fallback_attempt <= datetime.now() + ) + ): + self._active_database = self._failover_strategy.database + self._schedule_next_fallback() + + try: + return self._active_database.client.execute_command(*args, **options) + except (ConnectionError, TimeoutError, socket.timeout) as e: + # Register command failure + self._event_dispatcher.dispatch(OnCommandFailEvent(args, e, self.active_database.client)) + + # Retry until failure detector will trigger opening of circuit + return self.execute_command(*args, **options) + + def _schedule_next_fallback(self) -> None: + if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: + return + + self._next_fallback_attempt = datetime.now() + timedelta(seconds=self._auto_fallback_interval) + + def _setup_event_dispatcher(self): + """ + Registers command failure event listener. + """ + event_listener = RegisterCommandFailure(self._failure_detectors, self._databases) + self._event_dispatcher.register_listeners({ + OnCommandFailEvent: [event_listener], + }) \ No newline at end of file diff --git a/redis/multidb/config.py b/redis/multidb/config.py new file mode 100644 index 0000000000..a349409e9f --- /dev/null +++ b/redis/multidb/config.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass, field +from typing import List, Type, Union + +import pybreaker + +from redis import Redis, Sentinel +from redis.asyncio import RedisCluster +from redis.backoff import ExponentialWithJitterBackoff +from redis.data_structure import WeightedList +from redis.event import EventDispatcher, EventDispatcherInterface +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.database import Database, Databases +from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy +from redis.retry import Retry + +DEFAULT_GRACE_PERIOD = 5.0 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 +DEFAULT_HEALTH_CHECK_RETRIES = 3 +DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) +DEFAULT_FAILURES_THRESHOLD = 100 +DEFAULT_FAILURES_DURATION = 2 +DEFAULT_FAILOVER_RETRIES = 3 +DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) +DEFAULT_AUTO_FALLBACK_INTERVAL = -1 + +def default_health_checks() -> List[HealthCheck]: + return [ + EchoHealthCheck(retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF)), + ] + +def default_failure_detectors() -> List[FailureDetector]: + return [ + CommandFailureDetector(threshold=DEFAULT_FAILURES_THRESHOLD, duration=DEFAULT_FAILURES_DURATION), + ] + +def default_failover_strategy() -> FailoverStrategy: + return WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + +def default_circuit_breaker() -> CircuitBreaker: + circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=DEFAULT_GRACE_PERIOD) + return PBCircuitBreakerAdapter(circuit_breaker) + +def default_event_dispatcher() -> EventDispatcherInterface: + return EventDispatcher() + +@dataclass +class DatabaseConfig: + weight: float + client_kwargs: dict = field(default_factory=dict) + circuit: CircuitBreaker = field(default_factory=default_circuit_breaker) + +@dataclass +class MultiDbConfig: + databases_config: List[DatabaseConfig] + client_class: Type[Union[Redis, RedisCluster, Sentinel]] = Redis + failure_detectors: List[FailureDetector] = field(default_factory=default_failure_detectors) + health_checks: List[HealthCheck] = field(default_factory=default_health_checks) + health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL + failover_strategy: FailoverStrategy = field(default_factory=default_failover_strategy) + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL + event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) + + def databases(self) -> Databases: + databases = WeightedList() + + for database_config in self.databases_config: + client = self.client_class(**database_config.client_kwargs) + databases.add( + Database(client=client, circuit=database_config.circuit, weight=database_config.weight), + database_config.weight + ) + + return databases + diff --git a/redis/multidb/database.py b/redis/multidb/database.py new file mode 100644 index 0000000000..7a655b151f --- /dev/null +++ b/redis/multidb/database.py @@ -0,0 +1,118 @@ +import redis +from abc import ABC, abstractmethod +from enum import Enum +from typing import Union + +from redis import RedisCluster, Sentinel +from redis.data_structure import WeightedList +from redis.multidb.circuit import CircuitBreaker +from redis.typing import Number + + +class State(Enum): + ACTIVE = 0 + PASSIVE = 1 + DISCONNECTED = 2 + +class AbstractDatabase(ABC): + @property + @abstractmethod + def client(self) -> Union[redis.Redis, RedisCluster, Sentinel]: + """The underlying redis client.""" + pass + + @client.setter + @abstractmethod + def client(self, client: Union[redis.Redis, RedisCluster]): + """Set the underlying redis client.""" + pass + + @property + @abstractmethod + def weight(self) -> float: + """The weight of this database in compare to others. Used to determine the database failover to.""" + pass + + @weight.setter + @abstractmethod + def weight(self, weight: float): + """Set the weight of this database in compare to others.""" + pass + + @property + @abstractmethod + def state(self) -> State: + """The state of the current database.""" + pass + + @state.setter + @abstractmethod + def state(self, state: State): + """Set the state of the current database.""" + pass + + @property + @abstractmethod + def circuit(self) -> CircuitBreaker: + """Circuit breaker for the current database.""" + pass + + @circuit.setter + @abstractmethod + def circuit(self, circuit: CircuitBreaker): + """Set the circuit breaker for the current database.""" + pass + +Databases = WeightedList[tuple[AbstractDatabase, Number]] + +class Database(AbstractDatabase): + def __init__( + self, + client: Union[redis.Redis, RedisCluster, Sentinel], + circuit: CircuitBreaker, + weight: float, + state: State = State.DISCONNECTED, + ): + """ + param: client: Client instance for communication with the database. + param: circuit: Circuit breaker for the current database. + param: weight: Weight of current database. Database with the highest weight becomes Active. + param: state: State of the current database. + """ + self._client = client + self._cb = circuit + self._cb.database = self + self._weight = weight + self._state = state + + @property + def client(self) -> Union[redis.Redis, RedisCluster, Sentinel]: + return self._client + + @client.setter + def client(self, client: Union[redis.Redis, RedisCluster, Sentinel]): + self._client = client + + @property + def weight(self) -> float: + return self._weight + + @weight.setter + def weight(self, weight: float): + self._weight = weight + + @property + def state(self) -> State: + return self._state + + @state.setter + def state(self, state: State): + self._state = state + + @property + def circuit(self) -> CircuitBreaker: + return self._cb + + @circuit.setter + def circuit(self, circuit: CircuitBreaker): + self._cb = circuit diff --git a/redis/multidb/event.py b/redis/multidb/event.py new file mode 100644 index 0000000000..3d366dab77 --- /dev/null +++ b/redis/multidb/event.py @@ -0,0 +1,28 @@ +from typing import List + +from redis.event import EventListenerInterface, OnCommandFailEvent +from redis.multidb.config import Databases +from redis.multidb.failure_detector import FailureDetector + + +class RegisterCommandFailure(EventListenerInterface): + """ + Event listener that registers command failures and passing it to the failure detectors. + """ + def __init__(self, failure_detectors: List[FailureDetector], databases: Databases): + self._failure_detectors = failure_detectors + self._databases = databases + + def listen(self, event: OnCommandFailEvent) -> None: + matching_database = None + + for database, _ in self._databases: + if event.client == database.client: + matching_database = database + break + + if matching_database is None: + return + + for failure_detector in self._failure_detectors: + failure_detector.register_failure(matching_database, event.exception, event.command) diff --git a/redis/multidb/exception.py b/redis/multidb/exception.py new file mode 100644 index 0000000000..80fdb9409a --- /dev/null +++ b/redis/multidb/exception.py @@ -0,0 +1,2 @@ +class NoValidDatabaseException(Exception): + pass \ No newline at end of file diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py new file mode 100644 index 0000000000..a4c825aac1 --- /dev/null +++ b/redis/multidb/failover.py @@ -0,0 +1,54 @@ +from abc import ABC, abstractmethod + +from redis.data_structure import WeightedList +from redis.multidb.database import Databases +from redis.multidb.database import AbstractDatabase +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import NoValidDatabaseException +from redis.retry import Retry + + +class FailoverStrategy(ABC): + + @property + @abstractmethod + def database(self) -> AbstractDatabase: + """Select the database according to the strategy.""" + pass + + @abstractmethod + def set_databases(self, databases: Databases) -> None: + """Set the databases strategy operates on.""" + pass + +class WeightBasedFailoverStrategy(FailoverStrategy): + """ + Choose the active database with the highest weight. + """ + def __init__( + self, + retry: Retry + ): + self._retry = retry + self._retry.update_supported_errors([NoValidDatabaseException]) + self._databases = WeightedList() + + @property + def database(self) -> AbstractDatabase: + return self._retry.call_with_retry( + lambda: self._get_active_database(), + lambda _: self._dummy_fail() + ) + + def set_databases(self, databases: Databases) -> None: + self._databases = databases + + def _get_active_database(self) -> AbstractDatabase: + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + return database + + raise NoValidDatabaseException('No valid database available for communication') + + def _dummy_fail(self): + pass diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py new file mode 100644 index 0000000000..7cb5d5db07 --- /dev/null +++ b/redis/multidb/failure_detector.py @@ -0,0 +1,66 @@ +import threading +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import List, Type + +from typing_extensions import Optional + +from redis.multidb.circuit import State as CBState + +class FailureDetector(ABC): + + @abstractmethod + def register_failure(self, database, exception: Exception, cmd: tuple) -> None: + """Register a failure that occurred during command execution.""" + pass + +class CommandFailureDetector(FailureDetector): + """ + Detects a failure based on a threshold of failed commands during a specific period of time. + """ + + def __init__( + self, + threshold: int, + duration: float, + error_types: Optional[List[Type[Exception]]] = None, + ) -> None: + """ + :param threshold: Threshold of failed commands over the duration after which database will be marked as failed. + :param duration: Interval in seconds after which database will be marked as failed if threshold was exceeded. + :param error_types: List of exception that has to be registered. By default, all exceptions are registered. + """ + self._threshold = threshold + self._duration = duration + self._error_types = error_types + self._start_time: datetime = datetime.now() + self._end_time: datetime = self._start_time + timedelta(seconds=self._duration) + self._failures_within_duration: List[tuple[datetime, tuple]] = [] + self._lock = threading.RLock() + + def register_failure(self, database, exception: Exception, cmd: tuple) -> None: + failure_time = datetime.now() + + if not self._start_time < failure_time < self._end_time: + self._reset() + + with self._lock: + if self._error_types: + if type(exception) in self._error_types: + self._failures_within_duration.append((datetime.now(), cmd)) + else: + self._failures_within_duration.append((datetime.now(), cmd)) + + self._check_threshold(database) + + def _check_threshold(self, database): + with self._lock: + if len(self._failures_within_duration) >= self._threshold: + database.circuit.state = CBState.OPEN + self._reset() + + def _reset(self) -> None: + with self._lock: + self._start_time = datetime.now() + self._end_time = self._start_time + timedelta(seconds=self._duration) + self._failures_within_duration = [] \ No newline at end of file diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py new file mode 100644 index 0000000000..a96b9cf815 --- /dev/null +++ b/redis/multidb/healthcheck.py @@ -0,0 +1,57 @@ +from abc import abstractmethod, ABC +from redis.retry import Retry + + +class HealthCheck(ABC): + + @property + @abstractmethod + def retry(self) -> Retry: + """The retry object to use for health checks.""" + pass + + @abstractmethod + def check_health(self, database) -> bool: + """Function to determine the health status.""" + pass + +class AbstractHealthCheck(HealthCheck): + def __init__( + self, + retry: Retry, + ) -> None: + self._retry = retry + + @property + def retry(self) -> Retry: + return self._retry + + @abstractmethod + def check_health(self, database) -> bool: + pass + + +class EchoHealthCheck(AbstractHealthCheck): + def __init__( + self, + retry: Retry, + ) -> None: + """ + Check database healthiness by sending an echo request. + """ + super().__init__( + retry=retry, + ) + def check_health(self, database) -> bool: + return self._retry.call_with_retry( + lambda: self._returns_echoed_message(database), + lambda _: self._dummy_fail() + ) + + def _returns_echoed_message(self, database) -> bool: + expected_message = ["healthcheck", b"healthcheck"] + actual_message = database.client.execute_command('ECHO', "healthcheck") + return actual_message in expected_message + + def _dummy_fail(self): + pass \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 7eaccb1acb..fc316ea720 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,7 @@ ) from redis.connection import Connection, ConnectionInterface, SSLConnection, parse_url from redis.credentials import CredentialProvider +from redis.event import EventDispatcherInterface from redis.exceptions import RedisClusterException from redis.retry import Retry from tests.ssl_utils import get_tls_certificates @@ -581,6 +582,11 @@ def mock_connection() -> ConnectionInterface: mock_connection = Mock(spec=ConnectionInterface) return mock_connection +@pytest.fixture() +def mock_ed() -> EventDispatcherInterface: + mock_ed = Mock(spec=EventDispatcherInterface) + return mock_ed + @pytest.fixture() def cache_key(request) -> CacheKey: diff --git a/tests/test_background.py b/tests/test_background.py new file mode 100644 index 0000000000..4b3a5377c1 --- /dev/null +++ b/tests/test_background.py @@ -0,0 +1,60 @@ +from time import sleep + +import pytest + +from redis.background import BackgroundScheduler + +class TestBackgroundScheduler: + def test_run_once(self): + execute_counter = 0 + one = 'arg1' + two = 9999 + + def callback(arg1: str, arg2: int): + nonlocal execute_counter + nonlocal one + nonlocal two + + execute_counter += 1 + + assert arg1 == one + assert arg2 == two + + scheduler = BackgroundScheduler() + scheduler.run_once(0.1, callback, one, two) + assert execute_counter == 0 + + sleep(0.15) + + assert execute_counter == 1 + + @pytest.mark.parametrize( + "interval,timeout,call_count", + [ + (0.012, 0.04, 3), + (0.035, 0.04, 1), + (0.045, 0.04, 0), + ] + ) + def test_run_recurring(self, interval, timeout, call_count): + execute_counter = 0 + one = 'arg1' + two = 9999 + + def callback(arg1: str, arg2: int): + nonlocal execute_counter + nonlocal one + nonlocal two + + execute_counter += 1 + + assert arg1 == one + assert arg2 == two + + scheduler = BackgroundScheduler() + scheduler.run_recurring(interval, callback, one, two) + assert execute_counter == 0 + + sleep(timeout) + + assert execute_counter == call_count \ No newline at end of file diff --git a/tests/test_data_structure.py b/tests/test_data_structure.py new file mode 100644 index 0000000000..31ac5c4316 --- /dev/null +++ b/tests/test_data_structure.py @@ -0,0 +1,79 @@ +import concurrent +import random +from concurrent.futures import ThreadPoolExecutor +from time import sleep + +from redis.data_structure import WeightedList + + +class TestWeightedList: + def test_add_items(self): + wlist = WeightedList() + + wlist.add('item1', 3.0) + wlist.add('item2', 2.0) + wlist.add('item3', 4.0) + wlist.add('item4', 4.0) + + assert wlist.get_top_n(4) == [('item3', 4.0), ('item4', 4.0), ('item1', 3.0), ('item2', 2.0)] + + def test_remove_items(self): + wlist = WeightedList() + wlist.add('item1', 3.0) + wlist.add('item2', 2.0) + wlist.add('item3', 4.0) + wlist.add('item4', 4.0) + + assert wlist.remove('item2') == 2.0 + assert wlist.remove('item4') == 4.0 + + assert wlist.get_top_n(4) == [('item3', 4.0), ('item1', 3.0)] + + def test_get_by_weight_range(self): + wlist = WeightedList() + wlist.add('item1', 3.0) + wlist.add('item2', 2.0) + wlist.add('item3', 4.0) + wlist.add('item4', 4.0) + + assert wlist.get_by_weight_range(2.0, 3.0) == [('item1', 3.0), ('item2', 2.0)] + + def test_update_weights(self): + wlist = WeightedList() + wlist.add('item1', 3.0) + wlist.add('item2', 2.0) + wlist.add('item3', 4.0) + wlist.add('item4', 4.0) + + assert wlist.get_top_n(4) == [('item3', 4.0), ('item4', 4.0), ('item1', 3.0), ('item2', 2.0)] + + wlist.update_weight('item2', 5.0) + + assert wlist.get_top_n(4) == [('item2', 5.0), ('item3', 4.0), ('item4', 4.0), ('item1', 3.0)] + + def test_thread_safety(self) -> None: + """Test thread safety with concurrent operations""" + wl = WeightedList() + + def worker(worker_id): + for i in range(100): + # Add items + wl.add(f"item_{worker_id}_{i}", random.uniform(0, 100)) + + # Read operations + try: + length = len(wl) + if length > 0: + top_items = wl.get_top_n(min(5, length)) + items_in_range = wl.get_by_weight_range(20, 80) + except Exception as e: + print(f"Error in worker {worker_id}: {e}") + + sleep(0.001) # Small delay + + # Run multiple workers concurrently + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + concurrent.futures.wait(futures) + + assert len(wl) == 500 \ No newline at end of file diff --git a/tests/test_event.py b/tests/test_event.py new file mode 100644 index 0000000000..27526abeaf --- /dev/null +++ b/tests/test_event.py @@ -0,0 +1,55 @@ +from unittest.mock import Mock, AsyncMock + +from redis.event import EventListenerInterface, EventDispatcher, AsyncEventListenerInterface + + +class TestEventDispatcher: + def test_register_listeners(self): + mock_event = Mock(spec=object) + mock_event_listener = Mock(spec=EventListenerInterface) + listener_called = 0 + + def callback(event): + nonlocal listener_called + listener_called += 1 + + mock_event_listener.listen = callback + + # Register via constructor + dispatcher = EventDispatcher(event_listeners={type(mock_event): [mock_event_listener]}) + dispatcher.dispatch(mock_event) + + assert listener_called == 1 + + # Register additional listener for the same event + mock_another_event_listener = Mock(spec=EventListenerInterface) + mock_another_event_listener.listen = callback + dispatcher.register_listeners(event_listeners={type(mock_event): [mock_another_event_listener]}) + dispatcher.dispatch(mock_event) + + assert listener_called == 3 + + async def test_register_listeners_async(self): + mock_event = Mock(spec=object) + mock_event_listener = AsyncMock(spec=AsyncEventListenerInterface) + listener_called = 0 + + async def callback(event): + nonlocal listener_called + listener_called += 1 + + mock_event_listener.listen = callback + + # Register via constructor + dispatcher = EventDispatcher(event_listeners={type(mock_event): [mock_event_listener]}) + await dispatcher.dispatch_async(mock_event) + + assert listener_called == 1 + + # Register additional listener for the same event + mock_another_event_listener = Mock(spec=AsyncEventListenerInterface) + mock_another_event_listener.listen = callback + dispatcher.register_listeners(event_listeners={type(mock_event): [mock_another_event_listener]}) + await dispatcher.dispatch_async(mock_event) + + assert listener_called == 3 \ No newline at end of file diff --git a/tests/test_multidb/__init__.py b/tests/test_multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py new file mode 100644 index 0000000000..ad2057a118 --- /dev/null +++ b/tests/test_multidb/conftest.py @@ -0,0 +1,112 @@ +from unittest.mock import Mock + +import pytest + +from redis import Redis +from redis.data_structure import WeightedList +from redis.multidb.circuit import CircuitBreaker, State as CBState +from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ + DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.multidb.database import Database, State, Databases +from redis.multidb.failover import FailoverStrategy +from redis.multidb.failure_detector import FailureDetector +from redis.multidb.healthcheck import HealthCheck +from tests.conftest import mock_ed + + +@pytest.fixture() +def mock_client() -> Redis: + return Mock(spec=Redis) + +@pytest.fixture() +def mock_cb() -> CircuitBreaker: + return Mock(spec=CircuitBreaker) + +@pytest.fixture() +def mock_fd() -> FailureDetector: + return Mock(spec=FailureDetector) + +@pytest.fixture() +def mock_fs() -> FailoverStrategy: + return Mock(spec=FailoverStrategy) + +@pytest.fixture() +def mock_hc() -> HealthCheck: + return Mock(spec=HealthCheck) + +@pytest.fixture() +def mock_db(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.state = request.param.get("state", State.ACTIVE) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db1(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.state = request.param.get("state", State.ACTIVE) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db2(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.state = request.param.get("state", State.ACTIVE) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_multi_db_config( + request, mock_fd, mock_fs, mock_hc, mock_ed +) -> MultiDbConfig: + hc_interval = request.param.get('hc_interval', None) + if hc_interval is None: + hc_interval = DEFAULT_HEALTH_CHECK_INTERVAL + + auto_fallback_interval = request.param.get('auto_fallback_interval', None) + if auto_fallback_interval is None: + auto_fallback_interval = DEFAULT_AUTO_FALLBACK_INTERVAL + + config = MultiDbConfig( + databases_config=[Mock(spec=DatabaseConfig)], + failure_detectors=[mock_fd], + health_checks=[mock_hc], + health_check_interval=hc_interval, + failover_strategy=mock_fs, + auto_fallback_interval=auto_fallback_interval, + event_dispatcher=mock_ed + ) + + return config + +def create_weighted_list(*databases) -> Databases: + dbs = WeightedList() + + for db in databases: + dbs.add(db, db.weight) + + return dbs \ No newline at end of file diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py new file mode 100644 index 0000000000..7dc642373b --- /dev/null +++ b/tests/test_multidb/test_circuit.py @@ -0,0 +1,52 @@ +import pybreaker +import pytest + +from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker + + +class TestPBCircuitBreaker: + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CbState.CLOSED}}, + ], + indirect=True, + ) + def test_cb_correctly_configured(self, mock_db): + pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) + adapter = PBCircuitBreakerAdapter(cb=pb_circuit) + assert adapter.state == CbState.CLOSED + + adapter.state = CbState.OPEN + assert adapter.state == CbState.OPEN + + adapter.state = CbState.HALF_OPEN + assert adapter.state == CbState.HALF_OPEN + + adapter.state = CbState.CLOSED + assert adapter.state == CbState.CLOSED + + assert adapter.grace_period == 5 + adapter.grace_period = 10 + + assert adapter.grace_period == 10 + + adapter.database = mock_db + assert adapter.database == mock_db + + def test_cb_executes_callback_on_state_changed(self): + pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) + adapter = PBCircuitBreakerAdapter(cb=pb_circuit) + called_count = 0 + + def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): + nonlocal called_count + assert old_state == CbState.CLOSED + assert new_state == CbState.HALF_OPEN + assert isinstance(cb, PBCircuitBreakerAdapter) + called_count += 1 + + adapter.on_state_changed(callback) + adapter.state = CbState.HALF_OPEN + + assert called_count == 1 \ No newline at end of file diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py new file mode 100644 index 0000000000..c2ade264b5 --- /dev/null +++ b/tests/test_multidb/test_client.py @@ -0,0 +1,585 @@ +from time import sleep +from unittest.mock import patch, Mock + +import pybreaker +import pytest + +from redis.event import EventDispatcher, OnCommandFailEvent +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ + DEFAULT_FAILOVER_BACKOFF +from redis.multidb.database import State as DBState, AbstractDatabase +from redis.multidb.client import MultiDBClient +from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.failover import WeightBasedFailoverStrategy +from redis.multidb.failure_detector import FailureDetector +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.retry import Retry +from tests.test_multidb.conftest import create_weighted_list + + +class TestMultiDbClient: + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_against_correct_db_on_successful_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.CLOSED + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_execute_command_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.health_checks = [ + EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) + ] + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert client.set('key', 'value') == 'OK1' + sleep(0.15) + assert client.set('key', 'value') == 'OK2' + sleep(0.1) + assert client.set('key', 'value') == 'OK' + sleep(0.1) + assert client.set('key', 'value') == 'OK1' + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_auto_fallback_to_highest_weight_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.auto_fallback_interval = 0.2 + mock_multi_db_config.health_checks = [ + EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) + ] + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert client.set('key', 'value') == 'OK1' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + sleep(0.15) + + assert client.set('key', 'value') == 'OK2' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + sleep(0.22) + + assert client.set('key', 'value') == 'OK1' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_execute_command_throws_exception_on_failed_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'): + client.set('key', 'value') + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.state == DBState.DISCONNECTED + assert mock_db1.state == DBState.DISCONNECTED + assert mock_db2.state == DBState.DISCONNECTED + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_database_throws_exception_on_same_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(ValueError, match='Given database already exists'): + client.add_database(mock_db) + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_database_makes_new_database_active( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert client.set('key', 'value') == 'OK2' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 2 + + assert mock_db.state == DBState.PASSIVE + assert mock_db2.state == DBState.ACTIVE + + client.add_database(mock_db1) + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert client.set('key', 'value') == 'OK1' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_remove_highest_weighted_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + client.remove_database(mock_db1) + + assert client.set('key', 'value') == 'OK2' + + assert mock_db.state == DBState.PASSIVE + assert mock_db2.state == DBState.ACTIVE + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_update_database_weight_to_be_highest( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + client.update_database_weight(mock_db2, 0.8) + assert mock_db2.weight == 0.8 + + assert client.set('key', 'value') == 'OK2' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.PASSIVE + assert mock_db2.state == DBState.ACTIVE + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_new_failure_detector( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_multi_db_config.event_dispatcher = EventDispatcher() + mock_fd = mock_multi_db_config.failure_detectors[0] + + # Event fired if command against mock_db1 would fail + command_fail_event = OnCommandFailEvent( + command=('SET', 'key', 'value'), + exception=Exception(), + client=mock_db1.client + ) + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + # Simulate failing command events that lead to a failure detection + for i in range(5): + mock_multi_db_config.event_dispatcher.dispatch(command_fail_event) + + assert mock_fd.register_failure.call_count == 5 + + another_fd = Mock(spec=FailureDetector) + client.add_failure_detector(another_fd) + + # Simulate failing command events that lead to a failure detection + for i in range(5): + mock_multi_db_config.event_dispatcher.dispatch(command_fail_event) + + assert mock_fd.register_failure.call_count == 10 + assert another_fd.register_failure.call_count == 5 + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_new_health_check( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + another_hc = Mock(spec=HealthCheck) + another_hc.check_health.return_value = True + + client.add_health_check(another_hc) + client._check_db_health(mock_db1) + + assert another_hc.check_health.call_count == 1 + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_set_active_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db.client.execute_command.return_value = 'OK' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + client.set_active_database(mock_db) + assert client.set('key', 'value') == 'OK' + + assert mock_db.state == DBState.ACTIVE + assert mock_db1.state == DBState.PASSIVE + assert mock_db2.state == DBState.PASSIVE + + with pytest.raises(ValueError, match='Given database is not a member of database list'): + client.set_active_database(Mock(spec=AbstractDatabase)) + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = False + + with pytest.raises(NoValidDatabaseException, match='Cannot set active database, database is unhealthy'): + client.set_active_database(mock_db1) \ No newline at end of file diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py new file mode 100644 index 0000000000..54c6d38e1d --- /dev/null +++ b/tests/test_multidb/test_command_executor.py @@ -0,0 +1,174 @@ +from time import sleep +from unittest.mock import PropertyMock + +import pytest + +from redis.event import EventDispatcher, OnCommandFailEvent +from redis.multidb.circuit import State as CBState +from redis.multidb.command_executor import DefaultCommandExecutor +from redis.multidb.failure_detector import CommandFailureDetector +from tests.test_multidb.conftest import create_weighted_list + + +class TestDefaultCommandExecutor: + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed + ) + + executor.active_database = mock_db1 + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + + executor.active_database = mock_db2 + assert executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_automatically_select_active_database( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed + ) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.circuit.state = CBState.OPEN + + assert executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 2 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_fallback_to_another_db_after_fallback_interval( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + auto_fallback_interval=0.1, + ) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.weight = 0.1 + sleep(0.15) + + assert executor.execute_command('SET', 'key', 'value') == 'OK2' + mock_db1.weight = 0.7 + sleep(0.15) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 3 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_fallback_to_another_db_after_failure_detection( + self, mock_db, mock_db1, mock_db2, mock_fs + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + threshold = 5 + fd = CommandFailureDetector(threshold, 1) + ed = EventDispatcher() + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + # Event fired if command against mock_db1 would fail + command_fail_event = OnCommandFailEvent( + command=('SET', 'key', 'value'), + exception=Exception(), + client=mock_db1.client + ) + + executor = DefaultCommandExecutor( + failure_detectors=[fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=ed, + auto_fallback_interval=0.1, + ) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + + # Simulate failing command events that lead to a failure detection + for i in range(threshold): + ed.dispatch(command_fail_event) + + assert executor.execute_command('SET', 'key', 'value') == 'OK2' + + command_fail_event = OnCommandFailEvent( + command=('SET', 'key', 'value'), + exception=Exception(), + client=mock_db2.client + ) + + for i in range(threshold): + ed.dispatch(command_fail_event) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_selector.call_count == 3 \ No newline at end of file diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py new file mode 100644 index 0000000000..a810eea676 --- /dev/null +++ b/tests/test_multidb/test_config.py @@ -0,0 +1,121 @@ +from unittest.mock import Mock +from redis.connection import ConnectionPool +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ + DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD +from redis.multidb.database import Database +from redis.multidb.failure_detector import CommandFailureDetector, FailureDetector +from redis.multidb.healthcheck import EchoHealthCheck, HealthCheck +from redis.multidb.failover import WeightBasedFailoverStrategy, FailoverStrategy + + +class TestMultiDbConfig: + def test_default_config(self): + db_configs = [ + DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0), + DatabaseConfig(client_kwargs={'host': 'host2', 'port': 'port2'}, weight=0.9), + DatabaseConfig(client_kwargs={'host': 'host3', 'port': 'port3'}, weight=0.8), + ] + + config = MultiDbConfig( + databases_config=db_configs + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD + i+=1 + + assert len(config.failure_detectors) == 1 + assert isinstance(config.failure_detectors[0], CommandFailureDetector) + assert len(config.health_checks) == 1 + assert isinstance(config.health_checks[0], EchoHealthCheck) + assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL + assert isinstance(config.failover_strategy, WeightBasedFailoverStrategy) + assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL + + def test_overridden_config(self): + grace_period = 2 + mock_connection_pools = [Mock(spec=ConnectionPool), Mock(spec=ConnectionPool), Mock(spec=ConnectionPool)] + mock_connection_pools[0].connection_kwargs = {} + mock_connection_pools[1].connection_kwargs = {} + mock_connection_pools[2].connection_kwargs = {} + mock_cb1 = Mock(spec=CircuitBreaker) + mock_cb1.grace_period = grace_period + mock_cb2 = Mock(spec=CircuitBreaker) + mock_cb2.grace_period = grace_period + mock_cb3 = Mock(spec=CircuitBreaker) + mock_cb3.grace_period = grace_period + mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] + mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] + health_check_interval = 10 + mock_failover_strategy = Mock(spec=FailoverStrategy) + auto_fallback_interval = 10 + db_configs = [ + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[0]}, weight=1.0, circuit=mock_cb1 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[1]}, weight=0.9, circuit=mock_cb2 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[2]}, weight=0.8, circuit=mock_cb3 + ), + ] + + config = MultiDbConfig( + databases_config=db_configs, + failure_detectors=mock_failure_detectors, + health_checks=mock_health_checks, + health_check_interval=health_check_interval, + failover_strategy=mock_failover_strategy, + auto_fallback_interval=auto_fallback_interval, + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.client.connection_pool == mock_connection_pools[i] + assert db.circuit.grace_period == grace_period + i+=1 + + assert len(config.failure_detectors) == 2 + assert config.failure_detectors[0] == mock_failure_detectors[0] + assert config.failure_detectors[1] == mock_failure_detectors[1] + assert len(config.health_checks) == 2 + assert config.health_checks[0] == mock_health_checks[0] + assert config.health_checks[1] == mock_health_checks[1] + assert config.health_check_interval == health_check_interval + assert config.failover_strategy == mock_failover_strategy + assert config.auto_fallback_interval == auto_fallback_interval + +class TestDatabaseConfig: + def test_default_config(self): + config = DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0) + + assert config.client_kwargs == {'host': 'host1', 'port': 'port1'} + assert config.weight == 1.0 + assert isinstance(config.circuit, PBCircuitBreakerAdapter) + + def test_overridden_config(self): + mock_connection_pool = Mock(spec=ConnectionPool) + mock_circuit = Mock(spec=CircuitBreaker) + + config = DatabaseConfig( + client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit + ) + + assert config.client_kwargs == {'connection_pool': mock_connection_pool} + assert config.weight == 1.0 + assert config.circuit == mock_circuit \ No newline at end of file diff --git a/tests/test_multidb/test_failover.py b/tests/test_multidb/test_failover.py new file mode 100644 index 0000000000..06390c4e2e --- /dev/null +++ b/tests/test_multidb/test_failover.py @@ -0,0 +1,117 @@ +from unittest.mock import PropertyMock + +import pytest + +from redis.backoff import NoBackoff, ExponentialBackoff +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.failover import WeightBasedFailoverStrategy +from redis.retry import Retry + + +class TestWeightBasedFailoverStrategy: + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + ids=['all closed - highest weight', 'highest weight - open'], + indirect=True, + ) + def test_get_valid_database(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + assert failover_strategy.database == mock_db1 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.CLOSED] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + assert failover_strategy.database == mock_db + assert state_mock.call_count == 4 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert failover_strategy.database + + assert state_mock.call_count == 4 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert failover_strategy.database \ No newline at end of file diff --git a/tests/test_multidb/test_failure_detector.py b/tests/test_multidb/test_failure_detector.py new file mode 100644 index 0000000000..8e0c1bcbad --- /dev/null +++ b/tests/test_multidb/test_failure_detector.py @@ -0,0 +1,131 @@ +from time import sleep + +import pytest + +from redis.multidb.failure_detector import CommandFailureDetector +from redis.multidb.circuit import State as CBState +from redis.exceptions import ConnectionError + + +class TestCommandFailureDetector: + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): + fd = CommandFailureDetector(5, 1) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): + fd = CommandFailureDetector(5, 1) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): + fd = CommandFailureDetector(5, 0.3) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + sleep(0.1) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + sleep(0.1) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + sleep(0.1) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + sleep(0.1) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + # 4 more failure as last one already refreshed timer + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): + fd = CommandFailureDetector(5, 0.3) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + sleep(0.4) + + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): + fd = CommandFailureDetector(5, 1, error_types=[ConnectionError]) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(mock_db, ConnectionError(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN \ No newline at end of file diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py new file mode 100644 index 0000000000..9601638913 --- /dev/null +++ b/tests/test_multidb/test_healthcheck.py @@ -0,0 +1,41 @@ +from redis.backoff import ExponentialBackoff +from redis.multidb.database import Database, State +from redis.multidb.healthcheck import EchoHealthCheck +from redis.multidb.circuit import State as CBState +from redis.exceptions import ConnectionError +from redis.retry import Retry + + +class TestEchoHealthCheck: + def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + + assert hc.check_health(db) == True + assert mock_client.execute_command.call_count == 3 + + def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'wrong'] + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + + assert hc.check_health(db) == False + assert mock_client.execute_command.call_count == 3 + + def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): + mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] + mock_cb.state = CBState.HALF_OPEN + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + + assert hc.check_health(db) == True + assert mock_client.execute_command.call_count == 3 \ No newline at end of file