11import uuid
22from concurrent .futures import ThreadPoolExecutor
3- from typing import Any , Dict , Optional , Sequence , Type
3+ from typing import Any , Awaitable , Callable , Dict , Optional , Sequence , Type
44
55from temporalio import activity , workflow
66from temporalio .client import Client , WorkflowExecutionStatus , WorkflowFailureError
7+ from temporalio .service import ServiceCallRequest , ServiceCallResponse
78from temporalio .types import CallableType , ClassType
89from temporalio .worker import (
910 ActivityInboundInterceptor ,
5152)
5253
5354
55+ class TokenRefreshInterceptor :
56+ """RPC-level interceptor for automatic token refresh.
57+
58+ This interceptor ensures that all RPC calls to Temporal have a fresh
59+ authentication token, automatically refreshing tokens when needed.
60+ """
61+
62+ def __init__ (self , auth_manager : AuthManager ):
63+ """Initialize the token refresh interceptor.
64+
65+ Args:
66+ auth_manager: The authentication manager to use for token refresh.
67+ """
68+ self .auth_manager = auth_manager
69+
70+ async def __call__ (
71+ self ,
72+ request : ServiceCallRequest ,
73+ next_call : Callable [[ServiceCallRequest ], Awaitable [ServiceCallResponse ]],
74+ ) -> ServiceCallResponse :
75+ """Intercept all RPC calls and ensure fresh token.
76+
77+ This method is called for every RPC call to Temporal and ensures
78+ that the request includes a valid authentication token.
79+
80+ Args:
81+ request: The RPC request being made.
82+ next_call: The next interceptor or actual RPC call.
83+
84+ Returns:
85+ ServiceCallResponse: The response from the RPC call.
86+ """
87+ if self .auth_manager .is_auth_enabled ():
88+ try :
89+ # Get fresh token for each RPC call - AuthManager handles caching
90+ token = await self .auth_manager .get_access_token ()
91+ if token :
92+ request .rpc_metadata = request .rpc_metadata or {}
93+ request .rpc_metadata ["authorization" ] = f"Bearer { token } "
94+ except Exception as e :
95+ logger .error (f"Failed to refresh token for RPC call: { e } " )
96+ # Continue with existing token - let the call fail if needed
97+
98+ return await next_call (request )
99+
100+
54101class EventActivityInboundInterceptor (ActivityInboundInterceptor ):
55102 """Interceptor for tracking activity execution events.
56103
@@ -72,7 +119,6 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any:
72119 event_name = ApplicationEventNames .ACTIVITY_START .value ,
73120 data = {},
74121 )
75-
76122 EventStore .publish_event (event )
77123
78124 output = None
@@ -121,6 +167,7 @@ async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any:
121167 data = {},
122168 )
123169 )
170+
124171 output = None
125172 try :
126173 output = await super ().execute_workflow (input )
@@ -190,11 +237,12 @@ def workflow_interceptor_class(
190237
191238
192239class TemporalWorkflowClient (WorkflowClient ):
193- """Temporal-specific implementation of WorkflowClient.
240+ """Temporal-specific implementation of WorkflowClient with automatic token refresh .
194241
195242 This class provides an implementation of the WorkflowClient interface for
196243 the Temporal workflow engine. It handles connection management, workflow
197- execution, and worker creation specific to Temporal.
244+ execution, and worker creation specific to Temporal. The client now includes
245+ automatic token refresh capabilities to handle long-running workflows.
198246
199247 Attributes:
200248 client: Temporal client instance.
@@ -301,11 +349,12 @@ def get_namespace(self) -> str:
301349 return self .namespace
302350
303351 async def load (self ) -> None :
304- """Connect to the Temporal server.
352+ """Connect to the Temporal server with automatic token refresh .
305353
306354 Establishes a connection to the Temporal server using the configured
307355 connection string and namespace. If authentication is enabled, includes
308- the OAuth2 access token in the connection.
356+ the OAuth2 access token in the connection and sets up automatic token
357+ refresh for all RPC calls.
309358
310359 Raises:
311360 ConnectionError: If connection to the Temporal server fails.
@@ -322,13 +371,20 @@ async def load(self) -> None:
322371
323372 if self .auth_enabled :
324373 try :
374+ # Get initial token
325375 token = await self .auth_manager .get_access_token ()
326376 if token :
327377 connection_options ["rpc_metadata" ] = {
328378 "authorization" : f"Bearer { token } "
329379 }
380+
381+ # Add RPC-level interceptor for automatic token refresh
382+ connection_options ["rpc_interceptors" ] = [
383+ TokenRefreshInterceptor (self .auth_manager )
384+ ]
385+
330386 except Exception as e :
331- logger .error (f"Failed to get authentication headers : { e } " )
387+ logger .error (f"Failed to get authentication token : { e } " )
332388 raise
333389
334390 self .client = await Client .connect (** connection_options )
@@ -340,14 +396,39 @@ async def close(self) -> None:
340396 any authentication tokens. This is a no-op if the connection is
341397 already closed.
342398 """
399+ if self .client :
400+ try :
401+ await self .client .close ()
402+ except Exception as e :
403+ logger .warning (f"Error closing client connection: { e } " )
404+
343405 if hasattr (self , "auth_manager" ):
344406 self .auth_manager .clear_cache ()
345- return
407+
408+ async def ensure_valid_connection (self ) -> None :
409+ """Ensure the client has a valid connection with fresh token.
410+
411+ This method checks if the current authentication token is still valid
412+ and reconnects if the token has expired. This is used as a safeguard
413+ for client operations to prevent failures due to expired tokens.
414+
415+ Raises:
416+ ConnectionError: If reconnection fails.
417+ ValueError: If authentication is enabled but credentials are missing.
418+ """
419+ if not self .auth_enabled :
420+ return
421+
422+ # Check if we need to reconnect due to token expiry
423+ if not await self .auth_manager .is_token_valid ():
424+ logger .info ("Token expired, reconnecting..." )
425+ await self .close ()
426+ await self .load ()
346427
347428 async def start_workflow (
348429 self , workflow_args : Dict [str , Any ], workflow_class : Type [WorkflowInterface ]
349430 ) -> Dict [str , Any ]:
350- """Start a workflow execution.
431+ """Start a workflow execution with token validation .
351432
352433 Args:
353434 workflow_args (Dict[str, Any]): Arguments for the workflow.
@@ -362,6 +443,9 @@ async def start_workflow(
362443 WorkflowFailureError: If the workflow fails to start.
363444 ValueError: If the client is not loaded.
364445 """
446+ # Ensure we have a valid connection before starting
447+ await self .ensure_valid_connection ()
448+
365449 # Check if credentials should be stored based on credentialSource
366450 should_store_credentials = False
367451 if "credentials" in workflow_args :
@@ -386,7 +470,6 @@ async def start_workflow(
386470 if not workflow_id :
387471 # if workflow_id is not provided, create a new one
388472 workflow_id = workflow_args .get ("argo_workflow_name" , str (uuid .uuid4 ()))
389-
390473 workflow_args .update (
391474 {
392475 "application_name" : self .application_name ,
@@ -400,14 +483,12 @@ async def start_workflow(
400483 # StateStore approach - store configuration and pass only workflow_id with flag
401484 StateStoreOutput .store_configuration (workflow_id , workflow_args )
402485 args = [{"workflow_id" : workflow_id , "_use_statestore" : True }]
403- logger .info (
404- f"Created workflow config with ID: { workflow_id } (StateStore approach)"
405- )
486+ logger .info (f"Created workflow config with ID: { workflow_id } (StateStore)" )
406487 else :
407488 # Direct approach - pass full configuration with flag
408489 workflow_args ["_use_statestore" ] = False
409490 args = [workflow_args ]
410- logger .info (f"Created workflow with ID: { workflow_id } (direct approach )" )
491+ logger .info (f"Created workflow with ID: { workflow_id } (direct)" )
411492
412493 try :
413494 # Get task_queue from workflow_args or credentials or use default
@@ -420,6 +501,7 @@ async def start_workflow(
420501 # Pass the conditional args to the workflow
421502 if not self .client :
422503 raise ValueError ("Client is not loaded" )
504+
423505 handle = await self .client .start_workflow (
424506 workflow_class , # type: ignore
425507 args = args ,
@@ -428,8 +510,8 @@ async def start_workflow(
428510 cron_schedule = workflow_args .get ("cron_schedule" , "" ),
429511 execution_timeout = WORKFLOW_MAX_TIMEOUT_HOURS ,
430512 )
431- logger .info (f"Workflow started: { handle .id } { handle .result_run_id } " )
432513
514+ logger .info (f"Workflow started: { handle .id } { handle .result_run_id } " )
433515 return {
434516 "workflow_id" : handle .id ,
435517 "run_id" : handle .result_run_id ,
@@ -449,8 +531,11 @@ async def stop_workflow(self, workflow_id: str, run_id: str) -> None:
449531 Raises:
450532 ValueError: If the client is not loaded.
451533 """
534+ await self .ensure_valid_connection ()
535+
452536 if not self .client :
453537 raise ValueError ("Client is not loaded" )
538+
454539 try :
455540 workflow_handle = self .client .get_workflow_handle (
456541 workflow_id , run_id = run_id
@@ -530,6 +615,8 @@ async def get_workflow_run_status(
530615 ValueError: If the client is not loaded.
531616 Exception: If there's an error getting the workflow status.
532617 """
618+ await self .ensure_valid_connection ()
619+
533620 if not self .client :
534621 raise ValueError ("Client is not loaded" )
535622
0 commit comments