@@ -216,13 +216,20 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
216
216
await self .websocket .close (code = 4401 , reason = "Unauthorized" )
217
217
return
218
218
219
- try :
220
- graphql_document = parse (message ["payload" ]["query" ])
221
- except GraphQLSyntaxError as exc :
222
- await self .websocket .close (code = 4400 , reason = exc .message )
223
- return
219
+ request_data = await self .view .get_graphql_request_data (
220
+ self .websocket , self .context , message ["payload" ], "subscription"
221
+ )
224
222
225
- operation_name = message ["payload" ].get ("operationName" )
223
+ if request_data .document is not None :
224
+ graphql_document = request_data .document
225
+ else :
226
+ try :
227
+ graphql_document = parse (request_data .query )
228
+ except GraphQLSyntaxError as exc :
229
+ await self .websocket .close (code = 4400 , reason = exc .message )
230
+ return
231
+
232
+ operation_name = request_data .operation_name
226
233
227
234
try :
228
235
operation_type = get_operation_type (graphql_document , operation_name )
@@ -248,15 +255,11 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
248
255
249
256
if self .debug : # pragma: no cover
250
257
pretty_print_graphql_operation (
251
- message [ "payload" ]. get ( "operationName" ) ,
252
- message [ "payload" ][ " query" ] ,
253
- message [ "payload" ]. get ( " variables" ) ,
258
+ request_data . operation_name ,
259
+ request_data . query ,
260
+ request_data . variables ,
254
261
)
255
262
256
- request_data = await self .view .get_graphql_request_data (
257
- self .websocket , self .context , message ["payload" ], "subscription"
258
- )
259
-
260
263
operation = Operation (
261
264
self ,
262
265
message ["id" ],
0 commit comments