Skip to content

Commit b0adaae

Browse files
Nexus interceptors (#1218)
* use middleware branch of nexus-rpc * First draft of nexus interceptors and otel support * Add docstrings * Update tests to use @workflow_run_operation to avoid hacky cancel awaiting * PR feedback * remove OTel attribute that other SDKs are not sending * Use uv to update to the head of target nexus-rpc branch. * use nexus-rpc 1.3.0 * fix reference to nexus-rpc 1.3.0 in lock * fix indentation in docstring
1 parent fd17cdf commit b0adaae

File tree

8 files changed

+450
-38
lines changed

8 files changed

+450
-38
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ license = "MIT"
99
license-files = ["LICENSE"]
1010
keywords = ["temporal", "workflow"]
1111
dependencies = [
12-
"nexus-rpc==1.2.0",
12+
"nexus-rpc==1.3.0",
1313
"protobuf>=3.20,<7.0.0",
1414
"python-dateutil>=2.8.2,<3 ; python_version < '3.11'",
1515
"types-protobuf>=3.20",

temporalio/contrib/opentelemetry.py

Lines changed: 120 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,20 @@
22

33
from __future__ import annotations
44

5+
import dataclasses
56
from collections.abc import Callable, Iterator, Mapping, Sequence
67
from contextlib import contextmanager
78
from dataclasses import dataclass
89
from typing import (
910
Any,
10-
Dict,
11+
Generic,
1112
NoReturn,
12-
Optional,
13-
Type,
1413
TypeAlias,
14+
TypeVar,
1515
cast,
1616
)
1717

18+
import nexusrpc.handler
1819
import opentelemetry.baggage.propagation
1920
import opentelemetry.context
2021
import opentelemetry.context.context
@@ -54,6 +55,8 @@
5455

5556
_CarrierDict: TypeAlias = dict[str, opentelemetry.propagators.textmap.CarrierValT]
5657

58+
_ContextT = TypeVar("_ContextT", bound=nexusrpc.handler.OperationContext)
59+
5760

5861
class TracingInterceptor(temporalio.client.Interceptor, temporalio.worker.Interceptor):
5962
"""Interceptor that supports client and worker OpenTelemetry span creation
@@ -133,6 +136,14 @@ def workflow_interceptor_class(
133136
)
134137
return TracingWorkflowInboundInterceptor
135138

139+
def intercept_nexus_operation(
140+
self, next: temporalio.worker.NexusOperationInboundInterceptor
141+
) -> temporalio.worker.NexusOperationInboundInterceptor:
142+
"""Implementation of
143+
:py:meth:`temporalio.worker.Interceptor.intercept_nexus_operation`.
144+
"""
145+
return _TracingNexusOperationInboundInterceptor(next, self)
146+
136147
def _context_to_headers(
137148
self, headers: Mapping[str, temporalio.api.common.v1.Payload]
138149
) -> Mapping[str, temporalio.api.common.v1.Payload]:
@@ -166,7 +177,8 @@ def _start_as_current_span(
166177
name: str,
167178
*,
168179
attributes: opentelemetry.util.types.Attributes,
169-
input: _InputWithHeaders | None = None,
180+
input_with_headers: _InputWithHeaders | None = None,
181+
input_with_ctx: _InputWithOperationContext | None = None,
170182
kind: opentelemetry.trace.SpanKind,
171183
context: Context | None = None,
172184
) -> Iterator[None]:
@@ -179,8 +191,19 @@ def _start_as_current_span(
179191
context=context,
180192
set_status_on_exception=False,
181193
) as span:
182-
if input:
183-
input.headers = self._context_to_headers(input.headers)
194+
if input_with_headers:
195+
input_with_headers.headers = self._context_to_headers(
196+
input_with_headers.headers
197+
)
198+
if input_with_ctx:
199+
carrier: _CarrierDict = {}
200+
self.text_map_propagator.inject(carrier)
201+
input_with_ctx.ctx = dataclasses.replace(
202+
input_with_ctx.ctx,
203+
headers=_carrier_to_nexus_headers(
204+
carrier, input_with_ctx.ctx.headers
205+
),
206+
)
184207
try:
185208
yield None
186209
except Exception as exc:
@@ -258,7 +281,7 @@ async def start_workflow(
258281
with self.root._start_as_current_span(
259282
f"{prefix}:{input.workflow}",
260283
attributes={"temporalWorkflowID": input.id},
261-
input=input,
284+
input_with_headers=input,
262285
kind=opentelemetry.trace.SpanKind.CLIENT,
263286
):
264287
return await super().start_workflow(input)
@@ -267,7 +290,7 @@ async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> A
267290
with self.root._start_as_current_span(
268291
f"QueryWorkflow:{input.query}",
269292
attributes={"temporalWorkflowID": input.id},
270-
input=input,
293+
input_with_headers=input,
271294
kind=opentelemetry.trace.SpanKind.CLIENT,
272295
):
273296
return await super().query_workflow(input)
@@ -278,7 +301,7 @@ async def signal_workflow(
278301
with self.root._start_as_current_span(
279302
f"SignalWorkflow:{input.signal}",
280303
attributes={"temporalWorkflowID": input.id},
281-
input=input,
304+
input_with_headers=input,
282305
kind=opentelemetry.trace.SpanKind.CLIENT,
283306
):
284307
return await super().signal_workflow(input)
@@ -289,7 +312,7 @@ async def start_workflow_update(
289312
with self.root._start_as_current_span(
290313
f"StartWorkflowUpdate:{input.update}",
291314
attributes={"temporalWorkflowID": input.id},
292-
input=input,
315+
input_with_headers=input,
293316
kind=opentelemetry.trace.SpanKind.CLIENT,
294317
):
295318
return await super().start_workflow_update(input)
@@ -306,7 +329,7 @@ async def start_update_with_start_workflow(
306329
with self.root._start_as_current_span(
307330
f"StartUpdateWithStartWorkflow:{input.start_workflow_input.workflow}",
308331
attributes=attrs,
309-
input=input.start_workflow_input,
332+
input_with_headers=input.start_workflow_input,
310333
kind=opentelemetry.trace.SpanKind.CLIENT,
311334
):
312335
otel_header = input.start_workflow_input.headers.get(self.root.header_key)
@@ -345,10 +368,60 @@ async def execute_activity(
345368
return await super().execute_activity(input)
346369

347370

371+
class _TracingNexusOperationInboundInterceptor(
372+
temporalio.worker.NexusOperationInboundInterceptor
373+
):
374+
def __init__(
375+
self,
376+
next: temporalio.worker.NexusOperationInboundInterceptor,
377+
root: TracingInterceptor,
378+
) -> None:
379+
super().__init__(next)
380+
self._root = root
381+
382+
def _context_from_nexus_headers(self, headers: Mapping[str, str]):
383+
return self._root.text_map_propagator.extract(headers)
384+
385+
async def execute_nexus_operation_start(
386+
self, input: temporalio.worker.ExecuteNexusOperationStartInput
387+
) -> (
388+
nexusrpc.handler.StartOperationResultSync[Any]
389+
| nexusrpc.handler.StartOperationResultAsync
390+
):
391+
with self._root._start_as_current_span(
392+
f"RunStartNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}",
393+
context=self._context_from_nexus_headers(input.ctx.headers),
394+
attributes={},
395+
input_with_ctx=input,
396+
kind=opentelemetry.trace.SpanKind.SERVER,
397+
):
398+
return await self.next.execute_nexus_operation_start(input)
399+
400+
async def execute_nexus_operation_cancel(
401+
self, input: temporalio.worker.ExecuteNexusOperationCancelInput
402+
) -> None:
403+
with self._root._start_as_current_span(
404+
f"RunCancelNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}",
405+
context=self._context_from_nexus_headers(input.ctx.headers),
406+
attributes={},
407+
input_with_ctx=input,
408+
kind=opentelemetry.trace.SpanKind.SERVER,
409+
):
410+
return await self.next.execute_nexus_operation_cancel(input)
411+
412+
348413
class _InputWithHeaders(Protocol):
349414
headers: Mapping[str, temporalio.api.common.v1.Payload]
350415

351416

417+
class _InputWithStringHeaders(Protocol):
418+
headers: Mapping[str, str] | None
419+
420+
421+
class _InputWithOperationContext(Generic[_ContextT], Protocol):
422+
ctx: _ContextT
423+
424+
352425
class _WorkflowExternFunctions(TypedDict):
353426
__temporal_opentelemetry_completed_span: Callable[
354427
[_CompletedWorkflowSpanParams], _CarrierDict | None
@@ -602,6 +675,7 @@ def _completed_span(
602675
*,
603676
link_context_carrier: _CarrierDict | None = None,
604677
add_to_outbound: _InputWithHeaders | None = None,
678+
add_to_outbound_str: _InputWithStringHeaders | None = None,
605679
new_span_even_on_replay: bool = False,
606680
additional_attributes: opentelemetry.util.types.Attributes = None,
607681
exception: Exception | None = None,
@@ -614,12 +688,14 @@ def _completed_span(
614688
# Create the span. First serialize current context to carrier.
615689
new_context_carrier: _CarrierDict = {}
616690
self.text_map_propagator.inject(new_context_carrier)
691+
617692
# Invoke
618693
info = temporalio.workflow.info()
619694
attributes: dict[str, opentelemetry.util.types.AttributeValue] = {
620695
"temporalWorkflowID": info.workflow_id,
621696
"temporalRunID": info.run_id,
622697
}
698+
623699
if additional_attributes:
624700
attributes.update(additional_attributes)
625701
updated_context_carrier = self._extern_functions[
@@ -640,10 +716,16 @@ def _completed_span(
640716
)
641717

642718
# Add to outbound if needed
643-
if add_to_outbound and updated_context_carrier:
644-
add_to_outbound.headers = self._context_carrier_to_headers(
645-
updated_context_carrier, add_to_outbound.headers
646-
)
719+
if updated_context_carrier:
720+
if add_to_outbound:
721+
add_to_outbound.headers = self._context_carrier_to_headers(
722+
updated_context_carrier, add_to_outbound.headers
723+
)
724+
725+
if add_to_outbound_str:
726+
add_to_outbound_str.headers = _carrier_to_nexus_headers(
727+
updated_context_carrier, add_to_outbound_str.headers
728+
)
647729

648730
def _set_on_context(
649731
self, context: opentelemetry.context.Context
@@ -722,6 +804,29 @@ def start_local_activity(
722804
)
723805
return super().start_local_activity(input)
724806

807+
async def start_nexus_operation(
808+
self, input: temporalio.worker.StartNexusOperationInput[Any, Any]
809+
) -> temporalio.workflow.NexusOperationHandle[Any]:
810+
self.root._completed_span(
811+
f"StartNexusOperation:{input.service}/{input.operation_name}",
812+
kind=opentelemetry.trace.SpanKind.CLIENT,
813+
add_to_outbound_str=input,
814+
)
815+
816+
return await super().start_nexus_operation(input)
817+
818+
819+
def _carrier_to_nexus_headers(
820+
carrier: _CarrierDict, initial: Mapping[str, str] | None = None
821+
) -> Mapping[str, str]:
822+
out = {**initial} if initial else {}
823+
for k, v in carrier.items():
824+
if isinstance(v, list):
825+
out[k] = ",".join(v)
826+
else:
827+
out[k] = v
828+
return out
829+
725830

726831
class workflow:
727832
"""Contains static methods that are safe to call from within a workflow.

temporalio/worker/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
ActivityOutboundInterceptor,
77
ContinueAsNewInput,
88
ExecuteActivityInput,
9+
ExecuteNexusOperationCancelInput,
10+
ExecuteNexusOperationStartInput,
911
ExecuteWorkflowInput,
1012
HandleQueryInput,
1113
HandleSignalInput,
1214
HandleUpdateInput,
1315
Interceptor,
16+
NexusOperationInboundInterceptor,
1417
SignalChildWorkflowInput,
1518
SignalExternalWorkflowInput,
1619
StartActivityInput,
@@ -80,6 +83,7 @@
8083
"ActivityOutboundInterceptor",
8184
"WorkflowInboundInterceptor",
8285
"WorkflowOutboundInterceptor",
86+
"NexusOperationInboundInterceptor",
8387
"Plugin",
8488
# Interceptor input
8589
"ContinueAsNewInput",
@@ -95,6 +99,8 @@
9599
"StartLocalActivityInput",
96100
"StartNexusOperationInput",
97101
"WorkflowInterceptorClassInput",
102+
"ExecuteNexusOperationStartInput",
103+
"ExecuteNexusOperationCancelInput",
98104
# Advanced activity classes
99105
"SharedStateManager",
100106
"SharedHeartbeatSender",

temporalio/worker/_interceptor.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,20 @@ def workflow_interceptor_class(
6666
"""
6767
return None
6868

69+
def intercept_nexus_operation(
70+
self, next: NexusOperationInboundInterceptor
71+
) -> NexusOperationInboundInterceptor:
72+
"""Method called for intercepting a Nexus operation.
73+
74+
Args:
75+
next: The underlying inbound this interceptor
76+
should delegate to.
77+
78+
Returns:
79+
The new interceptor that should be used for the Nexus operation.
80+
"""
81+
return next
82+
6983

7084
@dataclass(frozen=True)
7185
class WorkflowInterceptorClassInput:
@@ -465,3 +479,50 @@ async def start_nexus_operation(
465479
) -> temporalio.workflow.NexusOperationHandle[OutputT]:
466480
"""Called for every :py:func:`temporalio.workflow.NexusClient.start_operation` call."""
467481
return await self.next.start_nexus_operation(input)
482+
483+
484+
@dataclass
485+
class ExecuteNexusOperationStartInput:
486+
"""Input for :pyt:meth:`NexusOperationInboundInterceptor.start_operation"""
487+
488+
ctx: nexusrpc.handler.StartOperationContext
489+
input: Any
490+
491+
492+
@dataclass
493+
class ExecuteNexusOperationCancelInput:
494+
"""Input for :pyt:meth:`NexusOperationInboundInterceptor.cancel_operation"""
495+
496+
ctx: nexusrpc.handler.CancelOperationContext
497+
token: str
498+
499+
500+
class NexusOperationInboundInterceptor:
501+
"""Inbound interceptor to wrap Nexus operation starting and cancelling.
502+
503+
This should be extended by any Nexus operation inbound interceptors.
504+
"""
505+
506+
def __init__(self, next: NexusOperationInboundInterceptor) -> None:
507+
"""Create the inbound interceptor.
508+
509+
Args:
510+
next: The next interceptor in the chain. The default implementation
511+
of all calls is to delegate to the next interceptor.
512+
"""
513+
self.next = next
514+
515+
async def execute_nexus_operation_start(
516+
self, input: ExecuteNexusOperationStartInput
517+
) -> (
518+
nexusrpc.handler.StartOperationResultSync[Any]
519+
| nexusrpc.handler.StartOperationResultAsync
520+
):
521+
"""Called to start a Nexus operation"""
522+
return await self.next.execute_nexus_operation_start(input)
523+
524+
async def execute_nexus_operation_cancel(
525+
self, input: ExecuteNexusOperationCancelInput
526+
) -> None:
527+
"""Called to cancel an in progress Nexus operation"""
528+
return await self.next.execute_nexus_operation_cancel(input)

0 commit comments

Comments
 (0)