Skip to content

Commit 944e499

Browse files
enabling automatic token refresh
1 parent 1c30528 commit 944e499

File tree

1 file changed

+102
-15
lines changed

1 file changed

+102
-15
lines changed

application_sdk/clients/temporal.py

Lines changed: 102 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import uuid
22
from concurrent.futures import ThreadPoolExecutor
3-
from typing import Any, Dict, Optional, Sequence, Type
3+
from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, Type
44

55
from temporalio import activity, workflow
66
from temporalio.client import Client, WorkflowExecutionStatus, WorkflowFailureError
7+
from temporalio.service import ServiceCallRequest, ServiceCallResponse
78
from temporalio.types import CallableType, ClassType
89
from temporalio.worker import (
910
ActivityInboundInterceptor,
@@ -51,6 +52,52 @@
5152
)
5253

5354

55+
class TokenRefreshInterceptor:
56+
"""RPC-level interceptor for automatic token refresh.
57+
58+
This interceptor ensures that all RPC calls to Temporal have a fresh
59+
authentication token, automatically refreshing tokens when needed.
60+
"""
61+
62+
def __init__(self, auth_manager: AuthManager):
63+
"""Initialize the token refresh interceptor.
64+
65+
Args:
66+
auth_manager: The authentication manager to use for token refresh.
67+
"""
68+
self.auth_manager = auth_manager
69+
70+
async def __call__(
71+
self,
72+
request: ServiceCallRequest,
73+
next_call: Callable[[ServiceCallRequest], Awaitable[ServiceCallResponse]],
74+
) -> ServiceCallResponse:
75+
"""Intercept all RPC calls and ensure fresh token.
76+
77+
This method is called for every RPC call to Temporal and ensures
78+
that the request includes a valid authentication token.
79+
80+
Args:
81+
request: The RPC request being made.
82+
next_call: The next interceptor or actual RPC call.
83+
84+
Returns:
85+
ServiceCallResponse: The response from the RPC call.
86+
"""
87+
if self.auth_manager.is_auth_enabled():
88+
try:
89+
# Get fresh token for each RPC call - AuthManager handles caching
90+
token = await self.auth_manager.get_access_token()
91+
if token:
92+
request.rpc_metadata = request.rpc_metadata or {}
93+
request.rpc_metadata["authorization"] = f"Bearer {token}"
94+
except Exception as e:
95+
logger.error(f"Failed to refresh token for RPC call: {e}")
96+
# Continue with existing token - let the call fail if needed
97+
98+
return await next_call(request)
99+
100+
54101
class EventActivityInboundInterceptor(ActivityInboundInterceptor):
55102
"""Interceptor for tracking activity execution events.
56103
@@ -72,7 +119,6 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any:
72119
event_name=ApplicationEventNames.ACTIVITY_START.value,
73120
data={},
74121
)
75-
76122
EventStore.publish_event(event)
77123

78124
output = None
@@ -121,6 +167,7 @@ async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any:
121167
data={},
122168
)
123169
)
170+
124171
output = None
125172
try:
126173
output = await super().execute_workflow(input)
@@ -190,11 +237,12 @@ def workflow_interceptor_class(
190237

191238

192239
class TemporalWorkflowClient(WorkflowClient):
193-
"""Temporal-specific implementation of WorkflowClient.
240+
"""Temporal-specific implementation of WorkflowClient with automatic token refresh.
194241
195242
This class provides an implementation of the WorkflowClient interface for
196243
the Temporal workflow engine. It handles connection management, workflow
197-
execution, and worker creation specific to Temporal.
244+
execution, and worker creation specific to Temporal. The client now includes
245+
automatic token refresh capabilities to handle long-running workflows.
198246
199247
Attributes:
200248
client: Temporal client instance.
@@ -301,11 +349,12 @@ def get_namespace(self) -> str:
301349
return self.namespace
302350

303351
async def load(self) -> None:
304-
"""Connect to the Temporal server.
352+
"""Connect to the Temporal server with automatic token refresh.
305353
306354
Establishes a connection to the Temporal server using the configured
307355
connection string and namespace. If authentication is enabled, includes
308-
the OAuth2 access token in the connection.
356+
the OAuth2 access token in the connection and sets up automatic token
357+
refresh for all RPC calls.
309358
310359
Raises:
311360
ConnectionError: If connection to the Temporal server fails.
@@ -322,13 +371,20 @@ async def load(self) -> None:
322371

323372
if self.auth_enabled:
324373
try:
374+
# Get initial token
325375
token = await self.auth_manager.get_access_token()
326376
if token:
327377
connection_options["rpc_metadata"] = {
328378
"authorization": f"Bearer {token}"
329379
}
380+
381+
# Add RPC-level interceptor for automatic token refresh
382+
connection_options["rpc_interceptors"] = [
383+
TokenRefreshInterceptor(self.auth_manager)
384+
]
385+
330386
except Exception as e:
331-
logger.error(f"Failed to get authentication headers: {e}")
387+
logger.error(f"Failed to get authentication token: {e}")
332388
raise
333389

334390
self.client = await Client.connect(**connection_options)
@@ -340,14 +396,39 @@ async def close(self) -> None:
340396
any authentication tokens. This is a no-op if the connection is
341397
already closed.
342398
"""
399+
if self.client:
400+
try:
401+
await self.client.close()
402+
except Exception as e:
403+
logger.warning(f"Error closing client connection: {e}")
404+
343405
if hasattr(self, "auth_manager"):
344406
self.auth_manager.clear_cache()
345-
return
407+
408+
async def ensure_valid_connection(self) -> None:
409+
"""Ensure the client has a valid connection with fresh token.
410+
411+
This method checks if the current authentication token is still valid
412+
and reconnects if the token has expired. This is used as a safeguard
413+
for client operations to prevent failures due to expired tokens.
414+
415+
Raises:
416+
ConnectionError: If reconnection fails.
417+
ValueError: If authentication is enabled but credentials are missing.
418+
"""
419+
if not self.auth_enabled:
420+
return
421+
422+
# Check if we need to reconnect due to token expiry
423+
if not await self.auth_manager.is_token_valid():
424+
logger.info("Token expired, reconnecting...")
425+
await self.close()
426+
await self.load()
346427

