Skip to content

Refactor transports #557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 16 additions & 30 deletions gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
@@ -127,7 +127,7 @@ async def connect(self) -> None:

# Adding custom parameters passed from init
if self.client_session_args:
client_session_args.update(self.client_session_args) # type: ignore
client_session_args.update(self.client_session_args)

log.debug("Connecting transport")

@@ -164,36 +164,22 @@ async def close(self) -> None:

self.session = None

def _prepare_batch_request(
self,
reqs: List[GraphQLRequest],
extra_args: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:

payload = [req.payload for req in reqs]

post_args = {"json": payload}

# Log the payload
if log.isEnabledFor(logging.DEBUG):
log.debug(">>> %s", self.json_serialize(post_args["json"]))

# Pass post_args to aiohttp post method
if extra_args:
post_args.update(extra_args)

return post_args

def _prepare_request(
self,
request: GraphQLRequest,
request: Union[GraphQLRequest, List[GraphQLRequest]],
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> Dict[str, Any]:

payload = request.payload
payload: Dict | List
if isinstance(request, GraphQLRequest):
payload = request.payload
else:
payload = [req.payload for req in request]

if upload_files:
assert isinstance(payload, Dict)
assert isinstance(request, GraphQLRequest)
post_args = self._prepare_file_uploads(request, payload)
else:
post_args = {"json": payload}
@@ -379,15 +365,15 @@ async def execute(
:returns: an ExecutionResult object.
"""

if self.session is None:
raise TransportClosed("Transport is not connected")

post_args = self._prepare_request(
request,
extra_args,
upload_files,
)

if self.session is None:
raise TransportClosed("Transport is not connected")

try:
async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp:
return await self._prepare_result(resp)
@@ -413,14 +399,14 @@ async def execute_batch(
if an error occurred.
"""

post_args = self._prepare_batch_request(
if self.session is None:
raise TransportClosed("Transport is not connected")

post_args = self._prepare_request(
reqs,
extra_args,
)

if self.session is None:
raise TransportClosed("Transport is not connected")

async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp:
return await self._prepare_batch_result(reqs, resp)

55 changes: 21 additions & 34 deletions gql/transport/httpx.py
Original file line number Diff line number Diff line change
@@ -59,15 +59,22 @@ def __init__(

def _prepare_request(
self,
req: GraphQLRequest,
request: Union[GraphQLRequest, List[GraphQLRequest]],
*,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> Dict[str, Any]:

payload = req.payload
payload: Dict | List
if isinstance(request, GraphQLRequest):
payload = request.payload
else:
payload = [req.payload for req in request]

if upload_files:
post_args = self._prepare_file_uploads(req, payload)
assert isinstance(payload, Dict)
assert isinstance(request, GraphQLRequest)
post_args = self._prepare_file_uploads(request, payload)
else:
post_args = {"json": payload}

@@ -81,26 +88,6 @@ def _prepare_request(

return post_args

def _prepare_batch_request(
self,
reqs: List[GraphQLRequest],
extra_args: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:

payload = [req.payload for req in reqs]

post_args = {"json": payload}

# Log the payload
if log.isEnabledFor(logging.DEBUG):
log.debug(">>> %s", self.json_serialize(payload))

# Pass post_args to aiohttp post method
if extra_args:
post_args.update(extra_args)

return post_args

def _prepare_file_uploads(
self,
request: GraphQLRequest,
@@ -244,7 +231,7 @@ def connect(self):

self.client = httpx.Client(**self.kwargs)

def execute( # type: ignore
def execute(
self,
request: GraphQLRequest,
*,
@@ -269,8 +256,8 @@ def execute( # type: ignore

post_args = self._prepare_request(
request,
extra_args,
upload_files,
extra_args=extra_args,
upload_files=upload_files,
)

try:
@@ -292,7 +279,7 @@ def execute_batch(
:code:`execute_batch` on a client or a session.
:param reqs: GraphQL requests as a list of GraphQLRequest objects.
:param extra_args: additional arguments to send to the aiohttp post method
:param extra_args: additional arguments to send to the httpx post method
:return: A list of results of execution.
For every result `data` is the result of executing the query,
`errors` is null if no errors occurred, and is a non-empty array
@@ -302,9 +289,9 @@ def execute_batch(
if not self.client:
raise TransportClosed("Transport is not connected")

post_args = self._prepare_batch_request(
post_args = self._prepare_request(
reqs,
extra_args,
extra_args=extra_args,
)

response = self.client.post(self.url, **post_args)
@@ -361,8 +348,8 @@ async def execute(

post_args = self._prepare_request(
request,
extra_args,
upload_files,
extra_args=extra_args,
upload_files=upload_files,
)

try:
@@ -384,7 +371,7 @@ async def execute_batch(
:code:`execute_batch` on a client or a session.
:param reqs: GraphQL requests as a list of GraphQLRequest objects.
:param extra_args: additional arguments to send to the aiohttp post method
:param extra_args: additional arguments to send to the httpx post method
:return: A list of results of execution.
For every result `data` is the result of executing the query,
`errors` is null if no errors occurred, and is a non-empty array
@@ -394,9 +381,9 @@ async def execute_batch(
if not self.client:
raise TransportClosed("Transport is not connected")

post_args = self._prepare_batch_request(
post_args = self._prepare_request(
reqs,
extra_args,
extra_args=extra_args,
)

response = await self.client.post(self.url, **post_args)
287 changes: 151 additions & 136 deletions gql/transport/requests.py
Original file line number Diff line number Diff line change
@@ -137,32 +137,20 @@ def connect(self):
else:
raise TransportAlreadyConnected("Transport is already connected")

def execute( # type: ignore
def _prepare_request(
self,
request: GraphQLRequest,
request: Union[GraphQLRequest, List[GraphQLRequest]],
*,
timeout: Optional[int] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> ExecutionResult:
"""Execute GraphQL query.
Execute the provided request against the configured remote server. This
uses the requests library to perform a HTTP POST request to the remote server.
:param request: GraphQL request as a
:class:`GraphQLRequest <gql.GraphQLRequest>` object.
:param timeout: Specifies a default timeout for requests (Default: None).
:param extra_args: additional arguments to send to the requests post method
:param upload_files: Set to True if you want to put files in the variable values
:return: The result of execution.
`data` is the result of executing the query, `errors` is null
if no errors occurred, and is a non-empty array if an error occurred.
"""

if not self.session:
raise TransportClosed("Transport is not connected")
) -> Dict[str, Any]:

payload = request.payload
payload: Dict | List
if isinstance(request, GraphQLRequest):
payload = request.payload
else:
payload = [req.payload for req in request]

post_args: Dict[str, Any] = {
"headers": self.headers,
@@ -173,111 +161,139 @@ def execute( # type: ignore
}

if upload_files:
# If the upload_files flag is set, then we need variable_values
assert request.variable_values is not None

# If we upload files, we will extract the files present in the
# variable_values dict and replace them by null values
nulled_variable_values, files = extract_files(
variables=request.variable_values,
file_classes=self.file_classes,
assert isinstance(payload, Dict)
assert isinstance(request, GraphQLRequest)
post_args = self._prepare_file_uploads(
request=request,
payload=payload,
post_args=post_args,
)

# Opening the files using the FileVar parameters
open_files(list(files.values()))
self.files = files
else:
data_key = "json" if self.use_json else "data"
post_args[data_key] = payload

# Save the nulled variable values in the payload
payload["variables"] = nulled_variable_values
# Log the payload
if log.isEnabledFor(logging.DEBUG):
log.debug(">>> %s", self.json_serialize(payload))

# Add the payload to the operations field
operations_str = self.json_serialize(payload)
log.debug("operations %s", operations_str)
# Pass kwargs to requests post method
post_args.update(self.kwargs)

# Generate the file map
# path is nested in a list because the spec allows multiple pointers
# to the same file. But we don't support that.
# Will generate something like {"0": ["variables.file"]}
file_map = {str(i): [path] for i, path in enumerate(files)}
# Pass post_args to requests post method
if extra_args:
post_args.update(extra_args)

return post_args

# Enumerate the file streams
# Will generate something like {'0': FileVar object}
file_vars = {str(i): files[path] for i, path in enumerate(files)}
def _prepare_file_uploads(
self,
request: GraphQLRequest,
*,
payload: Dict[str, Any],
post_args: Dict[str, Any],
) -> Dict[str, Any]:
# If the upload_files flag is set, then we need variable_values
assert request.variable_values is not None

# If we upload files, we will extract the files present in the
# variable_values dict and replace them by null values
nulled_variable_values, files = extract_files(
variables=request.variable_values,
file_classes=self.file_classes,
)

# Add the file map field
file_map_str = self.json_serialize(file_map)
log.debug("file_map %s", file_map_str)
# Opening the files using the FileVar parameters
open_files(list(files.values()))
self.files = files

fields = {"operations": operations_str, "map": file_map_str}
# Save the nulled variable values in the payload
payload["variables"] = nulled_variable_values

# Add the extracted files as remaining fields
for k, file_var in file_vars.items():
assert isinstance(file_var, FileVar)
name = k if file_var.filename is None else file_var.filename
# Add the payload to the operations field
operations_str = self.json_serialize(payload)
log.debug("operations %s", operations_str)

if file_var.content_type is None:
fields[k] = (name, file_var.f)
else:
fields[k] = (name, file_var.f, file_var.content_type)
# Generate the file map
# path is nested in a list because the spec allows multiple pointers
# to the same file. But we don't support that.
# Will generate something like {"0": ["variables.file"]}
file_map = {str(i): [path] for i, path in enumerate(files)}

# Prepare requests http to send multipart-encoded data
data = MultipartEncoder(fields=fields)
# Enumerate the file streams
# Will generate something like {'0': FileVar object}
file_vars = {str(i): files[path] for i, path in enumerate(files)}

post_args["data"] = data
# Add the file map field
file_map_str = self.json_serialize(file_map)
log.debug("file_map %s", file_map_str)

if post_args["headers"] is None:
post_args["headers"] = {}
fields = {"operations": operations_str, "map": file_map_str}

# Add the extracted files as remaining fields
for k, file_var in file_vars.items():
assert isinstance(file_var, FileVar)
name = k if file_var.filename is None else file_var.filename

if file_var.content_type is None:
fields[k] = (name, file_var.f)
else:
post_args["headers"] = dict(post_args["headers"])
fields[k] = (name, file_var.f, file_var.content_type)

# Prepare requests http to send multipart-encoded data
data = MultipartEncoder(fields=fields)

post_args["headers"]["Content-Type"] = data.content_type
post_args["data"] = data

if post_args["headers"] is None:
post_args["headers"] = {}
else:
data_key = "json" if self.use_json else "data"
post_args[data_key] = payload
post_args["headers"] = dict(post_args["headers"])

# Log the payload
if log.isEnabledFor(logging.DEBUG):
log.debug(">>> %s", self.json_serialize(payload))
post_args["headers"]["Content-Type"] = data.content_type

# Pass kwargs to requests post method
post_args.update(self.kwargs)
return post_args

# Pass post_args to requests post method
if extra_args:
post_args.update(extra_args)
def execute(
self,
request: GraphQLRequest,
timeout: Optional[int] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> ExecutionResult:
"""Execute GraphQL query.
Execute the provided request against the configured remote server. This
uses the requests library to perform a HTTP POST request to the remote server.
:param request: GraphQL request as a
:class:`GraphQLRequest <gql.GraphQLRequest>` object.
:param timeout: Specifies a default timeout for requests (Default: None).
:param extra_args: additional arguments to send to the requests post method
:param upload_files: Set to True if you want to put files in the variable values
:return: The result of execution.
`data` is the result of executing the query, `errors` is null
if no errors occurred, and is a non-empty array if an error occurred.
"""

if not self.session:
raise TransportClosed("Transport is not connected")

post_args = self._prepare_request(
request,
timeout=timeout,
extra_args=extra_args,
upload_files=upload_files,
)

# Using the created session to perform requests
try:
response = self.session.request(
self.method, self.url, **post_args # type: ignore
)
response = self.session.request(self.method, self.url, **post_args)
finally:
if upload_files:
close_files(list(self.files.values()))

self.response_headers = response.headers

try:
if self.json_deserialize == json.loads:
result = response.json()
else:
result = self.json_deserialize(response.text)

if log.isEnabledFor(logging.DEBUG):
log.debug("<<< %s", response.text)

except Exception:
self._raise_response_error(response, "Not a JSON answer")

if "errors" not in result and "data" not in result:
self._raise_response_error(response, 'No "data" or "errors" keys in answer')

return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)
return self._prepare_result(response)

@staticmethod
def _raise_transport_server_error_if_status_more_than_400(
@@ -327,27 +343,27 @@ def execute_batch(
if not self.session:
raise TransportClosed("Transport is not connected")

# Using the created session to perform requests
post_args = self._prepare_request(
reqs,
timeout=timeout,
extra_args=extra_args,
)

response = self.session.request(
self.method,
self.url,
**self._build_batch_post_args(reqs, timeout, extra_args),
**post_args,
)
self.response_headers = response.headers

answers = self._extract_response(response)
return self._prepare_batch_result(reqs, response)

try:
return get_batch_execution_result_list(reqs, answers)
except TransportProtocolError:
# Raise a TransportServerError if status > 400
self._raise_transport_server_error_if_status_more_than_400(response)
# In other cases, raise a TransportProtocolError
raise
def _get_json_result(self, response: requests.Response) -> Any:

# Saving latest response headers in the transport
self.response_headers = response.headers

def _extract_response(self, response: requests.Response) -> Any:
try:
result = response.json()
result = self.json_deserialize(response.text)

if log.isEnabledFor(logging.DEBUG):
log.debug("<<< %s", response.text)
@@ -357,35 +373,34 @@ def _extract_response(self, response: requests.Response) -> Any:

return result

def _build_batch_post_args(
self,
reqs: List[GraphQLRequest],
timeout: Optional[int] = None,
extra_args: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
post_args: Dict[str, Any] = {
"headers": self.headers,
"auth": self.auth,
"cookies": self.cookies,
"timeout": timeout or self.default_timeout,
"verify": self.verify,
}
def _prepare_result(self, response: requests.Response) -> ExecutionResult:

data_key = "json" if self.use_json else "data"
post_args[data_key] = [req.payload for req in reqs]
result = self._get_json_result(response)

# Log the payload
if log.isEnabledFor(logging.DEBUG):
log.debug(">>> %s", self.json_serialize(post_args[data_key]))
if "errors" not in result and "data" not in result:
self._raise_response_error(response, 'No "data" or "errors" keys in answer')

# Pass kwargs to requests post method
post_args.update(self.kwargs)
return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)

# Pass post_args to requests post method
if extra_args:
post_args.update(extra_args)
def _prepare_batch_result(
self,
reqs: List[GraphQLRequest],
response: requests.Response,
) -> List[ExecutionResult]:

return post_args
answers = self._get_json_result(response)

try:
return get_batch_execution_result_list(reqs, answers)
except TransportProtocolError:
# Raise a TransportServerError if status > 400
self._raise_transport_server_error_if_status_more_than_400(response)
# In other cases, raise a TransportProtocolError
raise

def close(self):
"""Closing the transport by closing the inner session"""