From e49f8515a890be7c578b7d45720bf1a3f16c4554 Mon Sep 17 00:00:00 2001
From: Syrus Akbary <me@syrusakbary.com>
Date: Thu, 6 Sep 2018 18:24:16 +0200
Subject: [PATCH] Added excution_context_class for custom ExecutionContext

---
 graphql/execution/execute.py | 87 ++++++++++++++++++------------------
 graphql/graphql.py           | 71 +++++++++++++++++------------
 2 files changed, 86 insertions(+), 72 deletions(-)

diff --git a/graphql/execution/execute.py b/graphql/execution/execute.py
index e8808de1..72875212 100644
--- a/graphql/execution/execute.py
+++ b/graphql/execution/execute.py
@@ -1,7 +1,7 @@
 from inspect import isawaitable
 from typing import (
     Any, Awaitable, Dict, Iterable, List, NamedTuple, Optional, Set, Union,
-    Tuple, cast)
+    Tuple, Type, cast)
 
 from ..error import GraphQLError, INVALID, located_error
 from ..language import (
@@ -60,48 +60,6 @@ class ExecutionResult(NamedTuple):
 ExecutionResult.__new__.__defaults__ = (None, None)  # type: ignore
 
 
-def execute(
-        schema: GraphQLSchema, document: DocumentNode,
-        root_value: Any=None, context_value: Any=None,
-        variable_values: Dict[str, Any]=None,
-        operation_name: str=None, field_resolver: GraphQLFieldResolver=None
-        ) -> MaybeAwaitable[ExecutionResult]:
-    """Execute a GraphQL operation.
-
-    Implements the "Evaluating requests" section of the GraphQL specification.
-
-    Returns an ExecutionResult (if all encountered resolvers are synchronous),
-    or a coroutine object eventually yielding an ExecutionResult.
-
-    If the arguments to this function do not result in a legal execution
-    context, a GraphQLError will be thrown immediately explaining the invalid
-    input.
-    """
-    # If arguments are missing or incorrect, throw an error.
-    assert_valid_execution_arguments(schema, document, variable_values)
-
-    # If a valid execution context cannot be created due to incorrect
-    #  arguments, a "Response" with only errors is returned.
-    exe_context = ExecutionContext.build(
-        schema, document, root_value, context_value,
-        variable_values, operation_name, field_resolver)
-
-    # Return early errors if execution context failed.
-    if isinstance(exe_context, list):
-        return ExecutionResult(data=None, errors=exe_context)
-
-    # Return a possible coroutine object that will eventually yield the data
-    # described by the "Response" section of the GraphQL specification.
-    #
-    # If errors are encountered while executing a GraphQL field, only that
-    # field and its descendants will be omitted, and sibling fields will still
-    # be executed. An execution which encounters errors will still result in a
-    # coroutine object that can be executed without errors.
-
-    data = exe_context.execute_operation(exe_context.operation, root_value)
-    return exe_context.build_response(data)
-
-
 class ExecutionContext:
     """Data that must be available at all points during query execution.
 
@@ -794,6 +752,49 @@ def collect_subfields(
         return sub_field_nodes
 
 
+def execute(
+        schema: GraphQLSchema, document: DocumentNode,
+        root_value: Any=None, context_value: Any=None,
+        variable_values: Dict[str, Any]=None,
+        operation_name: str=None, field_resolver: GraphQLFieldResolver=None,
+        execution_context_class: Type[ExecutionContext]=ExecutionContext,
+        ) -> MaybeAwaitable[ExecutionResult]:
+    """Execute a GraphQL operation.
+
+    Implements the "Evaluating requests" section of the GraphQL specification.
+
+    Returns an ExecutionResult (if all encountered resolvers are synchronous),
+    or a coroutine object eventually yielding an ExecutionResult.
+
+    If the arguments to this function do not result in a legal execution
+    context, a GraphQLError will be thrown immediately explaining the invalid
+    input.
+    """
+    # If arguments are missing or incorrect, throw an error.
+    assert_valid_execution_arguments(schema, document, variable_values)
+
+    # If a valid execution context cannot be created due to incorrect
+    #  arguments, a "Response" with only errors is returned.
+    exe_context = execution_context_class.build(
+        schema, document, root_value, context_value,
+        variable_values, operation_name, field_resolver)
+
+    # Return early errors if execution context failed.
+    if isinstance(exe_context, list):
+        return ExecutionResult(data=None, errors=exe_context)
+
+    # Return a possible coroutine object that will eventually yield the data
+    # described by the "Response" section of the GraphQL specification.
+    #
+    # If errors are encountered while executing a GraphQL field, only that
+    # field and its descendants will be omitted, and sibling fields will still
+    # be executed. An execution which encounters errors will still result in a
+    # coroutine object that can be executed without errors.
+
+    data = exe_context.execute_operation(exe_context.operation, root_value)
+    return exe_context.build_response(data)
+
+
 def assert_valid_execution_arguments(
         schema: GraphQLSchema, document: DocumentNode,
         raw_variable_values: Dict[str, Any]=None) -> None:
diff --git a/graphql/graphql.py b/graphql/graphql.py
index a5de20f6..3fda1d18 100644
--- a/graphql/graphql.py
+++ b/graphql/graphql.py
@@ -1,25 +1,27 @@
 from asyncio import ensure_future
 from inspect import isawaitable
-from typing import Any, Awaitable, Callable, Dict, Union, cast
+from typing import Any, Awaitable, Callable, Dict, Union, Type, cast
 
 from .error import GraphQLError
 from .execution import execute
 from .language import parse, Source
 from .pyutils import MaybeAwaitable
 from .type import GraphQLSchema, validate_schema
-from .execution.execute import ExecutionResult
+from .execution.execute import ExecutionResult, ExecutionContext
 
-__all__ = ['graphql', 'graphql_sync']
+__all__ = ["graphql", "graphql_sync"]
 
 
 async def graphql(
-        schema: GraphQLSchema,
-        source: Union[str, Source],
-        root_value: Any=None,
-        context_value: Any=None,
-        variable_values: Dict[str, Any]=None,
-        operation_name: str=None,
-        field_resolver: Callable=None) -> ExecutionResult:
+    schema: GraphQLSchema,
+    source: Union[str, Source],
+    root_value: Any = None,
+    context_value: Any = None,
+    variable_values: Dict[str, Any] = None,
+    operation_name: str = None,
+    field_resolver: Callable = None,
+    execution_context_class: Type[ExecutionContext] = ExecutionContext,
+) -> ExecutionResult:
     """Execute a GraphQL operation asynchronously.
 
     This is the primary entry point function for fulfilling GraphQL operations
@@ -65,7 +67,9 @@ async def graphql(
         context_value,
         variable_values,
         operation_name,
-        field_resolver)
+        field_resolver,
+        execution_context_class,
+    )
 
     if isawaitable(result):
         return await cast(Awaitable[ExecutionResult], result)
