Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/11876.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactored tests to use ``create_autospec()`` for more robust mocking -- by :user:`soheil-star01`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ Serhiy Storchaka
Shubh Agarwal
Simon Kennedy
Sin-Woo Bang
Soheil Dolatabadi
Stanislas Plum
Stanislav Prokop
Stefan Tjarks
Expand Down
21 changes: 14 additions & 7 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,10 @@ async def test_content_encoding( # type: ignore[misc]
"post", URL("http://python.org/"), data="foo", compress="deflate", loop=loop
)
with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In these cases, it's the mock.patch() that's missing it:

Suggested change
with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer:
with mock.patch("aiohttp.client_reqrep.StreamWriter", autospec=True, spec_set=True) as m_writer:

The other lines can then likely be removed entirely.

m_writer.return_value.write_headers = mock.AsyncMock()
mock_writer_instance = mock.create_autospec(
StreamWriter, instance=True, spec_set=True
)
m_writer.return_value = mock_writer_instance
resp = await req._send(conn)
assert req.headers["TRANSFER-ENCODING"] == "chunked"
assert req.headers["CONTENT-ENCODING"] == "deflate"
Expand Down Expand Up @@ -1019,7 +1022,10 @@ async def test_content_encoding_header( # type: ignore[misc]
loop=loop,
)
with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer:
m_writer.return_value.write_headers = mock.AsyncMock()
mock_writer_instance = mock.create_autospec(
StreamWriter, instance=True, spec_set=True
)
m_writer.return_value = mock_writer_instance
resp = await req._send(conn)

assert not m_writer.return_value.enable_compression.called
Expand Down Expand Up @@ -1108,8 +1114,10 @@ async def test_chunked_explicit(
"post", URL("http://python.org/"), chunked=True, loop=loop
)
with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer:
m_writer.return_value.write_headers = mock.AsyncMock()
m_writer.return_value.write_eof = mock.AsyncMock()
mock_writer_instance = mock.create_autospec(
StreamWriter, instance=True, spec_set=True
)
m_writer.return_value = mock_writer_instance
resp = await req._send(conn)

