Skip to content

PYTHON-4927 - Add missing CSOT prose tests #1987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion pymongo/_csot.py
Original file line number Diff line number Diff line change
@@ -16,13 +16,24 @@

from __future__ import annotations

import contextlib
import functools
import inspect
import time
from collections import deque
from contextlib import AbstractContextManager
from contextvars import ContextVar, Token
from typing import TYPE_CHECKING, Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Deque,
Generator,
MutableMapping,
Optional,
TypeVar,
cast,
)

if TYPE_CHECKING:
from pymongo.write_concern import WriteConcern
@@ -54,6 +65,17 @@ def remaining() -> Optional[float]:
return DEADLINE.get() - time.monotonic()


@contextlib.contextmanager
def reset() -> Generator:
timeout = get_timeout()
if timeout is None:
deadline_token = DEADLINE.set(DEADLINE.get())
else:
deadline_token = DEADLINE.set(DEADLINE.get() + timeout)
yield
DEADLINE.reset(deadline_token)


def clamp_remaining(max_timeout: float) -> float:
"""Return the remaining timeout clamped to a max value."""
timeout = remaining()
12 changes: 10 additions & 2 deletions pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
@@ -473,7 +473,11 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:

def _within_time_limit(start_time: float) -> bool:
"""Are we within the with_transaction retry limit?"""
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
timeout = _csot.get_timeout()
if timeout:
return time.monotonic() - start_time < timeout
else:
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT


_T = TypeVar("_T")
@@ -512,6 +516,7 @@ def __init__(
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None, client)
self._timeout = client.options.timeout

async def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
@@ -597,6 +602,7 @@ def _inherit_option(self, name: str, val: _T) -> _T:
return parent_val
return getattr(self.client, name)