@@ -74,13 +78,15 @@ async def graphql(
 
 
 def graphql_sync(
-        schema: GraphQLSchema,
-        source: Union[str, Source],
-        root_value: Any = None,
-        context_value: Any = None,
-        variable_values: Dict[str, Any] = None,
-        operation_name: str = None,
-        field_resolver: Callable = None) -> ExecutionResult:
+    schema: GraphQLSchema,
+    source: Union[str, Source],
+    root_value: Any = None,
+    context_value: Any = None,
+    variable_values: Dict[str, Any] = None,
+    operation_name: str = None,
+    field_resolver: Callable = None,
+    execution_context_class: Type[ExecutionContext] = ExecutionContext,
+) -> ExecutionResult:
     """Execute a GraphQL operation synchronously.
 
     The graphql_sync function also fulfills GraphQL operations by parsing,
@@ -95,26 +101,30 @@ def graphql_sync(
         context_value,
         variable_values,
         operation_name,
-        field_resolver)
+        field_resolver,
+        execution_context_class,
+    )
 
     # Assert that the execution was synchronous.
     if isawaitable(result):
         ensure_future(cast(Awaitable[ExecutionResult], result)).cancel()
         raise RuntimeError(
-            'GraphQL execution failed to complete synchronously.')
+            "GraphQL execution failed to complete synchronously."
+        )
 
     return cast(ExecutionResult, result)
 
 
 def graphql_impl(
-        schema,
-        source,
-        root_value,
-        context_value,
-        variable_values,
-        operation_name,
-        field_resolver
-        ) -> MaybeAwaitable[ExecutionResult]:
+    schema,
+    source,
+    root_value,
+    context_value,
+    variable_values,
+    operation_name,
+    field_resolver,
+    execution_context_class,
+) -> MaybeAwaitable[ExecutionResult]:
     """Execute a query, return asynchronously only if necessary."""
     # Validate Schema
     schema_validation_errors = validate_schema(schema)
@@ -132,6 +142,7 @@ def graphql_impl(
 
     # Validate
     from .validation import validate
+
     validation_errors = validate(schema, document)
     if validation_errors:
         return ExecutionResult(data=None, errors=validation_errors)
@@ -144,4 +155,6 @@ def graphql_impl(
         context_value,
         variable_values,
         operation_name,
-        field_resolver)
+        field_resolver,
+        execution_context_class,
+    )