assert "chunked" == req.headers["TRANSFER-ENCODING"]
Expand Down Expand Up @@ -1977,8 +1985,7 @@ async def test_update_body_closes_previous_payload(
req = make_client_request("POST", URL("http://python.org/"))

# Create a mock payload that tracks if it was closed
mock_payload = mock.Mock(spec=payload.Payload)
mock_payload.close = mock.AsyncMock()
mock_payload = mock.create_autospec(payload.Payload, spec_set=True, instance=True)

# Set initial payload
req._body = mock_payload
Expand Down Expand Up @@ -2101,7 +2108,7 @@ async def test_expect100_with_body_becomes_empty(
) -> None:
"""Test that write_bytes handles body becoming empty after expect100 handling."""
# Create a mock writer and connection
mock_writer = mock.AsyncMock()
mock_writer = mock.create_autospec(StreamWriter, instance=True, spec_set=True)
mock_conn = mock.Mock()

# Create a request
Expand Down
4 changes: 2 additions & 2 deletions tests/test_client_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from aiohttp.connector import Connection
from aiohttp.helpers import TimerNoop
from aiohttp.multipart import BadContentDispositionHeader
from aiohttp.tracing import Trace


class WriterMock(mock.AsyncMock):
Expand Down Expand Up @@ -1262,8 +1263,7 @@ def test_redirect_history_in_exception() -> None:
async def test_response_read_triggers_callback(
loop: asyncio.AbstractEventLoop, session: ClientSession
) -> None:
trace = mock.Mock()
trace.send_response_chunk_received = mock.AsyncMock()
trace = mock.create_autospec(Trace, instance=True, spec_set=True)
response_method = "get"
response_url = URL("http://def-cl-resp.org")
response_body = b"This is response"
Expand Down
133 changes: 116 additions & 17 deletions tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import deque
from collections.abc import Awaitable, Callable, Iterator
from http.cookies import BaseCookie, SimpleCookie
from types import SimpleNamespace
from typing import Any, NoReturn, TypedDict, cast
from unittest import mock
from uuid import uuid4
Expand All @@ -25,9 +26,19 @@
from aiohttp.connector import BaseConnector, Connection, TCPConnector, UnixConnector
from aiohttp.cookiejar import CookieJar
from aiohttp.http import RawResponseMessage
from aiohttp.payload import Payload
from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer
from aiohttp.test_utils import TestServer
from aiohttp.tracing import Trace
from aiohttp.tracing import (
Trace,
TraceRequestChunkSentParams,
TraceRequestEndParams,
TraceRequestExceptionParams,
TraceRequestHeadersSentParams,
TraceRequestRedirectParams,
TraceRequestStartParams,
TraceResponseChunkReceivedParams,
)


class _Params(TypedDict):
Expand Down Expand Up @@ -557,9 +568,8 @@ async def test_reraise_os_error(
err = OSError(1, "permission error")
req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
req_factory = mock.Mock(return_value=req)
req._send = mock.AsyncMock(side_effect=err)
req._body = mock.Mock()
req._body.close = mock.AsyncMock()
req._send.side_effect = err
req._body = mock.create_autospec(Payload, spec_set=True, instance=True)
session = await create_session(request_class=req_factory)

async def create_connection(
Expand Down Expand Up @@ -589,9 +599,8 @@ class UnexpectedException(BaseException):
err = UnexpectedException("permission error")
req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
req_factory = mock.Mock(return_value=req)
req._send = mock.AsyncMock(side_effect=err)
req._body = mock.Mock()
req._body.close = mock.AsyncMock()
req._send.side_effect = err
req._body = mock.create_autospec(Payload, spec_set=True, instance=True)
session = await create_session(request_class=req_factory)

connections = []
Expand Down Expand Up @@ -637,7 +646,7 @@ async def test_ws_connect_allowed_protocols( # type: ignore[misc]
ws_key: str,
key_data: bytes,
) -> None:
resp = mock.create_autospec(aiohttp.ClientResponse)
resp = mock.create_autospec(aiohttp.ClientResponse, spec_set=True, instance=True)
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
Expand All @@ -646,7 +655,6 @@ async def test_ws_connect_allowed_protocols( # type: ignore[misc]
}
resp.url = URL(f"{protocol}://example")
resp.cookies = SimpleCookie()
resp.start = mock.AsyncMock()

req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
req._body = None # No body for WebSocket upgrade requests
Expand Down Expand Up @@ -700,7 +708,7 @@ async def test_ws_connect_unix_socket_allowed_protocols( # type: ignore[misc]
ws_key: str,
key_data: bytes,
) -> None:
resp = mock.create_autospec(aiohttp.ClientResponse)
resp = mock.create_autospec(aiohttp.ClientResponse, spec_set=True, instance=True)
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
Expand All @@ -709,7 +717,6 @@ async def test_ws_connect_unix_socket_allowed_protocols( # type: ignore[misc]
}
resp.url = URL(f"{protocol}://example")
resp.cookies = SimpleCookie()
resp.start = mock.AsyncMock()

req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
req._body = None # No body for WebSocket upgrade requests
Expand Down Expand Up @@ -923,15 +930,41 @@ async def test_request_tracing(
async def handler(request: web.Request) -> web.Response:
return web.json_response({"ok": True})

# Define callback signatures
async def on_request_start_callback(
session: ClientSession,
trace_config_ctx: SimpleNamespace,
params: TraceRequestStartParams,
) -> None:
pass

async def on_request_end_callback(
session: ClientSession,
trace_config_ctx: SimpleNamespace,
params: TraceRequestEndParams,
) -> None:
pass

async def on_request_redirect_callback(
session: ClientSession,
trace_config_ctx: SimpleNamespace,
params: TraceRequestRedirectParams,
) -> None:
pass

app = web.Application()
app.router.add_post("/", handler)

trace_config_ctx = mock.Mock()
body = "This is request body"
gathered_req_headers: CIMultiDict[str] = CIMultiDict()
on_request_start = mock.AsyncMock()
on_request_redirect = mock.AsyncMock()
on_request_end = mock.AsyncMock()

# Create mocks with signatures(above)
on_request_start = mock.create_autospec(on_request_start_callback, spec_set=True)
on_request_end = mock.create_autospec(on_request_end_callback, spec_set=True)
on_request_redirect = mock.create_autospec(
on_request_redirect_callback, spec_set=True
)

with io.BytesIO() as gathered_req_body, io.BytesIO() as gathered_res_body:

Expand Down Expand Up @@ -1006,20 +1039,86 @@ async def root_handler(request: web.Request) -> web.Response:
async def redirect_handler(request: web.Request) -> NoReturn:
raise web.HTTPFound("/")

# Define callback signatures
async def on_request_start_callback(
session: ClientSession,
trace_config_ctx: SimpleNamespace,
params: TraceRequestStartParams,
) -> None:
pass

async def on_request_end_callback(
session: ClientSession,
trace_config_ctx: SimpleNamespace,
params: TraceRequestEndParams,
) -> None:
pass

async def on_request_redirect_callback(
session: ClientSession,
trace_config_ctx: SimpleNamespace,
params: TraceRequestRedirectParams,
) -> None:
pass

async def on_request_exception_callback(
session: ClientSession,
trace_config_ctx: SimpleNamespace,
params: TraceRequestExceptionParams,
) -> None:
pass

async def on_request_chunk_sent_callback(
session: ClientSession,
trace_config_ctx: SimpleNamespace,
params: TraceRequestChunkSentParams,
) -> None:
pass

async def on_response_chunk_received_callback(
session: ClientSession,
trace_config_ctx: SimpleNamespace,
params: TraceResponseChunkReceivedParams,
) -> None:
pass

async def on_request_headers_sent_callback(
session: ClientSession,
trace_config_ctx: SimpleNamespace,
params: TraceRequestHeadersSentParams,
) -> None:
pass

app = web.Application()
app.router.add_get("/", root_handler)
app.router.add_get("/redirect", redirect_handler)

mocks = [mock.AsyncMock() for _ in range(7)]
(
on_request_start = mock.create_autospec(on_request_start_callback, spec_set=True)
on_request_redirect = mock.create_autospec(
on_request_redirect_callback, spec_set=True
)
on_request_end = mock.create_autospec(on_request_end_callback, spec_set=True)
on_request_exception = mock.create_autospec(
on_request_exception_callback, spec_set=True
)
on_request_chunk_sent = mock.create_autospec(
on_request_chunk_sent_callback, spec_set=True
)
on_response_chunk_received = mock.create_autospec(
on_response_chunk_received_callback, spec_set=True
)
on_request_headers_sent = mock.create_autospec(
on_request_headers_sent_callback, spec_set=True
)
mocks = [
on_request_start,
on_request_redirect,
on_request_end,
on_request_exception,
on_request_chunk_sent,
on_response_chunk_received,
on_request_headers_sent,
) = mocks
]

trace_config = aiohttp.TraceConfig(
trace_config_ctx_factory=mock.Mock(return_value=mock.Mock())
Expand Down
33 changes: 21 additions & 12 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
client,
hdrs,
)
from aiohttp._websocket.writer import WebSocketWriter as RealWebSocketWriter
from aiohttp.http import WS_KEY
from aiohttp.http_websocket import WSMessageClose
from aiohttp.streams import EofStream
Expand Down Expand Up @@ -380,9 +381,10 @@ async def test_close(
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(mresp)
writer = mock.Mock()
writer = mock.create_autospec(
RealWebSocketWriter, instance=True, spec_set=True
)
WebSocketWriter.return_value = writer
writer.close = mock.AsyncMock()

session = aiohttp.ClientSession()
resp = await session.ws_connect("http://test.org")
Expand Down Expand Up @@ -489,9 +491,10 @@ async def test_close_exc(
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(mresp)
writer = mock.Mock()
writer = mock.create_autospec(
RealWebSocketWriter, instance=True, spec_set=True
)
WebSocketWriter.return_value = writer
writer.close = mock.AsyncMock()

session = aiohttp.ClientSession()
resp = await session.ws_connect("http://test.org")
Expand Down Expand Up @@ -625,9 +628,10 @@ async def test_reader_read_exception(
m_req.return_value = loop.create_future()
m_req.return_value.set_result(hresp)

writer = mock.Mock()
writer = mock.create_autospec(
RealWebSocketWriter, instance=True, spec_set=True
)
WebSocketWriter.return_value = writer
writer.close = mock.AsyncMock()

session = aiohttp.ClientSession()
resp = await session.ws_connect("http://test.org")
Expand Down Expand Up @@ -778,29 +782,34 @@ async def test_ws_connect_deflate_per_message(
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(mresp)
writer = WebSocketWriter.return_value = mock.Mock()
send_frame = writer.send_frame = mock.AsyncMock()
writer = mock.create_autospec(
RealWebSocketWriter, instance=True, spec_set=True
)

WebSocketWriter.return_value = writer

session = aiohttp.ClientSession()
resp = await session.ws_connect("http://test.org")

await resp.send_str("string", compress=-1)
send_frame.assert_called_with(
writer.send_frame.assert_called_with(
b"string", aiohttp.WSMsgType.TEXT, compress=-1
)

await resp.send_bytes(b"bytes", compress=15)
send_frame.assert_called_with(
writer.send_frame.assert_called_with(
b"bytes", aiohttp.WSMsgType.BINARY, compress=15
)

await resp.send_json([{}], compress=-9)
send_frame.assert_called_with(
writer.send_frame.assert_called_with(
b"[{}]", aiohttp.WSMsgType.TEXT, compress=-9
)

await resp.send_frame(b"[{}]", aiohttp.WSMsgType.TEXT, compress=-9)
send_frame.assert_called_with(b"[{}]", aiohttp.WSMsgType.TEXT, -9)
writer.send_frame.assert_called_with(
b"[{}]", aiohttp.WSMsgType.TEXT, -9
)

await session.close()

Expand Down
Loading
Loading