@_csot.apply
async def with_transaction(
self,
callback: Callable[[AsyncClientSession], Coroutine[Any, Any, _T]],
@@ -697,7 +703,8 @@ async def callback(session, custom_arg, custom_kwarg=None):
ret = await callback(self)
except Exception as exc:
if self.in_transaction:
await self.abort_transaction()
with _csot.reset():
await self.abort_transaction()
if (
isinstance(exc, PyMongoError)
and exc.has_error_label("TransientTransactionError")
@@ -816,6 +823,7 @@ async def commit_transaction(self) -> None:
finally:
self._transaction.state = _TxnState.COMMITTED

@_csot.apply
async def abort_transaction(self) -> None:
"""Abort a multi-statement transaction.
3 changes: 2 additions & 1 deletion pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
@@ -249,7 +249,8 @@ def get_server_selection_timeout(self) -> float:
timeout = _csot.remaining()
if timeout is None:
return self._settings.server_selection_timeout
return timeout
else:
return min(timeout, self._settings.server_selection_timeout)

async def select_servers(
self,
12 changes: 10 additions & 2 deletions pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
@@ -472,7 +472,11 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:

def _within_time_limit(start_time: float) -> bool:
"""Are we within the with_transaction retry limit?"""
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
timeout = _csot.get_timeout()
if timeout:
return time.monotonic() - start_time < timeout
else:
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT


_T = TypeVar("_T")
@@ -511,6 +515,7 @@ def __init__(
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None, client)
self._timeout = client.options.timeout

def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
@@ -596,6 +601,7 @@ def _inherit_option(self, name: str, val: _T) -> _T:
return parent_val
return getattr(self.client, name)

@_csot.apply
def with_transaction(
self,
callback: Callable[[ClientSession], _T],
@@ -694,7 +700,8 @@ def callback(session, custom_arg, custom_kwarg=None):
ret = callback(self)
except Exception as exc:
if self.in_transaction:
self.abort_transaction()
with _csot.reset():
self.abort_transaction()
if (
isinstance(exc, PyMongoError)
and exc.has_error_label("TransientTransactionError")
@@ -813,6 +820,7 @@ def commit_transaction(self) -> None:
finally:
self._transaction.state = _TxnState.COMMITTED

@_csot.apply
def abort_transaction(self) -> None:
"""Abort a multi-statement transaction.
3 changes: 2 additions & 1 deletion pymongo/synchronous/topology.py
Original file line number Diff line number Diff line change
@@ -249,7 +249,8 @@ def get_server_selection_timeout(self) -> float:
timeout = _csot.remaining()
if timeout is None:
return self._settings.server_selection_timeout
return timeout
else:
return min(timeout, self._settings.server_selection_timeout)

def select_servers(
self,
48 changes: 47 additions & 1 deletion test/asynchronous/test_client.py
Original file line number Diff line number Diff line change
@@ -64,6 +64,7 @@
from test.utils import (
NTHREADS,
CMAPListener,
EventListener,
FunctionCallRecorder,
async_get_pool,
async_wait_until,
@@ -114,7 +115,13 @@
ServerSelectionTimeoutError,
WriteConcernError,
)
from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent
from pymongo.monitoring import (
ConnectionClosedEvent,
ConnectionCreatedEvent,
ConnectionReadyEvent,
ServerHeartbeatListener,
ServerHeartbeatStartedEvent,
)
from pymongo.pool_options import _MAX_METADATA_SIZE, _METADATA, ENV_VAR_K8S, PoolOptions
from pymongo.read_preferences import ReadPreference
from pymongo.server_description import ServerDescription
@@ -2585,5 +2592,44 @@ async def test_direct_client_maintains_pool_to_arbiter(self):
self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 1)


# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#4-background-connection-pooling
class TestClientCSOTProse(AsyncIntegrationTest):
# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-is-refreshed-for-each-handshake-command
@async_client_context.require_auth
@async_client_context.require_version_min(4, 4, -1)
@async_client_context.require_failCommand_appName
async def test_02_timeoutMS_refreshed_for_each_handshake_command(self):
listener = CMAPListener()

async with self.fail_point(
{
"mode": {"times": 1},
"data": {
"failCommands": ["hello", "isMaster", "saslContinue"],
"blockConnection": True,
"blockTimeMS": 15,
"appName": "refreshTimeoutBackgroundPoolTest",
},
}
):
_ = await self.async_single_client(
minPoolSize=1,
timeoutMS=20,
appname="refreshTimeoutBackgroundPoolTest",
event_listeners=[listener],
)

async def predicate():
return (
listener.event_count(ConnectionCreatedEvent) == 1
and listener.event_count(ConnectionReadyEvent) == 1
)

await async_wait_until(
predicate,
"didn't ever see a ConnectionCreatedEvent and a ConnectionReadyEvent",
)


if __name__ == "__main__":
unittest.main()
27 changes: 27 additions & 0 deletions test/asynchronous/test_collection.py
Original file line number Diff line number Diff line change
@@ -64,6 +64,7 @@
InvalidDocument,
InvalidName,
InvalidOperation,
NetworkTimeout,
OperationFailure,
WriteConcernError,
)
@@ -2277,6 +2278,32 @@ async def afind(*args, **kwargs):
for helper, args in helpers:
await helper(*args, let={}) # type: ignore

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#1-multi-batch-inserts
@async_client_context.require_standalone
@async_client_context.require_version_min(4, 4, -1)
@async_client_context.require_failCommand_fail_point
async def test_01_multi_batch_inserts(self):
client = await self.async_single_client(read_preference=ReadPreference.PRIMARY_PREFERRED)
await client.db.coll.drop()

async with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 1010},
}
):
listener = OvertCommandListener()
client2 = await self.async_single_client(
timeoutMS=2000,
read_preference=ReadPreference.PRIMARY_PREFERRED,
event_listeners=[listener],
)
docs = [{"a": "b" * 1000000} for _ in range(50)]
with self.assertRaises(NetworkTimeout):
await client2.db.coll.insert_many(docs)

self.assertEqual(2, len(listener.started_events))


if __name__ == "__main__":
unittest.main()
109 changes: 109 additions & 0 deletions test/asynchronous/test_encryption.py
Original file line number Diff line number Diff line change
@@ -86,6 +86,7 @@
EncryptedCollectionError,
EncryptionError,
InvalidOperation,
NetworkTimeout,
OperationFailure,
ServerSelectionTimeoutError,
WriteError,
@@ -3133,5 +3134,113 @@ async def test_explicit_session_errors_when_unsupported(self):
await self.mongocryptd_client.db.test.insert_one({"x": 1}, session=s)


# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#3-clientencryption
class TestCSOTProse(AsyncEncryptionIntegrationTest):
mongocryptd_client: AsyncMongoClient
MONGOCRYPTD_PORT = 27020
LOCAL_MASTERKEY = Binary(
base64.b64decode(
b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
),
UUID_SUBTYPE,
)

async def asyncSetUp(self) -> None:
self.listener = OvertCommandListener()
self.client = await self.async_single_client(
read_preference=ReadPreference.PRIMARY_PREFERRED, event_listeners=[self.listener]
)
await self.client.keyvault.datakeys.drop()
self.key_vault_client = await self.async_rs_or_single_client(
timeoutMS=50, event_listeners=[self.listener]
)
self.client_encryption = self.create_client_encryption(
key_vault_namespace="keyvault.datakeys",
kms_providers={"local": {"key": self.LOCAL_MASTERKEY}},
key_vault_client=self.key_vault_client,
codec_options=OPTS,
)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#createdatakey
@async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, -1)
async def test_01_create_data_key(self):
async with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 100},
}
):
self.listener.reset()
with self.assertRaisesRegex(EncryptionError, "timed out"):
await self.client_encryption.create_data_key("local")

events = self.listener.started_events
self.assertEqual(1, len(events))
self.assertEqual("insert", events[0].command_name)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#encrypt
@async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, -1)
async def test_02_encrypt(self):
data_key_id = await self.client_encryption.create_data_key("local")
self.assertEqual(4, data_key_id.subtype)
async with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 100},
}
):
self.listener.reset()
with self.assertRaisesRegex(EncryptionError, "timed out"):
await self.client_encryption.encrypt(
"hello",
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
key_id=data_key_id,
)

events = self.listener.started_events
self.assertEqual(1, len(events))
self.assertEqual("find", events[0].command_name)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#decrypt
@async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, -1)
async def test_03_decrypt(self):
data_key_id = await self.client_encryption.create_data_key("local")
self.assertEqual(4, data_key_id.subtype)

encrypted = await self.client_encryption.encrypt(
"hello", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=data_key_id
)
self.assertEqual(6, encrypted.subtype)

await self.key_vault_client.close()
self.key_vault_client = await self.async_rs_or_single_client(
timeoutMS=50, event_listeners=[self.listener]
)
await self.client_encryption.close()
self.client_encryption = self.create_client_encryption(
key_vault_namespace="keyvault.datakeys",
kms_providers={"local": {"key": self.LOCAL_MASTERKEY}},
key_vault_client=self.key_vault_client,
codec_options=OPTS,
)

async with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 100},
}
):
self.listener.reset()
with self.assertRaisesRegex(EncryptionError, "timed out"):
await self.client_encryption.decrypt(encrypted)

events = self.listener.started_events
self.assertEqual(1, len(events))
self.assertEqual("find", events[0].command_name)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/asynchronous/test_session.py
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.helpers import anext
from pymongo.common import _MAX_END_SESSIONS
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
from pymongo.errors import ConfigurationError, InvalidOperation, NetworkTimeout, OperationFailure
from pymongo.operations import IndexModel, InsertOne, UpdateOne
from pymongo.read_concern import ReadConcern

40 changes: 40 additions & 0 deletions test/asynchronous/test_transactions.py
Original file line number Diff line number Diff line change
@@ -43,6 +43,7 @@
ConfigurationError,
ConnectionFailure,
InvalidOperation,
NetworkTimeout,
OperationFailure,
)
from pymongo.operations import IndexModel, InsertOne
@@ -386,6 +387,45 @@ async def find_raw_batches(*args, **kwargs):
if isinstance(res, (AsyncCommandCursor, AsyncCursor)):
await res.to_list()

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#10-convenient-transactions
@async_client_context.require_transactions
@async_client_context.require_version_min(4, 4, -1)
@async_client_context.require_failCommand_fail_point
async def test_10_convenient_transactions_csot(self):
await self.client.db.coll.drop()

listener = OvertCommandListener()

async with self.fail_point(
{
"mode": {"times": 2},
"data": {
"failCommands": ["insert", "abortTransaction"],
"blockConnection": True,
"blockTimeMS": 200,
},
}
):
client = await self.async_rs_or_single_client(
timeoutMS=150,
event_listeners=[listener],
)
session = client.start_session()

async def callback(s):
await client.db.coll.insert_one({"_id": 1}, session=s)

