@@ -829,15 +829,11 @@ async def connect_async(self, reconnecting=False, **kwargs):
829
829
830
830
if reconnecting :
831
831
self .session = ReconnectingAsyncClientSession (client = self , ** kwargs )
832
- await self .session .start_connecting_task ()
833
832
else :
834
- try :
835
- await self .transport .connect ()
836
- except Exception as e :
837
- await self .transport .close ()
838
- raise e
839
833
self .session = AsyncClientSession (client = self )
840
834
835
+ await self .session .connect ()
836
+
841
837
# Get schema from transport if needed
842
838
try :
843
839
if self .fetch_schema_from_transport and not self .schema :
@@ -846,18 +842,15 @@ async def connect_async(self, reconnecting=False, **kwargs):
846
842
# we don't know what type of exception is thrown here because it
847
843
# depends on the underlying transport; we just make sure that the
848
844
# transport is closed and re-raise the exception
849
- await self .transport .close ()
845
+ await self .session .close ()
850
846
raise
851
847
852
848
return self .session
853
849
854
850
async def close_async (self ):
855
851
"""Close the async transport and stop the optional reconnecting task."""
856
852
857
- if isinstance (self .session , ReconnectingAsyncClientSession ):
858
- await self .session .stop_connecting_task ()
859
-
860
- await self .transport .close ()
853
+ await self .session .close ()
861
854
862
855
async def __aenter__ (self ):
863
856
return await self .connect_async ()
@@ -1564,12 +1557,17 @@ async def _execute(
1564
1557
):
1565
1558
request = request .serialize_variable_values (self .client .schema )
1566
1559
1567
- # Execute the query with the transport with a timeout
1568
- with fail_after (self .client .execute_timeout ):
1569
- result = await self .transport .execute (
1570
- request ,
1571
- ** kwargs ,
1572
- )
1560
+ # Check if batching is enabled
1561
+ if self .client .batching_enabled :
1562
+ future_result = await self ._execute_future (request )
1563
+ result = await future_result
1564
+ else :
1565
+ # Execute the query with the transport with a timeout
1566
+ with fail_after (self .client .execute_timeout ):
1567
+ result = await self .transport .execute (
1568
+ request ,
1569
+ ** kwargs ,
1570
+ )
1573
1571
1574
1572
# Unserialize the result if requested
1575
1573
if self .client .schema :
@@ -1828,6 +1826,134 @@ async def execute_batch(
1828
1826
1829
1827
return cast (List [Dict [str , Any ]], [result .data for result in results ])
1830
1828
1829
+ async def _batch_loop (self ) -> None :
1830
+ """Main loop of the task used to wait for requests
1831
+ to execute them in a batch"""
1832
+
1833
+ stop_loop = False
1834
+
1835
+ while not stop_loop :
1836
+ # First wait for a first request in from the batch queue
1837
+ requests_and_futures : List [Tuple [GraphQLRequest , asyncio .Future ]] = []
1838
+
1839
+ # Wait for the first request
1840
+ request_and_future : Optional [Tuple [GraphQLRequest , asyncio .Future ]] = (
1841
+ await self .batch_queue .get ()
1842
+ )
1843
+
1844
+ if request_and_future is None :
1845
+ # None is our sentinel value to stop the loop
1846
+ break
1847
+
1848
+ requests_and_futures .append (request_and_future )
1849
+
1850
+ # Then wait the requested batch interval except if we already
1851
+ # have the maximum number of requests in the queue
1852
+ if self .batch_queue .qsize () < self .client .batch_max - 1 :
1853
+ # Wait for the batch interval
1854
+ await asyncio .sleep (self .client .batch_interval )
1855
+
1856
+ # Then get the requests which had been made during that wait interval
1857
+ for _ in range (self .client .batch_max - 1 ):
1858
+ try :
1859
+ # Use get_nowait since we don't want to wait here
1860
+ request_and_future = self .batch_queue .get_nowait ()
1861
+
1862
+ if request_and_future is None :
1863
+ # Sentinel value - stop after processing current batch
1864
+ stop_loop = True
1865
+ break
1866
+
1867
+ requests_and_futures .append (request_and_future )
1868
+
1869
+ except asyncio .QueueEmpty :
1870
+ # No more requests in queue, that's fine
1871
+ break
1872
+
1873
+ # Extract requests and futures
1874
+ requests = [request for request , _ in requests_and_futures ]
1875
+ futures = [future for _ , future in requests_and_futures ]
1876
+
1877
+ # Execute the batch
1878
+ try :
1879
+ results : List [ExecutionResult ] = await self ._execute_batch (
1880
+ requests ,
1881
+ serialize_variables = False , # already done
1882
+ parse_result = False , # will be done later
1883
+ validate_document = False , # already validated
1884
+ )
1885
+
1886
+ # Set the result for each future
1887
+ for result , future in zip (results , futures ):
1888
+ if not future .cancelled ():
1889
+ future .set_result (result )
1890
+
1891
+ except Exception as exc :
1892
+ # If batch execution fails, propagate the error to all futures
1893
+ for future in futures :
1894
+ if not future .cancelled ():
1895
+ future .set_exception (exc )
1896
+
1897
+ # Signal that the task has stopped
1898
+ self ._batch_task_stopped_event .set ()
1899
+
1900
+ async def _execute_future (
1901
+ self ,
1902
+ request : GraphQLRequest ,
1903
+ ) -> asyncio .Future :
1904
+ """If batching is enabled, this method will put a request in the batching queue
1905
+ instead of executing it directly so that the requests could be put in a batch.
1906
+ """
1907
+
1908
+ assert hasattr (self , "batch_queue" ), "Batching is not enabled"
1909
+ assert not self ._batch_task_stop_requested , "Batching task has been stopped"
1910
+
1911
+ future : asyncio .Future = asyncio .Future ()
1912
+ await self .batch_queue .put ((request , future ))
1913
+
1914
+ return future
1915
+
1916
+ async def _batch_init (self ):
1917
+ """Initialize the batch task loop if batching is enabled."""
1918
+ if self .client .batching_enabled :
1919
+ self .batch_queue : asyncio .Queue = asyncio .Queue ()
1920
+ self ._batch_task_stop_requested = False
1921
+ self ._batch_task_stopped_event = asyncio .Event ()
1922
+ self ._batch_task = asyncio .create_task (self ._batch_loop ())
1923
+
1924
+ async def _batch_cleanup (self ):
1925
+ """Cleanup the batching task if batching is enabled."""
1926
+ if hasattr (self , "_batch_task_stopped_event" ):
1927
+ # Send a None in the queue to indicate that the batching task must stop
1928
+ # after having processed the remaining requests in the queue
1929
+ self ._batch_task_stop_requested = True
1930
+ await self .batch_queue .put (None )
1931
+
1932
+ # Wait for the task to process remaining requests and stop
1933
+ await self ._batch_task_stopped_event .wait ()
1934
+
1935
+ async def connect (self ):
1936
+ """Connect the transport and initialize the batch task loop if batching
1937
+ is enabled."""
1938
+
1939
+ await self ._batch_init ()
1940
+
1941
+ try :
1942
+ await self .transport .connect ()
1943
+ except Exception as e :
1944
+ await self .transport .close ()
1945
+ raise e
1946
+
1947
+ async def close (self ):
1948
+ """Close the transport and cleanup the batching task if batching is enabled.
1949
+
1950
+ Will wait until all the remaining requests in the batch processing queue
1951
+ have been executed.
1952
+ """
1953
+ await self ._batch_cleanup ()
1954
+
1955
+ await self .transport .close ()
1956
+
1831
1957
async def fetch_schema (self ) -> None :
1832
1958
"""Fetch the GraphQL schema explicitly using introspection.
1833
1959
@@ -1954,6 +2080,23 @@ async def stop_connecting_task(self):
1954
2080
self ._connect_task .cancel ()
1955
2081
self ._connect_task = None
1956
2082
2083
+ async def connect (self ):
2084
+ """Start the connect task and initialize the batch task loop if batching
2085
+ is enabled."""
2086
+
2087
+ await self ._batch_init ()
2088
+
2089
+ await self .start_connecting_task ()
2090
+
2091
+ async def close (self ):
2092
+ """Stop the connect task and cleanup the batching task
2093
+ if batching is enabled."""
2094
+ await self ._batch_cleanup ()
2095
+
2096
+ await self .stop_connecting_task ()
2097
+
2098
+ await self .transport .close ()
2099
+
1957
2100
async def _execute_once (
1958
2101
self ,
1959
2102
request : GraphQLRequest ,
0 commit comments