Skip to content

Commit 7fcb5b6

Browse files
authored
Implementation of automatic batching for async (#554)
1 parent 77a3a40 commit 7fcb5b6

File tree

10 files changed

+573
-90
lines changed

10 files changed

+573
-90
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ The complete documentation for GQL can be found at
4040
* Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage)
4141
* Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html)
4242
* Supports [Custom scalars / Enums](https://gql.readthedocs.io/en/latest/usage/custom_scalars_and_enums.html)
43+
* Supports [Batching requests](https://gql.readthedocs.io/en/latest/advanced/batching_requests.html)
4344
* [gql-cli script](https://gql.readthedocs.io/en/latest/gql-cli/intro.html) to execute GraphQL queries or download schemas from the command line
4445
* [DSL module](https://gql.readthedocs.io/en/latest/advanced/dsl_module.html) to compose GraphQL queries dynamically
4546

docs/advanced/batching_requests.rst

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
.. _batching_requests:
2+
3+
Batching requests
4+
=================
5+
6+
If you need to send multiple GraphQL queries to a backend,
7+
and if the backend supports batch requests,
8+
then you might want to send those requests in a batch instead of
9+
making multiple execution requests.
10+
11+
.. warning::
12+
- Some backends do not support batch requests
13+
- File uploads and subscriptions are not supported with batch requests
14+
15+
Batching requests manually
16+
^^^^^^^^^^^^^^^^^^^^^^^^^^
17+
18+
To execute a batch of requests manually:
19+
20+
- First Make a list of :class:`GraphQLRequest <gql.GraphQLRequest>` objects, containing:
21+
* your GraphQL query
22+
* Optional variable_values
23+
* Optional operation_name
24+
25+
.. code-block:: python
26+
27+
request1 = GraphQLRequest("""
28+
query getContinents {
29+
continents {
30+
code
31+
name
32+
}
33+
}
34+
"""
35+
)
36+
37+
request2 = GraphQLRequest("""
38+
query getContinentName ($code: ID!) {
39+
continent (code: $code) {
40+
name
41+
}
42+
}
43+
""",
44+
variable_values={
45+
"code": "AF",
46+
},
47+
)
48+
49+
requests = [request1, request2]
50+
51+
- Then use one of the `execute_batch` methods, either on Client,
52+
or in a sync or async session
53+
54+
**Sync**:
55+
56+
.. code-block:: python
57+
58+
transport = RequestsHTTPTransport(url=url)
59+
# Or transport = HTTPXTransport(url=url)
60+
61+
with Client(transport=transport) as session:
62+
63+
results = session.execute_batch(requests)
64+
65+
result1 = results[0]
66+
result2 = results[1]
67+
68+
**Async**:
69+
70+
.. code-block:: python
71+
72+
transport = AIOHTTPTransport(url=url)
73+
# Or transport = HTTPXAsyncTransport(url=url)
74+
75+
async with Client(transport=transport) as session:
76+
77+
results = await session.execute_batch(requests)
78+
79+
result1 = results[0]
80+
result2 = results[1]
81+
82+
.. note::
83+
If any request in the batch returns an error, then a TransportQueryError will be raised
84+
with the first error found.
85+
86+
Automatic Batching of requests
87+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
88+
89+
If your code execute multiple requests independently in a short time
90+
(either from different threads in sync code, or from different asyncio tasks in async code),
91+
then you can use gql automatic batching of request functionality.
92+
93+
You define a :code:`batching_interval` in your :class:`Client <gql.Client>`
94+
and each time a new execution request is received through an `execute` method,
95+
we will wait that interval (in seconds) for other requests to arrive
96+
before sending all the requests received in that interval in a single batch.

docs/advanced/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Advanced
66

77
async_advanced_usage
88
async_permanent_session
9+
batching_requests
910
logging
1011
error_handling
1112
local_schema

gql/client.py

Lines changed: 160 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -829,15 +829,11 @@ async def connect_async(self, reconnecting=False, **kwargs):
829829

830830
if reconnecting:
831831
self.session = ReconnectingAsyncClientSession(client=self, **kwargs)
832-
await self.session.start_connecting_task()
833832
else:
834-
try:
835-
await self.transport.connect()
836-
except Exception as e:
837-
await self.transport.close()
838-
raise e
839833
self.session = AsyncClientSession(client=self)
840834

835+
await self.session.connect()
836+
841837
# Get schema from transport if needed
842838
try:
843839
if self.fetch_schema_from_transport and not self.schema:
@@ -846,18 +842,15 @@ async def connect_async(self, reconnecting=False, **kwargs):
846842
# we don't know what type of exception is thrown here because it
847843
# depends on the underlying transport; we just make sure that the
848844
# transport is closed and re-raise the exception
849-
await self.transport.close()
845+
await self.session.close()
850846
raise
851847

852848
return self.session
853849

854850
async def close_async(self):
855851
"""Close the async transport and stop the optional reconnecting task."""
856852

857-
if isinstance(self.session, ReconnectingAsyncClientSession):
858-
await self.session.stop_connecting_task()
859-
860-
await self.transport.close()
853+
await self.session.close()
861854

862855
async def __aenter__(self):
863856
return await self.connect_async()
@@ -1564,12 +1557,17 @@ async def _execute(
15641557
):
15651558
request = request.serialize_variable_values(self.client.schema)
15661559

1567-
# Execute the query with the transport with a timeout
1568-
with fail_after(self.client.execute_timeout):
1569-
result = await self.transport.execute(
1570-
request,
1571-
**kwargs,
1572-
)
1560+
# Check if batching is enabled
1561+
if self.client.batching_enabled:
1562+
future_result = await self._execute_future(request)
1563+
result = await future_result
1564+
else:
1565+
# Execute the query with the transport with a timeout
1566+
with fail_after(self.client.execute_timeout):
1567+
result = await self.transport.execute(
1568+
request,
1569+
**kwargs,
1570+
)
15731571

15741572
# Unserialize the result if requested
15751573
if self.client.schema:
@@ -1828,6 +1826,134 @@ async def execute_batch(
18281826

18291827
return cast(List[Dict[str, Any]], [result.data for result in results])
18301828

1829+
async def _batch_loop(self) -> None:
1830+
"""Main loop of the task used to wait for requests
1831+
to execute them in a batch"""
1832+
1833+
stop_loop = False
1834+
1835+
while not stop_loop:
1836+
# First wait for a first request in from the batch queue
1837+
requests_and_futures: List[Tuple[GraphQLRequest, asyncio.Future]] = []
1838+
1839+
# Wait for the first request
1840+
request_and_future: Optional[Tuple[GraphQLRequest, asyncio.Future]] = (
1841+
await self.batch_queue.get()
1842+
)
1843+
1844+
if request_and_future is None:
1845+
# None is our sentinel value to stop the loop
1846+
break
1847+
1848+
requests_and_futures.append(request_and_future)
1849+
1850+
# Then wait the requested batch interval except if we already
1851+
# have the maximum number of requests in the queue
1852+
if self.batch_queue.qsize() < self.client.batch_max - 1:
1853+
# Wait for the batch interval
1854+
await asyncio.sleep(self.client.batch_interval)
1855+
1856+
# Then get the requests which had been made during that wait interval
1857+
for _ in range(self.client.batch_max - 1):
1858+
try:
1859+
# Use get_nowait since we don't want to wait here
1860+
request_and_future = self.batch_queue.get_nowait()
1861+
1862+
if request_and_future is None:
1863+
# Sentinel value - stop after processing current batch
1864+
stop_loop = True
1865+
break
1866+
1867+
requests_and_futures.append(request_and_future)
1868+
1869+
except asyncio.QueueEmpty:
1870+
# No more requests in queue, that's fine
1871+
break
1872+
1873+
# Extract requests and futures
1874+
requests = [request for request, _ in requests_and_futures]
1875+
futures = [future for _, future in requests_and_futures]
1876+
1877+
# Execute the batch
1878+
try:
1879+
results: List[ExecutionResult] = await self._execute_batch(
1880+
requests,
1881+
serialize_variables=False, # already done
1882+
parse_result=False, # will be done later
1883+
validate_document=False, # already validated
1884+
)
1885+
1886+
# Set the result for each future
1887+
for result, future in zip(results, futures):
1888+
if not future.cancelled():
1889+
future.set_result(result)
1890+
1891+
except Exception as exc:
1892+
# If batch execution fails, propagate the error to all futures
1893+
for future in futures:
1894+
if not future.cancelled():
1895+
future.set_exception(exc)
1896+
1897+
# Signal that the task has stopped
1898+
self._batch_task_stopped_event.set()
1899+
1900+
async def _execute_future(
1901+
self,
1902+
request: GraphQLRequest,
1903+
) -> asyncio.Future:
1904+
"""If batching is enabled, this method will put a request in the batching queue
1905+
instead of executing it directly so that the requests could be put in a batch.
1906+
"""
1907+
1908+
assert hasattr(self, "batch_queue"), "Batching is not enabled"
1909+
assert not self._batch_task_stop_requested, "Batching task has been stopped"
1910+
1911+
future: asyncio.Future = asyncio.Future()
1912+
await self.batch_queue.put((request, future))
1913+
1914+
return future
1915+
1916+
async def _batch_init(self):
1917+
"""Initialize the batch task loop if batching is enabled."""
1918+
if self.client.batching_enabled:
1919+
self.batch_queue: asyncio.Queue = asyncio.Queue()
1920+
self._batch_task_stop_requested = False
1921+
self._batch_task_stopped_event = asyncio.Event()
1922+
self._batch_task = asyncio.create_task(self._batch_loop())
1923+
1924+
async def _batch_cleanup(self):
1925+
"""Cleanup the batching task if batching is enabled."""
1926+
if hasattr(self, "_batch_task_stopped_event"):
1927+
# Send a None in the queue to indicate that the batching task must stop
1928+
# after having processed the remaining requests in the queue
1929+
self._batch_task_stop_requested = True
1930+
await self.batch_queue.put(None)
1931+
1932+
# Wait for the task to process remaining requests and stop
1933+
await self._batch_task_stopped_event.wait()
1934+
1935+
async def connect(self):
1936+
"""Connect the transport and initialize the batch task loop if batching
1937+
is enabled."""
1938+
1939+
await self._batch_init()
1940+
1941+
try:
1942+
await self.transport.connect()
1943+
except Exception as e:
1944+
await self.transport.close()
1945+
raise e
1946+
1947+
async def close(self):
1948+
"""Close the transport and cleanup the batching task if batching is enabled.
1949+
1950+
Will wait until all the remaining requests in the batch processing queue
1951+
have been executed.
1952+
"""
1953+
await self._batch_cleanup()
1954+
1955+
await self.transport.close()
1956+
18311957
async def fetch_schema(self) -> None:
18321958
"""Fetch the GraphQL schema explicitly using introspection.
18331959
@@ -1954,6 +2080,23 @@ async def stop_connecting_task(self):
19542080
self._connect_task.cancel()
19552081
self._connect_task = None
19562082

2083+
async def connect(self):
2084+
"""Start the connect task and initialize the batch task loop if batching
2085+
is enabled."""
2086+
2087+
await self._batch_init()
2088+
2089+
await self.start_connecting_task()
2090+
2091+
async def close(self):
2092+
"""Stop the connect task and cleanup the batching task
2093+
if batching is enabled."""
2094+
await self._batch_cleanup()
2095+
2096+
await self.stop_connecting_task()
2097+
2098+
await self.transport.close()
2099+
19572100
async def _execute_once(
19582101
self,
19592102
request: GraphQLRequest,

gql/graphql_request.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,38 @@
1-
from dataclasses import dataclass
2-
from typing import Any, Dict, Optional
1+
from typing import Any, Dict, Optional, Union
32

43
from graphql import DocumentNode, GraphQLSchema, print_ast
54

5+
from .gql import gql
66
from .utilities import serialize_variable_values
77

88

9-
@dataclass(frozen=True)
109
class GraphQLRequest:
1110
"""GraphQL Request to be executed."""
1211

13-
document: DocumentNode
14-
"""GraphQL query as AST Node object."""
12+
def __init__(
13+
self,
14+
document: Union[DocumentNode, str],
15+
*,
16+
variable_values: Optional[Dict[str, Any]] = None,
17+
operation_name: Optional[str] = None,
18+
):
19+
"""
20+
Initialize a GraphQL request.
1521
16-
variable_values: Optional[Dict[str, Any]] = None
17-
"""Dictionary of input parameters (Default: None)."""
22+
Args:
23+
document: GraphQL query as AST Node object or as a string.
24+
If string, it will be converted to DocumentNode using gql().
25+
variable_values: Dictionary of input parameters (Default: None).
26+
operation_name: Name of the operation that shall be executed.
27+
Only required in multi-operation documents (Default: None).
28+
"""
29+
if isinstance(document, str):
30+
self.document = gql(document)
31+
else:
32+
self.document = document
1833

19-
operation_name: Optional[str] = None
20-
"""
21-
Name of the operation that shall be executed.
22-
Only required in multi-operation documents (Default: None).
23-
"""
34+
self.variable_values = variable_values
35+
self.operation_name = operation_name
2436

2537
def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest":
2638
assert self.variable_values
@@ -48,3 +60,6 @@ def payload(self) -> Dict[str, Any]:
4860
payload["variables"] = self.variable_values
4961

5062
return payload
63+
64+
def __str__(self):
65+
return str(self.payload)

0 commit comments

Comments
 (0)