with self.assertRaises(NetworkTimeout):
await session.with_transaction(callback)

started = listener.started_command_names()
failed = listener.failed_command_names()

self.assertIn("insert", started)
self.assertIn("abortTransaction", started)
self.assertIn("insert", failed)
self.assertIn("abortTransaction", failed)


class PatchSessionTimeout:
"""Patches the client_session's with_transaction timeout for testing."""
48 changes: 47 additions & 1 deletion test/test_client.py
Original file line number Diff line number Diff line change
@@ -63,6 +63,7 @@
from test.utils import (
NTHREADS,
CMAPListener,
EventListener,
FunctionCallRecorder,
assertRaisesExactly,
delay,
@@ -102,7 +103,13 @@
ServerSelectionTimeoutError,
WriteConcernError,
)
from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent
from pymongo.monitoring import (
ConnectionClosedEvent,
ConnectionCreatedEvent,
ConnectionReadyEvent,
ServerHeartbeatListener,
ServerHeartbeatStartedEvent,
)
from pymongo.pool_options import _MAX_METADATA_SIZE, _METADATA, ENV_VAR_K8S, PoolOptions
from pymongo.read_preferences import ReadPreference
from pymongo.server_description import ServerDescription
@@ -2541,5 +2548,44 @@ def test_direct_client_maintains_pool_to_arbiter(self):
self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 1)


# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#4-background-connection-pooling
class TestClientCSOTProse(IntegrationTest):
# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-is-refreshed-for-each-handshake-command
@client_context.require_auth
@client_context.require_version_min(4, 4, -1)
@client_context.require_failCommand_appName
def test_02_timeoutMS_refreshed_for_each_handshake_command(self):
listener = CMAPListener()

with self.fail_point(
{
"mode": {"times": 1},
"data": {
"failCommands": ["hello", "isMaster", "saslContinue"],
"blockConnection": True,
"blockTimeMS": 15,
"appName": "refreshTimeoutBackgroundPoolTest",
},
}
):
_ = self.single_client(
minPoolSize=1,
timeoutMS=20,
appname="refreshTimeoutBackgroundPoolTest",
event_listeners=[listener],
)

def predicate():
return (
listener.event_count(ConnectionCreatedEvent) == 1
and listener.event_count(ConnectionReadyEvent) == 1
)

wait_until(
predicate,
"didn't ever see a ConnectionCreatedEvent and a ConnectionReadyEvent",
)


if __name__ == "__main__":
unittest.main()
27 changes: 27 additions & 0 deletions test/test_collection.py
Original file line number Diff line number Diff line change
@@ -59,6 +59,7 @@
InvalidDocument,
InvalidName,
InvalidOperation,
NetworkTimeout,
OperationFailure,
WriteConcernError,
)
@@ -2254,6 +2255,32 @@ def afind(*args, **kwargs):
for helper, args in helpers:
helper(*args, let={}) # type: ignore

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#1-multi-batch-inserts
@client_context.require_standalone
@client_context.require_version_min(4, 4, -1)
@client_context.require_failCommand_fail_point
def test_01_multi_batch_inserts(self):
client = self.single_client(read_preference=ReadPreference.PRIMARY_PREFERRED)
client.db.coll.drop()

with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 1010},
}
):
listener = OvertCommandListener()
client2 = self.single_client(
timeoutMS=2000,
read_preference=ReadPreference.PRIMARY_PREFERRED,
event_listeners=[listener],
)
docs = [{"a": "b" * 1000000} for _ in range(50)]
with self.assertRaises(NetworkTimeout):
client2.db.coll.insert_many(docs)

self.assertEqual(2, len(listener.started_events))


