22
33from __future__ import annotations
44
5+ import dataclasses
56from collections .abc import Callable , Iterator , Mapping , Sequence
67from contextlib import contextmanager
78from dataclasses import dataclass
89from typing import (
910 Any ,
10- Dict ,
11+ Generic ,
1112 NoReturn ,
12- Optional ,
13- Type ,
1413 TypeAlias ,
14+ TypeVar ,
1515 cast ,
1616)
1717
18+ import nexusrpc .handler
1819import opentelemetry .baggage .propagation
1920import opentelemetry .context
2021import opentelemetry .context .context
5455
5556_CarrierDict : TypeAlias = dict [str , opentelemetry .propagators .textmap .CarrierValT ]
5657
58+ _ContextT = TypeVar ("_ContextT" , bound = nexusrpc .handler .OperationContext )
59+
5760
5861class TracingInterceptor (temporalio .client .Interceptor , temporalio .worker .Interceptor ):
5962 """Interceptor that supports client and worker OpenTelemetry span creation
@@ -133,6 +136,14 @@ def workflow_interceptor_class(
133136 )
134137 return TracingWorkflowInboundInterceptor
135138
139+ def intercept_nexus_operation (
140+ self , next : temporalio .worker .NexusOperationInboundInterceptor
141+ ) -> temporalio .worker .NexusOperationInboundInterceptor :
142+ """Implementation of
143+ :py:meth:`temporalio.worker.Interceptor.intercept_nexus_operation`.
144+ """
145+ return _TracingNexusOperationInboundInterceptor (next , self )
146+
136147 def _context_to_headers (
137148 self , headers : Mapping [str , temporalio .api .common .v1 .Payload ]
138149 ) -> Mapping [str , temporalio .api .common .v1 .Payload ]:
@@ -166,7 +177,8 @@ def _start_as_current_span(
166177 name : str ,
167178 * ,
168179 attributes : opentelemetry .util .types .Attributes ,
169- input : _InputWithHeaders | None = None ,
180+ input_with_headers : _InputWithHeaders | None = None ,
181+ input_with_ctx : _InputWithOperationContext | None = None ,
170182 kind : opentelemetry .trace .SpanKind ,
171183 context : Context | None = None ,
172184 ) -> Iterator [None ]:
@@ -179,8 +191,19 @@ def _start_as_current_span(
179191 context = context ,
180192 set_status_on_exception = False ,
181193 ) as span :
182- if input :
183- input .headers = self ._context_to_headers (input .headers )
194+ if input_with_headers :
195+ input_with_headers .headers = self ._context_to_headers (
196+ input_with_headers .headers
197+ )
198+ if input_with_ctx :
199+ carrier : _CarrierDict = {}
200+ self .text_map_propagator .inject (carrier )
201+ input_with_ctx .ctx = dataclasses .replace (
202+ input_with_ctx .ctx ,
203+ headers = _carrier_to_nexus_headers (
204+ carrier , input_with_ctx .ctx .headers
205+ ),
206+ )
184207 try :
185208 yield None
186209 except Exception as exc :
@@ -258,7 +281,7 @@ async def start_workflow(
258281 with self .root ._start_as_current_span (
259282 f"{ prefix } :{ input .workflow } " ,
260283 attributes = {"temporalWorkflowID" : input .id },
261- input = input ,
284+ input_with_headers = input ,
262285 kind = opentelemetry .trace .SpanKind .CLIENT ,
263286 ):
264287 return await super ().start_workflow (input )
@@ -267,7 +290,7 @@ async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> A
267290 with self .root ._start_as_current_span (
268291 f"QueryWorkflow:{ input .query } " ,
269292 attributes = {"temporalWorkflowID" : input .id },
270- input = input ,
293+ input_with_headers = input ,
271294 kind = opentelemetry .trace .SpanKind .CLIENT ,
272295 ):
273296 return await super ().query_workflow (input )
@@ -278,7 +301,7 @@ async def signal_workflow(
278301 with self .root ._start_as_current_span (
279302 f"SignalWorkflow:{ input .signal } " ,
280303 attributes = {"temporalWorkflowID" : input .id },
281- input = input ,
304+ input_with_headers = input ,
282305 kind = opentelemetry .trace .SpanKind .CLIENT ,
283306 ):
284307 return await super ().signal_workflow (input )
@@ -289,7 +312,7 @@ async def start_workflow_update(
289312 with self .root ._start_as_current_span (
290313 f"StartWorkflowUpdate:{ input .update } " ,
291314 attributes = {"temporalWorkflowID" : input .id },
292- input = input ,
315+ input_with_headers = input ,
293316 kind = opentelemetry .trace .SpanKind .CLIENT ,
294317 ):
295318 return await super ().start_workflow_update (input )
@@ -306,7 +329,7 @@ async def start_update_with_start_workflow(
306329 with self .root ._start_as_current_span (
307330 f"StartUpdateWithStartWorkflow:{ input .start_workflow_input .workflow } " ,
308331 attributes = attrs ,
309- input = input .start_workflow_input ,
332+ input_with_headers = input .start_workflow_input ,
310333 kind = opentelemetry .trace .SpanKind .CLIENT ,
311334 ):
312335 otel_header = input .start_workflow_input .headers .get (self .root .header_key )
@@ -345,10 +368,60 @@ async def execute_activity(
345368 return await super ().execute_activity (input )
346369
347370
371+ class _TracingNexusOperationInboundInterceptor (
372+ temporalio .worker .NexusOperationInboundInterceptor
373+ ):
374+ def __init__ (
375+ self ,
376+ next : temporalio .worker .NexusOperationInboundInterceptor ,
377+ root : TracingInterceptor ,
378+ ) -> None :
379+ super ().__init__ (next )
380+ self ._root = root
381+
382+ def _context_from_nexus_headers (self , headers : Mapping [str , str ]):
383+ return self ._root .text_map_propagator .extract (headers )
384+
385+ async def execute_nexus_operation_start (
386+ self , input : temporalio .worker .ExecuteNexusOperationStartInput
387+ ) -> (
388+ nexusrpc .handler .StartOperationResultSync [Any ]
389+ | nexusrpc .handler .StartOperationResultAsync
390+ ):
391+ with self ._root ._start_as_current_span (
392+ f"RunStartNexusOperationHandler:{ input .ctx .service } /{ input .ctx .operation } " ,
393+ context = self ._context_from_nexus_headers (input .ctx .headers ),
394+ attributes = {},
395+ input_with_ctx = input ,
396+ kind = opentelemetry .trace .SpanKind .SERVER ,
397+ ):
398+ return await self .next .execute_nexus_operation_start (input )
399+
400+ async def execute_nexus_operation_cancel (
401+ self , input : temporalio .worker .ExecuteNexusOperationCancelInput
402+ ) -> None :
403+ with self ._root ._start_as_current_span (
404+ f"RunCancelNexusOperationHandler:{ input .ctx .service } /{ input .ctx .operation } " ,
405+ context = self ._context_from_nexus_headers (input .ctx .headers ),
406+ attributes = {},
407+ input_with_ctx = input ,
408+ kind = opentelemetry .trace .SpanKind .SERVER ,
409+ ):
410+ return await self .next .execute_nexus_operation_cancel (input )
411+
412+
348413class _InputWithHeaders (Protocol ):
349414 headers : Mapping [str , temporalio .api .common .v1 .Payload ]
350415
351416
417+ class _InputWithStringHeaders (Protocol ):
418+ headers : Mapping [str , str ] | None
419+
420+
421+ class _InputWithOperationContext (Generic [_ContextT ], Protocol ):
422+ ctx : _ContextT
423+
424+
352425class _WorkflowExternFunctions (TypedDict ):
353426 __temporal_opentelemetry_completed_span : Callable [
354427 [_CompletedWorkflowSpanParams ], _CarrierDict | None
@@ -602,6 +675,7 @@ def _completed_span(
602675 * ,
603676 link_context_carrier : _CarrierDict | None = None ,
604677 add_to_outbound : _InputWithHeaders | None = None ,
678+ add_to_outbound_str : _InputWithStringHeaders | None = None ,
605679 new_span_even_on_replay : bool = False ,
606680 additional_attributes : opentelemetry .util .types .Attributes = None ,
607681 exception : Exception | None = None ,
@@ -614,12 +688,14 @@ def _completed_span(
614688 # Create the span. First serialize current context to carrier.
615689 new_context_carrier : _CarrierDict = {}
616690 self .text_map_propagator .inject (new_context_carrier )
691+
617692 # Invoke
618693 info = temporalio .workflow .info ()
619694 attributes : dict [str , opentelemetry .util .types .AttributeValue ] = {
620695 "temporalWorkflowID" : info .workflow_id ,
621696 "temporalRunID" : info .run_id ,
622697 }
698+
623699 if additional_attributes :
624700 attributes .update (additional_attributes )
625701 updated_context_carrier = self ._extern_functions [
@@ -640,10 +716,16 @@ def _completed_span(
640716 )
641717
642718 # Add to outbound if needed
643- if add_to_outbound and updated_context_carrier :
644- add_to_outbound .headers = self ._context_carrier_to_headers (
645- updated_context_carrier , add_to_outbound .headers
646- )
719+ if updated_context_carrier :
720+ if add_to_outbound :
721+ add_to_outbound .headers = self ._context_carrier_to_headers (
722+ updated_context_carrier , add_to_outbound .headers
723+ )
724+
725+ if add_to_outbound_str :
726+ add_to_outbound_str .headers = _carrier_to_nexus_headers (
727+ updated_context_carrier , add_to_outbound_str .headers
728+ )
647729
648730 def _set_on_context (
649731 self , context : opentelemetry .context .Context
@@ -722,6 +804,29 @@ def start_local_activity(
722804 )
723805 return super ().start_local_activity (input )
724806
807+ async def start_nexus_operation (
808+ self , input : temporalio .worker .StartNexusOperationInput [Any , Any ]
809+ ) -> temporalio .workflow .NexusOperationHandle [Any ]:
810+ self .root ._completed_span (
811+ f"StartNexusOperation:{ input .service } /{ input .operation_name } " ,
812+ kind = opentelemetry .trace .SpanKind .CLIENT ,
813+ add_to_outbound_str = input ,
814+ )
815+
816+ return await super ().start_nexus_operation (input )
817+
818+
819+ def _carrier_to_nexus_headers (
820+ carrier : _CarrierDict , initial : Mapping [str , str ] | None = None
821+ ) -> Mapping [str , str ]:
822+ out = {** initial } if initial else {}
823+ for k , v in carrier .items ():
824+ if isinstance (v , list ):
825+ out [k ] = "," .join (v )
826+ else :
827+ out [k ] = v
828+ return out
829+
725830
726831class workflow :
727832 """Contains static methods that are safe to call from within a workflow.
0 commit comments