347428
async def start_workflow(
348429
self, workflow_args: Dict[str, Any], workflow_class: Type[WorkflowInterface]
349430
) -> Dict[str, Any]:
350-
"""Start a workflow execution.
431+
"""Start a workflow execution with token validation.
351432
352433
Args:
353434
workflow_args (Dict[str, Any]): Arguments for the workflow.
@@ -362,6 +443,9 @@ async def start_workflow(
362443
WorkflowFailureError: If the workflow fails to start.
363444
ValueError: If the client is not loaded.
364445
"""
446+
# Ensure we have a valid connection before starting
447+
await self.ensure_valid_connection()
448+
365449
# Check if credentials should be stored based on credentialSource
366450
should_store_credentials = False
367451
if "credentials" in workflow_args:
@@ -386,7 +470,6 @@ async def start_workflow(
386470
if not workflow_id:
387471
# if workflow_id is not provided, create a new one
388472
workflow_id = workflow_args.get("argo_workflow_name", str(uuid.uuid4()))
389-
390473
workflow_args.update(
391474
{
392475
"application_name": self.application_name,
@@ -400,14 +483,12 @@ async def start_workflow(
400483
# StateStore approach - store configuration and pass only workflow_id with flag
401484
StateStoreOutput.store_configuration(workflow_id, workflow_args)
402485
args = [{"workflow_id": workflow_id, "_use_statestore": True}]
403-
logger.info(
404-
f"Created workflow config with ID: {workflow_id} (StateStore approach)"
405-
)
486+
logger.info(f"Created workflow config with ID: {workflow_id} (StateStore)")
406487
else:
407488
# Direct approach - pass full configuration with flag
408489
workflow_args["_use_statestore"] = False
409490
args = [workflow_args]
410-
logger.info(f"Created workflow with ID: {workflow_id} (direct approach)")
491+
logger.info(f"Created workflow with ID: {workflow_id} (direct)")
411492

412493
try:
413494
# Get task_queue from workflow_args or credentials or use default
@@ -420,6 +501,7 @@ async def start_workflow(
420501
# Pass the conditional args to the workflow
421502
if not self.client:
422503
raise ValueError("Client is not loaded")
504+
423505
handle = await self.client.start_workflow(
424506
workflow_class, # type: ignore
425507
args=args,
@@ -428,8 +510,8 @@ async def start_workflow(
428510
cron_schedule=workflow_args.get("cron_schedule", ""),
429511
execution_timeout=WORKFLOW_MAX_TIMEOUT_HOURS,
430512
)
431-
logger.info(f"Workflow started: {handle.id} {handle.result_run_id}")
432513

514+
logger.info(f"Workflow started: {handle.id} {handle.result_run_id}")
433515
return {
434516
"workflow_id": handle.id,
435517
"run_id": handle.result_run_id,
@@ -449,8 +531,11 @@ async def stop_workflow(self, workflow_id: str, run_id: str) -> None:
449531
Raises:
450532
ValueError: If the client is not loaded.
451533
"""
534+
await self.ensure_valid_connection()
535+
452536
if not self.client:
453537
raise ValueError("Client is not loaded")
538+
454539
try:
455540
workflow_handle = self.client.get_workflow_handle(
456541
workflow_id, run_id=run_id
@@ -530,6 +615,8 @@ async def get_workflow_run_status(
530615
ValueError: If the client is not loaded.
531616
Exception: If there's an error getting the workflow status.
532617
"""
618+
await self.ensure_valid_connection()
619+
533620
if not self.client:
534621
raise ValueError("Client is not loaded")
535622

0 commit comments

Comments
 (0)