if __name__ == "__main__":
unittest.main()
9 changes: 6 additions & 3 deletions test/test_csot.py
Original file line number Diff line number Diff line change
@@ -12,20 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test the CSOT unified spec tests."""
"""Test the CSOT unified spec and prose tests."""
from __future__ import annotations

import os
import sys
from test.utils import OvertCommandListener

from pymongo.read_concern import ReadConcern

sys.path[0:0] = [""]

from test import IntegrationTest, client_context, unittest
from test.unified_format import generate_test_classes

import pymongo
from pymongo import _csot
from pymongo.errors import PyMongoError
from pymongo import ReadPreference, WriteConcern, _csot
from pymongo.errors import NetworkTimeout, PyMongoError

# Location of JSON test specifications.
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "csot")
109 changes: 109 additions & 0 deletions test/test_encryption.py
Original file line number Diff line number Diff line change
@@ -83,6 +83,7 @@
EncryptedCollectionError,
EncryptionError,
InvalidOperation,
NetworkTimeout,
OperationFailure,
ServerSelectionTimeoutError,
WriteError,
@@ -3115,5 +3116,113 @@ def test_explicit_session_errors_when_unsupported(self):
self.mongocryptd_client.db.test.insert_one({"x": 1}, session=s)


# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#3-clientencryption
class TestCSOTProse(EncryptionIntegrationTest):
mongocryptd_client: MongoClient
MONGOCRYPTD_PORT = 27020
LOCAL_MASTERKEY = Binary(
base64.b64decode(
b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
),
UUID_SUBTYPE,
)

def setUp(self) -> None:
self.listener = OvertCommandListener()
self.client = self.single_client(
read_preference=ReadPreference.PRIMARY_PREFERRED, event_listeners=[self.listener]
)
self.client.keyvault.datakeys.drop()
self.key_vault_client = self.rs_or_single_client(
timeoutMS=50, event_listeners=[self.listener]
)
self.client_encryption = self.create_client_encryption(
key_vault_namespace="keyvault.datakeys",
kms_providers={"local": {"key": self.LOCAL_MASTERKEY}},
key_vault_client=self.key_vault_client,
codec_options=OPTS,
)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#createdatakey
@client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, -1)
def test_01_create_data_key(self):
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 100},
}
):
self.listener.reset()
with self.assertRaisesRegex(EncryptionError, "timed out"):
self.client_encryption.create_data_key("local")

events = self.listener.started_events
self.assertEqual(1, len(events))
self.assertEqual("insert", events[0].command_name)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#encrypt
@client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, -1)
def test_02_encrypt(self):
data_key_id = self.client_encryption.create_data_key("local")
self.assertEqual(4, data_key_id.subtype)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 100},
}
):
self.listener.reset()
with self.assertRaisesRegex(EncryptionError, "timed out"):
self.client_encryption.encrypt(
"hello",
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
key_id=data_key_id,
)

events = self.listener.started_events
self.assertEqual(1, len(events))
self.assertEqual("find", events[0].command_name)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#decrypt
@client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, -1)
def test_03_decrypt(self):
data_key_id = self.client_encryption.create_data_key("local")
self.assertEqual(4, data_key_id.subtype)

encrypted = self.client_encryption.encrypt(
"hello", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=data_key_id
)
self.assertEqual(6, encrypted.subtype)

self.key_vault_client.close()
self.key_vault_client = self.rs_or_single_client(
timeoutMS=50, event_listeners=[self.listener]
)
self.client_encryption.close()
self.client_encryption = self.create_client_encryption(
key_vault_namespace="keyvault.datakeys",
kms_providers={"local": {"key": self.LOCAL_MASTERKEY}},
key_vault_client=self.key_vault_client,
codec_options=OPTS,
)

with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 100},
}
):
self.listener.reset()
with self.assertRaisesRegex(EncryptionError, "timed out"):
self.client_encryption.decrypt(encrypted)

events = self.listener.started_events
self.assertEqual(1, len(events))
self.assertEqual("find", events[0].command_name)


if __name__ == "__main__":
unittest.main()
80 changes: 80 additions & 0 deletions test/test_gridfs_bucket.py
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@
from gridfs.errors import CorruptGridFile, NoFile
from pymongo.errors import (
ConfigurationError,
NetworkTimeout,
NotPrimaryError,
ServerSelectionTimeoutError,
WriteConcernError,
@@ -525,5 +526,84 @@ def test_gridfs_secondary_lazy(self):
self.assertRaises(NotPrimaryError, gfs.upload_from_stream, "test_filename", b"data")


# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#6-gridfs---upload
class TestGridFsCSOT(IntegrationTest):
# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#uploads-via-openuploadstream-can-be-timed-out
@client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, -1)
def test_06_01_uploads_via_open_upload_stream_can_be_timed_out(self):
self.client.db.fs.files.drop()
self.client.db.fs.chunks.drop()

with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 200},
}
):
client = self.single_client(timeoutMS=150)
bucket = gridfs.GridFSBucket(client.db)
upload_stream = bucket.open_upload_stream("filename")
upload_stream.write(b"0x12")
with self.assertRaises(NetworkTimeout):
upload_stream.close()

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#aborting-an-upload-stream-can-be-timed-out
@client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, -1)
def test_06_02_aborting_an_upload_stream_can_be_timed_out(self):
self.client.db.fs.files.drop()
self.client.db.fs.chunks.drop()

with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["delete"], "blockConnection": True, "blockTimeMS": 200},
}
):
client = self.single_client(timeoutMS=150)
bucket = gridfs.GridFSBucket(client.db, chunk_size_bytes=2)
upload_stream = bucket.open_upload_stream("filename")
upload_stream.write(b"0x010x020x030x04")
with self.assertRaises(NetworkTimeout):
upload_stream.abort()

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#7-gridfs---download
@client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, -1)
def test_07_gridfs_download_csot(self):
self.client.db.fs.files.drop()
self.client.db.fs.chunks.drop()

id = ObjectId("000000000000000000000005")

self.client.db.fs.files.insert_one(
{
"_id": id,
"length": 10,
"chunkSize": 4,
"uploadDate": {"$date": "1970-01-01T00:00:00.000Z"},
"md5": "57d83cd477bfb1ccd975ab33d827a92b",
"filename": "length-10",
"contentType": "application/octet-stream",
"aliases": [],
"metadata": {},
}
)

client = self.single_client(timeoutMS=150)
bucket = gridfs.GridFSBucket(client.db)
download_stream = bucket.open_download_stream(id)

with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 200},
}
):
with self.assertRaises(NetworkTimeout):
download_stream.read()


if __name__ == "__main__":
unittest.main()
101 changes: 100 additions & 1 deletion test/test_server_selection.py
Original file line number Diff line number Diff line change
@@ -17,9 +17,10 @@

import os
import sys
import time

from pymongo import MongoClient, ReadPreference
from pymongo.errors import ServerSelectionTimeoutError
from pymongo.errors import NetworkTimeout, ServerSelectionTimeoutError
from pymongo.hello import HelloCompat
from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector
@@ -200,5 +201,103 @@ def test_server_selector_bypassed(self):
self.assertEqual(selector.call_count, 0)


# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#8-server-selection
class TestServerSelectionCSOT(IntegrationTest):
# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#serverselectiontimeoutms-honored-if-timeoutms-is-not-set
@client_context.require_version_min(4, 4, -1)
def test_08_01_server_selection_timeoutMS_honored(self):
client = self.single_client("mongodb://invalid/?serverSelectionTimeoutMS=10")
with self.assertRaises(ServerSelectionTimeoutError):
start = time.time_ns() * 1000
client.admin.command({"ping": 1})

end = time.time_ns() * 1000

self.assertLessEqual(start - end, 15)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-honored-for-server-selection-if-its-lower-than-serverselectiontimeoutms
@client_context.require_version_min(4, 4, -1)
def test_08_02_timeoutMS_honored_for_server_selection_if_lower(self):
client = self.single_client("mongodb://invalid/?timeoutMS=10&serverSelectionTimeoutMS=20")
with self.assertRaises(ServerSelectionTimeoutError):
start = time.time_ns() * 1_000_000
client.admin.command({"ping": 1})
end = time.time_ns() * 1_000_000

self.assertLessEqual(start - end, 15)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#serverselectiontimeoutms-honored-for-server-selection-if-its-lower-than-timeoutms
@client_context.require_version_min(4, 4, -1)
def test_08_03_serverselectiontimeoutms_honored_for_server_selection_if_lower(self):
client = self.single_client("mongodb://invalid/?timeoutMS=20&serverSelectionTimeoutMS=10")
with self.assertRaises(ServerSelectionTimeoutError):
start = time.time_ns() * 1_000_000
client.admin.command({"ping": 1})

end = time.time_ns() * 1_000_000

self.assertLessEqual(start - end, 15)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#serverselectiontimeoutms-honored-for-server-selection-if-timeoutms0
@client_context.require_version_min(4, 4, -1)
def test_08_04_serverselectiontimeoutms_honored_for_server_selection_if_zero_timeoutms(self):
client = self.single_client("mongodb://invalid/?timeoutMS=0&serverSelectionTimeoutMS=10")
with self.assertRaises(ServerSelectionTimeoutError):
start = time.time_ns() * 1_000_000
client.admin.command({"ping": 1})

end = time.time_ns() * 1_000_000

self.assertLessEqual(start - end, 15)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-honored-for-connection-handshake-commands-if-its-lower-than-serverselectiontimeoutms
@client_context.require_auth
@client_context.require_version_min(4, 4, -1)
@client_context.require_failCommand_fail_point
def test_08_05_timeoutms_honored_for_handshake_if_lower(self):
with self.fail_point(
{
"mode": {"times": 1},
"data": {
"failCommands": ["saslContinue"],
"blockConnection": True,
"blockTimeMS": 15,
},
}
):
client = self.single_client(timeoutMS=10, serverSelectionTimeoutMS=20)
with self.assertRaises(NetworkTimeout):
start = time.time_ns() * 1_000_000
client.db.coll.insert_one({"x": 1})

end = time.time_ns() * 1_000_000

self.assertLessEqual(start - end, 15)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#serverselectiontimeoutms-honored-for-connection-handshake-commands-if-its-lower-than-timeoutms
@client_context.require_auth
@client_context.require_version_min(4, 4, -1)
@client_context.require_failCommand_fail_point
def test_08_06_serverSelectionTimeoutMS_honored_for_handshake_if_lower(self):
with self.fail_point(
{
"mode": {"times": 1},
"data": {
"failCommands": ["saslContinue"],
"blockConnection": True,
"blockTimeMS": 15,
},
}
):
client = self.single_client(timeoutMS=20, serverSelectionTimeoutMS=10)
with self.assertRaises(NetworkTimeout):
start = time.time_ns() * 1_000_000
client.db.coll.insert_one({"x": 1})

end = time.time_ns() * 1_000_000

self.assertLessEqual(start - end, 15)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/test_session.py
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@
from gridfs.synchronous.grid_file import GridFS, GridFSBucket
from pymongo import ASCENDING, MongoClient, monitoring
from pymongo.common import _MAX_END_SESSIONS
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
from pymongo.errors import ConfigurationError, InvalidOperation, NetworkTimeout, OperationFailure
from pymongo.operations import IndexModel, InsertOne, UpdateOne
from pymongo.read_concern import ReadConcern
from pymongo.synchronous.command_cursor import CommandCursor
40 changes: 40 additions & 0 deletions test/test_transactions.py
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@
ConfigurationError,
ConnectionFailure,
InvalidOperation,
NetworkTimeout,
OperationFailure,
)
from pymongo.operations import IndexModel, InsertOne
@@ -378,6 +379,45 @@ def find_raw_batches(*args, **kwargs):
if isinstance(res, (CommandCursor, Cursor)):
res.to_list()

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#10-convenient-transactions
@client_context.require_transactions
@client_context.require_version_min(4, 4, -1)
@client_context.require_failCommand_fail_point
def test_10_convenient_transactions_csot(self):
self.client.db.coll.drop()

listener = OvertCommandListener()

with self.fail_point(
{
"mode": {"times": 2},
"data": {
"failCommands": ["insert", "abortTransaction"],
"blockConnection": True,
"blockTimeMS": 200,
},
}
):
client = self.rs_or_single_client(
timeoutMS=150,
event_listeners=[listener],
)
session = client.start_session()

def callback(s):
client.db.coll.insert_one({"_id": 1}, session=s)

with self.assertRaises(NetworkTimeout):
session.with_transaction(callback)

started = listener.started_command_names()
failed = listener.failed_command_names()

self.assertIn("insert", started)
self.assertIn("abortTransaction", started)
self.assertIn("insert", failed)
self.assertIn("abortTransaction", failed)


class PatchSessionTimeout:
"""Patches the client_session's with_transaction timeout for testing."""
4 changes: 4 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
@@ -178,6 +178,10 @@ def started_command_names(self) -> List[str]:
"""Return list of command names started."""
return [event.command_name for event in self.started_events]

def failed_command_names(self) -> List[str]:
"""Return list of command names failed."""
return [event.command_name for event in self.failed_events]

def reset(self) -> None:
"""Reset the state of this listener."""
self.results.clear()