From 94667cfe8a8a1b821c3df390b9b3fde0ad64df14 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 19 Apr 2025 02:51:34 -0400 Subject: [PATCH 001/183] Nexus --- README.md | 2 + pyproject.toml | 16 +- temporalio/bridge/src/worker.rs | 28 +- temporalio/bridge/worker.py | 17 +- temporalio/client.py | 23 +- temporalio/common.py | 33 +- temporalio/converter.py | 26 + temporalio/exceptions.py | 63 ++ temporalio/nexus/__init__.py | 28 + temporalio/nexus/handler.py | 471 +++++++++ temporalio/nexus/token.py | 118 +++ temporalio/types.py | 6 +- temporalio/worker/_activity.py | 2 +- temporalio/worker/_interceptor.py | 70 +- temporalio/worker/_nexus.py | 470 +++++++++ temporalio/worker/_worker.py | 107 ++- tests/conftest.py | 6 + tests/helpers/nexus.py | 37 + ...ynamic_creation_of_user_handler_classes.py | 83 ++ tests/nexus/test_handler.py | 904 ++++++++++++++++++ tests/nexus/test_handler_async_operation.py | 260 +++++ .../test_handler_interface_implementation.py | 64 ++ .../test_handler_operation_definitions.py | 100 ++ 23 files changed, 2888 insertions(+), 46 deletions(-) create mode 100644 temporalio/nexus/__init__.py create mode 100644 temporalio/nexus/handler.py create mode 100644 temporalio/nexus/token.py create mode 100644 temporalio/worker/_nexus.py create mode 100644 tests/helpers/nexus.py create mode 100644 tests/nexus/test_dynamic_creation_of_user_handler_classes.py create mode 100644 tests/nexus/test_handler.py create mode 100644 tests/nexus/test_handler_async_operation.py create mode 100644 tests/nexus/test_handler_interface_implementation.py create mode 100644 tests/nexus/test_handler_operation_definitions.py diff --git a/README.md b/README.md index 5eacd1dec..03ab3c0ce 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ informal introduction to the features and their implementation. - [Heartbeating and Cancellation](#heartbeating-and-cancellation) - [Worker Shutdown](#worker-shutdown) - [Testing](#testing-1) + - [Nexus](#nexus) - [Workflow Replay](#workflow-replay) - [Observability](#observability) - [Metrics](#metrics) @@ -1308,6 +1309,7 @@ affect calls activity code might make to functions on the `temporalio.activity` * `cancel()` can be invoked to simulate a cancellation of the activity * `worker_shutdown()` can be invoked to simulate a worker shutdown during execution of the activity + ### Workflow Replay Given a workflow's history, it can be replayed locally to check for things like non-determinism errors. For example, diff --git a/pyproject.toml b/pyproject.toml index 391ea0abc..4a38ba6a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ keywords = [ "workflow", ] dependencies = [ + "nexus-rpc", "protobuf>=3.20,<6", "python-dateutil>=2.8.2,<3 ; python_version < '3.11'", "types-protobuf>=3.20", @@ -44,7 +45,7 @@ dev = [ "psutil>=5.9.3,<6", "pydocstyle>=6.3.0,<7", "pydoctor>=24.11.1,<25", - "pyright==1.1.377", + "pyright==1.1.400", "pytest~=7.4", "pytest-asyncio>=0.21,<0.22", "pytest-timeout~=2.2", @@ -53,6 +54,8 @@ dev = [ "twine>=4.0.1,<5", "ruff>=0.5.0,<0.6", "maturin>=1.8.2", + "pytest-cov>=6.1.1", + "httpx>=0.28.1", "pytest-pretty>=1.3.0", ] @@ -162,6 +165,7 @@ exclude = [ "tests/worker/workflow_sandbox/testmodules/proto", "temporalio/bridge/worker.py", "temporalio/contrib/opentelemetry.py", + "temporalio/contrib/pydantic.py", "temporalio/converter.py", "temporalio/testing/_workflow.py", "temporalio/worker/_activity.py", @@ -173,6 +177,10 @@ exclude = [ "tests/api/test_grpc_stub.py", "tests/conftest.py", "tests/contrib/test_opentelemetry.py", + "tests/contrib/pydantic/models.py", + "tests/contrib/pydantic/models_2.py", + "tests/contrib/pydantic/test_pydantic.py", + "tests/contrib/pydantic/workflows.py", "tests/test_converter.py", "tests/test_service.py", "tests/test_workflow.py", @@ -192,6 +200,9 @@ exclude = [ [tool.ruff] target-version = "py39" +[tool.ruff.lint] +extend-ignore = ["E741"] # Allow single-letter variable names like I, O + [build-system] requires = ["maturin>=1.0,<2.0"] build-backend = "maturin" @@ -208,3 +219,6 @@ exclude = [ [tool.uv] # Prevent uv commands from building the package by default package = false + +[tool.uv.sources] +nexus-rpc = { path = "../nexus-sdk-python", editable = true } diff --git a/temporalio/bridge/src/worker.rs b/temporalio/bridge/src/worker.rs index 9dfca82c9..4fb3085ed 100644 --- a/temporalio/bridge/src/worker.rs +++ b/temporalio/bridge/src/worker.rs @@ -20,7 +20,7 @@ use temporal_sdk_core_api::worker::{ }; use temporal_sdk_core_api::Worker; use temporal_sdk_core_protos::coresdk::workflow_completion::WorkflowActivationCompletion; -use temporal_sdk_core_protos::coresdk::{ActivityHeartbeat, ActivityTaskCompletion}; +use temporal_sdk_core_protos::coresdk::{ActivityHeartbeat, ActivityTaskCompletion, nexus::NexusTaskCompletion}; use temporal_sdk_core_protos::temporal::api::history::v1::History; use tokio::sync::mpsc::{channel, Sender}; use tokio_stream::wrappers::ReceiverStream; @@ -565,6 +565,19 @@ impl WorkerRef { }) } + fn poll_nexus_task<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + let worker = self.worker.as_ref().unwrap().clone(); + self.runtime.future_into_py(py, async move { + let bytes = match worker.poll_nexus_task().await { + Ok(task) => task.encode_to_vec(), + Err(PollError::ShutDown) => return Err(PollShutdownError::new_err(())), + Err(err) => return Err(PyRuntimeError::new_err(format!("Poll failure: {}", err))), + }; + let bytes: &[u8] = &bytes; + Ok(Python::with_gil(|py| bytes.into_py(py))) + }) + } + fn complete_workflow_activation<'p>( &self, py: Python<'p>, @@ -599,6 +612,19 @@ impl WorkerRef { }) } + fn complete_nexus_task<'p>(&self, py: Python<'p>, proto: &PyBytes) -> PyResult<&'p PyAny> { + let worker = self.worker.as_ref().unwrap().clone(); + let completion = NexusTaskCompletion::decode(proto.as_bytes()) + .map_err(|err| PyValueError::new_err(format!("Invalid proto: {}", err)))?; + self.runtime.future_into_py(py, async move { + worker + .complete_nexus_task(completion) + .await + .context("Completion failure") + .map_err(Into::into) + }) + } + fn record_activity_heartbeat(&self, proto: &Bound<'_, PyBytes>) -> PyResult<()> { enter_sync!(self.runtime); let heartbeat = ActivityHeartbeat::decode(proto.as_bytes()) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 74cf55bfd..e98a54470 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -26,6 +26,7 @@ import temporalio.bridge.client import temporalio.bridge.proto import temporalio.bridge.proto.activity_task +import temporalio.bridge.proto.nexus import temporalio.bridge.proto.workflow_activation import temporalio.bridge.proto.workflow_completion import temporalio.bridge.runtime @@ -35,7 +36,7 @@ from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, ) -from temporalio.bridge.temporal_sdk_bridge import PollShutdownError +from temporalio.bridge.temporal_sdk_bridge import PollShutdownError # type: ignore @dataclass @@ -216,6 +217,14 @@ async def poll_activity_task( await self._ref.poll_activity_task() ) + async def poll_nexus_task( + self, + ) -> temporalio.bridge.proto.nexus.NexusTask: + """Poll for a nexus task.""" + return temporalio.bridge.proto.nexus.NexusTask.FromString( + await self._ref.poll_nexus_task() + ) + async def complete_workflow_activation( self, comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion, @@ -229,6 +238,12 @@ async def complete_activity_task( """Complete an activity task.""" await self._ref.complete_activity_task(comp.SerializeToString()) + async def complete_nexus_task( + self, comp: temporalio.bridge.proto.nexus.NexusTaskCompletion + ) -> None: + """Complete a nexus task.""" + await self._ref.complete_nexus_task(comp.SerializeToString()) + def record_activity_heartbeat( self, comp: temporalio.bridge.proto.ActivityHeartbeat ) -> None: diff --git a/temporalio/client.py b/temporalio/client.py index f46297eb9..a5cac9b18 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -464,9 +464,16 @@ async def start_workflow( rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, request_eager_start: bool = False, - stack_level: int = 2, priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, + # The following options are deliberately not exposed in overloads + stack_level: int = 2, + nexus_completion_callbacks: Sequence[ + temporalio.common.NexusCompletionCallback + ] = [], + workflow_event_links: Sequence[ + temporalio.api.common.v1.Link.WorkflowEvent + ] = [], ) -> WorkflowHandle[Any, Any]: """Start a workflow and return its handle. @@ -557,6 +564,8 @@ async def start_workflow( rpc_timeout=rpc_timeout, request_eager_start=request_eager_start, priority=priority, + nexus_completion_callbacks=nexus_completion_callbacks, + workflow_event_links=workflow_event_links, ) ) @@ -5193,6 +5202,8 @@ class StartWorkflowInput: rpc_timeout: Optional[timedelta] request_eager_start: bool priority: temporalio.common.Priority + nexus_completion_callbacks: Sequence[temporalio.common.NexusCompletionCallback] + workflow_event_links: Sequence[temporalio.api.common.v1.Link.WorkflowEvent] versioning_override: Optional[temporalio.common.VersioningOverride] = None @@ -5809,6 +5820,16 @@ async def _build_start_workflow_execution_request( req = temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest() req.request_eager_execution = input.request_eager_start await self._populate_start_workflow_execution_request(req, input) + for callback in input.nexus_completion_callbacks: + c = temporalio.api.common.v1.Callback() + c.nexus.url = callback.url + c.nexus.header.update(callback.header) + req.completion_callbacks.append(c) + + req.links.extend( + temporalio.api.common.v1.Link(workflow_event=link) + for link in input.workflow_event_links + ) return req async def _build_signal_with_start_workflow_execution_request( diff --git a/temporalio/common.py b/temporalio/common.py index 3349f70e9..dbc04a3b1 100644 --- a/temporalio/common.py +++ b/temporalio/common.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timedelta -from enum import Enum, IntEnum +from enum import IntEnum from typing import ( Any, Callable, @@ -197,6 +197,37 @@ def __setstate__(self, state: object) -> None: ) +@dataclass(frozen=True) +class NexusCompletionCallback: + """Nexus callback to attach to events such as workflow completion.""" + + url: str + """Callback URL.""" + + header: Mapping[str, str] + """Header to attach to callback request.""" + + +@dataclass(frozen=True) +class WorkflowEventLink: + """A link to a history event that can be attached to a different history event.""" + + namespace: str + """Namespace of the workflow to link to.""" + + workflow_id: str + """ID of the workflow to link to.""" + + run_id: str + """Run ID of the workflow to link to.""" + + event_type: temporalio.api.enums.v1.EventType + """Type of the event to link to.""" + + event_id: int + """ID of the event to link to.""" + + # We choose to make this a list instead of an sequence so we can catch if people # are not sending lists each time but maybe accidentally sending a string (which # is a sequence) diff --git a/temporalio/converter.py b/temporalio/converter.py index 6a6d0e12b..b976eca08 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -911,6 +911,12 @@ def _error_to_failure( failure.child_workflow_execution_failure_info.retry_state = ( temporalio.api.enums.v1.RetryState.ValueType(error.retry_state or 0) ) + # TODO(nexus-prerelease): test coverage for this + elif isinstance(error, temporalio.exceptions.NexusOperationError): + failure.nexus_operation_execution_failure_info.SetInParent() + failure.nexus_operation_execution_failure_info.operation_token = ( + error.operation_token + ) def from_failure( self, @@ -1006,6 +1012,26 @@ def from_failure( if child_info.retry_state else None, ) + elif failure.HasField("nexus_handler_failure_info"): + nexus_handler_failure_info = failure.nexus_handler_failure_info + err = temporalio.exceptions.NexusHandlerError( + failure.message or "Nexus handler error", + type=nexus_handler_failure_info.type, + retryable={ + temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE: True, + temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: False, + }.get(nexus_handler_failure_info.retry_behavior), + ) + elif failure.HasField("nexus_operation_execution_failure_info"): + nexus_op_failure_info = failure.nexus_operation_execution_failure_info + err = temporalio.exceptions.NexusOperationError( + failure.message or "Nexus operation error", + scheduled_event_id=nexus_op_failure_info.scheduled_event_id, + endpoint=nexus_op_failure_info.endpoint, + service=nexus_op_failure_info.service, + operation=nexus_op_failure_info.operation, + operation_token=nexus_op_failure_info.operation_token, + ) else: err = temporalio.exceptions.FailureError(failure.message or "Failure error") err._failure = failure diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index f045b36a0..e687482f6 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -362,6 +362,69 @@ def retry_state(self) -> Optional[RetryState]: return self._retry_state +class NexusHandlerError(FailureError): + """Error raised on Nexus handler failure.""" + + def __init__( + self, + message: str, + *, + type: str, + retryable: Optional[bool] = None, + ): + """Initialize a Nexus handler error.""" + super().__init__(message) + self._type = type + self._retryable = retryable + + +class NexusOperationError(FailureError): + """Error raised on Nexus operation failure.""" + + def __init__( + self, + message: str, + *, + scheduled_event_id: int, + endpoint: str, + service: str, + operation: str, + operation_token: str, + ): + """Initialize a Nexus operation error.""" + super().__init__(message) + self._scheduled_event_id = scheduled_event_id + self._endpoint = endpoint + self._service = service + self._operation = operation + self._operation_token = operation_token + + @property + def scheduled_event_id(self) -> int: + """The NexusOperationScheduled event ID for the failed operation.""" + return self._scheduled_event_id + + @property + def endpoint(self) -> str: + """The endpoint name for the failed operation.""" + return self._endpoint + + @property + def service(self) -> str: + """The service name for the failed operation.""" + return self._service + + @property + def operation(self) -> str: + """The name of the failed operation.""" + return self._operation + + @property + def operation_token(self) -> str: + """The operation token returned by the failed operation.""" + return self._operation_token + + def is_cancelled_exception(exception: BaseException) -> bool: """Check whether the given exception is considered a cancellation exception according to Temporal. diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py new file mode 100644 index 000000000..9750cfb88 --- /dev/null +++ b/temporalio/nexus/__init__.py @@ -0,0 +1,28 @@ +import dataclasses +import logging +from collections.abc import Mapping +from typing import Any, MutableMapping, Optional + +from .handler import _current_context as _current_context +from .handler import workflow_run_operation_handler as workflow_run_operation_handler +from .token import WorkflowOperationToken as WorkflowOperationToken + + +class LoggerAdapter(logging.LoggerAdapter): + def __init__(self, logger: logging.Logger, extra: Optional[Mapping[str, Any]]): + super().__init__(logger, extra or {}) + + def process( + self, msg: Any, kwargs: MutableMapping[str, Any] + ) -> tuple[Any, MutableMapping[str, Any]]: + extra = dict(self.extra or {}) + if context := _current_context.get(None): + extra.update( + {f.name: getattr(context, f.name) for f in dataclasses.fields(context)} + ) + kwargs["extra"] = extra | kwargs.get("extra", {}) + return msg, kwargs + + +logger = LoggerAdapter(logging.getLogger(__name__), None) +"""Logger that emits additional data describing the current Nexus operation.""" diff --git a/temporalio/nexus/handler.py b/temporalio/nexus/handler.py new file mode 100644 index 000000000..4e96bb33e --- /dev/null +++ b/temporalio/nexus/handler.py @@ -0,0 +1,471 @@ +from __future__ import annotations + +import logging +import re +import types +import typing +import urllib.parse +import warnings +from contextvars import ContextVar +from dataclasses import dataclass +from functools import wraps +from typing import ( + Any, + Awaitable, + Callable, + Generic, + Optional, + Sequence, + Type, + TypeVar, + Union, +) + +import nexusrpc.handler +from typing_extensions import Concatenate, Self, overload + +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.common +from temporalio.client import ( + Client, + WorkflowHandle, +) +from temporalio.nexus.token import WorkflowOperationToken +from temporalio.types import ( + MethodAsyncNoParam, + MethodAsyncSingleParam, + MultiParamSpec, + ParamType, + ReturnType, + SelfType, +) + +I = TypeVar("I", contravariant=True) # operation input +O = TypeVar("O", covariant=True) # operation output +S = TypeVar("S") # a service + +logger = logging.getLogger(__name__) + + +# TODO(nexus-preview): demonstrate obtaining Temporal client in sync operation. + + +def _get_workflow_run_start_method_input_and_output_type_annotations( + start_method: Callable[ + [S, nexusrpc.handler.StartOperationContext, I], + Awaitable[WorkflowHandle[Any, O]], + ], +) -> tuple[ + Optional[Type[I]], + Optional[Type[O]], +]: + """Return operation input and output types. + + `start_method` must be a type-annotated start method that returns a + :py:class:`WorkflowHandle`. + """ + input_type, output_type = ( + nexusrpc.handler.get_start_method_input_and_output_types_annotations( + start_method + ) + ) + origin_type = typing.get_origin(output_type) + if not origin_type or not issubclass(origin_type, WorkflowHandle): + warnings.warn( + f"Expected return type of {start_method.__name__} to be a subclass of WorkflowHandle, " + f"but is {output_type}" + ) + output_type = None + + args = typing.get_args(output_type) + if len(args) != 2: + warnings.warn( + f"Expected return type of {start_method.__name__} to have exactly two type parameters, " + f"but has {len(args)}: {args}" + ) + output_type = None + else: + _wf_type, output_type = args + return input_type, output_type + + +# No-param overload +@overload +async def start_workflow( + ctx: nexusrpc.handler.StartOperationContext, + workflow: MethodAsyncNoParam[SelfType, ReturnType], + *, + id: str, + client: Optional[Client] = None, + task_queue: Optional[str] = None, +) -> WorkflowHandle[SelfType, ReturnType]: ... + + +# Single-param overload +@overload +async def start_workflow( + ctx: nexusrpc.handler.StartOperationContext, + workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + client: Optional[Client] = None, + task_queue: Optional[str] = None, +) -> WorkflowHandle[SelfType, ReturnType]: ... + + +# Multiple-params overload +@overload +async def start_workflow( + ctx: nexusrpc.handler.StartOperationContext, + workflow: Callable[Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType]], + *, + args: Sequence[Any], + id: str, + client: Optional[Client] = None, + task_queue: Optional[str] = None, +) -> WorkflowHandle[SelfType, ReturnType]: ... + + +# TODO(nexus-prerelease): Overload for string-name workflow + + +async def start_workflow( + ctx: nexusrpc.handler.StartOperationContext, + workflow: Callable[..., Awaitable[Any]], + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + client: Optional[Client] = None, + task_queue: Optional[str] = None, +) -> WorkflowHandle[Any, Any]: + if client is None: + client = get_client() + if task_queue is None: + # TODO(nexus-prerelease): are we handling empty string well elsewhere? + task_queue = get_task_queue() + completion_callbacks = ( + [ + # TODO(nexus-prerelease): For WorkflowRunOperation, when it handles the Nexus + # request, it needs to copy the links to the callback in + # StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links + # (for backwards compatibility). PR reference in Go SDK: + # https://github.com/temporalio/sdk-go/pull/1945 + temporalio.common.NexusCompletionCallback( + url=ctx.callback_url, header=ctx.callback_headers + ) + ] + if ctx.callback_url + else [] + ) + # We need to pass options (completion_callbacks, links, on_conflict_options) which are + # deliberately not exposed in any overload, hence the type error. + wf_handle = await client.start_workflow( # type: ignore + workflow, + args=temporalio.common._arg_or_args(arg, args), + id=id, + task_queue=task_queue, + nexus_completion_callbacks=completion_callbacks, + workflow_event_links=[ + _nexus_link_to_workflow_event(l) for l in ctx.inbound_links + ], + ) + try: + link = _workflow_event_to_nexus_link( + _workflow_handle_to_workflow_execution_started_event_link(wf_handle) + ) + except Exception as e: + logger.warning( + f"Failed to create WorkflowExecutionStarted event link for workflow {id}: {e}" + ) + else: + ctx.outbound_links.append( + # TODO(nexus-prerelease): Before, WorkflowRunOperation was generating an EventReference + # link to send back to the caller. Now, it checks if the server returned + # the link in the StartWorkflowExecutionResponse, and if so, send the link + # from the response to the caller. Fallback to generating the link for + # backwards compatibility. PR reference in Go SDK: + # https://github.com/temporalio/sdk-go/pull/1934 + link + ) + return wf_handle + + +# TODO(nexus-prerelease): support request_id +# See e.g. TS +# packages/nexus/src/context.ts attachRequestId +# packages/test/src/test-nexus-handler.ts ctx.requestId + + +async def cancel_workflow( + ctx: nexusrpc.handler.CancelOperationContext, + token: str, + client: Optional[Client] = None, +) -> None: + _client = client or get_client() + handle = WorkflowOperationToken.decode(token).to_workflow_handle(_client) + await handle.cancel() + + +_current_context: ContextVar[_Context] = ContextVar("nexus-handler") + + +@dataclass +class _Context: + client: Optional[Client] + task_queue: Optional[str] + service: Optional[str] = None + operation: Optional[str] = None + + +def get_client() -> Client: + context = _current_context.get(None) + if context is None: + raise RuntimeError("Not in Nexus handler context") + if context.client is None: + raise RuntimeError("Nexus handler client not set") + return context.client + + +def get_task_queue() -> str: + context = _current_context.get(None) + if context is None: + raise RuntimeError("Not in Nexus handler context") + if context.task_queue is None: + raise RuntimeError("Nexus handler task queue not set") + return context.task_queue + + +class WorkflowRunOperation(nexusrpc.handler.OperationHandler[I, O], Generic[I, O, S]): + def __init__( + self, + service: S, + start_method: Callable[ + [S, nexusrpc.handler.StartOperationContext, I], + Awaitable[WorkflowHandle[Any, O]], + ], + output_type: Optional[Type] = None, + ): + self.service = service + + @wraps(start_method) + async def start( + self, ctx: nexusrpc.handler.StartOperationContext, input: I + ) -> WorkflowRunOperationResult: + wf_handle = await start_method(service, ctx, input) + # TODO(nexus-prerelease): Error message if user has accidentally used the normal client.start_workflow + return WorkflowRunOperationResult.from_workflow_handle(wf_handle) + + self.start = types.MethodType(start, self) + + async def start( + self, ctx: nexusrpc.handler.StartOperationContext, input: I + ) -> nexusrpc.handler.StartOperationResultAsync: + raise NotImplementedError( + "The start method of a WorkflowRunOperation should be set " + "dynamically in the __init__ method. (Did you forget to call super()?)" + ) + + async def cancel( + self, ctx: nexusrpc.handler.CancelOperationContext, token: str + ) -> None: + await cancel_workflow(ctx, token) + + def fetch_info( + self, ctx: nexusrpc.handler.FetchOperationInfoContext, token: str + ) -> Union[ + nexusrpc.handler.OperationInfo, Awaitable[nexusrpc.handler.OperationInfo] + ]: + raise NotImplementedError( + "Temporal Nexus operation handlers do not support fetching operation info." + ) + + def fetch_result( + self, ctx: nexusrpc.handler.FetchOperationResultContext, token: str + ) -> Union[O, Awaitable[O]]: + raise NotImplementedError( + "Temporal Nexus operation handlers do not support fetching operation results." + ) + + +class WorkflowRunOperationResult(nexusrpc.handler.StartOperationResultAsync): + """ + A value returned by the start method of a :class:`WorkflowRunOperation`. + + It indicates that the operation is responding asynchronously, and contains a token + that the handler can use to construct a :class:`~temporalio.client.WorkflowHandle` to + interact with the workflow. + """ + + @classmethod + def from_workflow_handle(cls, workflow_handle: WorkflowHandle) -> Self: + token = WorkflowOperationToken.from_workflow_handle(workflow_handle).encode() + return cls(token=token) + + +@overload +def workflow_run_operation_handler( + start_method: Callable[ + [S, nexusrpc.handler.StartOperationContext, I], + Awaitable[WorkflowHandle[Any, O]], + ], +) -> Callable[[S], WorkflowRunOperation[I, O, S]]: ... + + +@overload +def workflow_run_operation_handler( + *, + name: Optional[str] = None, +) -> Callable[ + [ + Callable[ + [S, nexusrpc.handler.StartOperationContext, I], + Awaitable[WorkflowHandle[Any, O]], + ] + ], + Callable[[S], WorkflowRunOperation[I, O, S]], +]: ... + + +def workflow_run_operation_handler( + start_method: Optional[ + Callable[ + [S, nexusrpc.handler.StartOperationContext, I], + Awaitable[WorkflowHandle[Any, O]], + ] + ] = None, + *, + name: Optional[str] = None, +) -> Union[ + Callable[[S], WorkflowRunOperation[I, O, S]], + Callable[ + [ + Callable[ + [S, nexusrpc.handler.StartOperationContext, I], + Awaitable[WorkflowHandle[Any, O]], + ] + ], + Callable[[S], WorkflowRunOperation[I, O, S]], + ], +]: + def decorator( + start_method: Callable[ + [S, nexusrpc.handler.StartOperationContext, I], + Awaitable[WorkflowHandle[Any, O]], + ], + ) -> Callable[[S], WorkflowRunOperation[I, O, S]]: + input_type, output_type = ( + _get_workflow_run_start_method_input_and_output_type_annotations( + start_method + ) + ) + + def factory(service: S) -> WorkflowRunOperation[I, O, S]: + return WorkflowRunOperation(service, start_method, output_type=output_type) + + # TODO(nexus-prerelease): handle callable instances: __class__.__name__ as in sync_operation_handler + method_name = getattr(start_method, "__name__", None) + if not method_name and callable(start_method): + method_name = start_method.__class__.__name__ + if not method_name: + raise TypeError( + f"Could not determine operation method name: " + f"expected {start_method} to be a function or callable instance." + ) + + factory.__nexus_operation__ = nexusrpc.Operation._create( + name=name, + method_name=method_name, + input_type=input_type, + output_type=output_type, + ) + + return factory + + if start_method is None: + return decorator + + return decorator(start_method) + + +# TODO(nexus-prerelease): confirm that it is correct not to use event_id in the following functions. +# Should the proto say explicitly that it's optional or how it behaves when it's missing? +def _workflow_handle_to_workflow_execution_started_event_link( + handle: WorkflowHandle[Any, Any], +) -> temporalio.api.common.v1.Link.WorkflowEvent: + if handle.first_execution_run_id is None: + raise ValueError( + f"Workflow handle {handle} has no first execution run ID. " + "Cannot create WorkflowExecutionStarted event link." + ) + return temporalio.api.common.v1.Link.WorkflowEvent( + namespace=handle._client.namespace, + workflow_id=handle.id, + run_id=handle.first_execution_run_id, + event_ref=temporalio.api.common.v1.Link.WorkflowEvent.EventReference( + event_type=temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED + ), + ) + + +def _workflow_event_to_nexus_link( + workflow_event: temporalio.api.common.v1.Link.WorkflowEvent, +) -> nexusrpc.handler.Link: + scheme = "temporal" + namespace = urllib.parse.quote(workflow_event.namespace) + workflow_id = urllib.parse.quote(workflow_event.workflow_id) + run_id = urllib.parse.quote(workflow_event.run_id) + path = f"/namespaces/{namespace}/workflows/{workflow_id}/{run_id}/history" + query_params = urllib.parse.urlencode( + { + "eventType": temporalio.api.enums.v1.EventType.Name( + workflow_event.event_ref.event_type + ), + "referenceType": "EventReference", + } + ) + return nexusrpc.handler.Link( + url=urllib.parse.urlunparse((scheme, "", path, "", query_params, "")), + type=workflow_event.DESCRIPTOR.full_name, + ) + + +def _nexus_link_to_workflow_event( + link: nexusrpc.handler.Link, +) -> Optional[temporalio.api.common.v1.Link.WorkflowEvent]: + path_regex = re.compile( + r"^/namespaces/(?P[^/]+)/workflows/(?P[^/]+)/(?P[^/]+)/history$" + ) + url = urllib.parse.urlparse(link.url) + match = path_regex.match(url.path) + if not match: + logger.warning( + f"Invalid Nexus link: {link}. Expected path to match {path_regex.pattern}" + ) + return None + try: + query_params = urllib.parse.parse_qs(url.query) + [reference_type] = query_params.get("referenceType", []) + if reference_type != "EventReference": + raise ValueError( + f"@@ Expected Nexus link URL query parameter referenceType to be EventReference but got: {reference_type}" + ) + [event_type_name] = query_params.get("eventType", []) + event_ref = temporalio.api.common.v1.Link.WorkflowEvent.EventReference( + event_type=temporalio.api.enums.v1.EventType.Value(event_type_name) + ) + except ValueError as err: + logger.warning( + f"@@ Failed to parse event type from Nexus link URL query parameters: {link} ({err})" + ) + event_ref = None + + groups = match.groupdict() + return temporalio.api.common.v1.Link.WorkflowEvent( + namespace=urllib.parse.unquote(groups["namespace"]), + workflow_id=urllib.parse.unquote(groups["workflow_id"]), + run_id=urllib.parse.unquote(groups["run_id"]), + event_ref=event_ref, + ) diff --git a/temporalio/nexus/token.py b/temporalio/nexus/token.py new file mode 100644 index 000000000..d357ecb9c --- /dev/null +++ b/temporalio/nexus/token.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import base64 +import json +from dataclasses import dataclass +from typing import Any, Literal, Optional + +from temporalio.client import Client, WorkflowHandle + +OPERATION_TOKEN_TYPE_WORKFLOW = 1 +OperationTokenType = Literal[1] + + +@dataclass(frozen=True) +class WorkflowOperationToken: + """Represents the structured data of a Nexus workflow operation token.""" + + namespace: str + workflow_id: str + _type: OperationTokenType = OPERATION_TOKEN_TYPE_WORKFLOW + # Version of the token. Treated as v1 if missing. This field is not included in the + # serialized token; it's only used to reject newer token versions on load. + version: Optional[int] = None + + @classmethod + def from_workflow_handle( + cls, workflow_handle: WorkflowHandle[Any, Any] + ) -> WorkflowOperationToken: + """Creates a token from a workflow handle.""" + return cls( + namespace=workflow_handle._client.namespace, + workflow_id=workflow_handle.id, + ) + + def to_workflow_handle(self, client: Client) -> WorkflowHandle[Any, Any]: + """Creates a workflow handle from this token.""" + if client.namespace != self.namespace: + raise ValueError( + f"Client namespace {client.namespace} does not match token namespace {self.namespace}" + ) + return client.get_workflow_handle(self.workflow_id) + + def encode(self) -> str: + return _base64url_encode_no_padding( + json.dumps( + { + "t": self._type, + "ns": self.namespace, + "wid": self.workflow_id, + }, + separators=(",", ":"), + ).encode("utf-8") + ) + + @classmethod + def decode(cls, data: str) -> WorkflowOperationToken: + """Decodes and validates a token from its base64url-encoded string representation.""" + if not data: + raise TypeError("invalid workflow token: token is empty") + try: + decoded_bytes = _base64url_decode_no_padding(data) + except Exception as err: + raise TypeError("failed to decode token as base64url") from err + try: + token = json.loads(decoded_bytes.decode("utf-8")) + except Exception as err: + raise TypeError("failed to unmarshal workflow operation token") from err + + if not isinstance(token, dict): + raise TypeError(f"invalid workflow token: expected dict, got {type(token)}") + + _type = token.get("t") + if _type != OPERATION_TOKEN_TYPE_WORKFLOW: + raise TypeError( + f"invalid workflow token type: {_type}, expected: {OPERATION_TOKEN_TYPE_WORKFLOW}" + ) + + version = token.get("v") + if version is not None and version != 0: + raise TypeError( + "invalid workflow token: 'v' field, if present, must be 0 or null/absent" + ) + + workflow_id = token.get("wid") + if not workflow_id or not isinstance(workflow_id, str): + raise TypeError( + "invalid workflow token: missing, empty, or non-string workflow ID (wid)" + ) + + namespace = token.get("ns") + if namespace is None or not isinstance(namespace, str): + # Allow empty string for ns, but it must be present and a string + raise TypeError( + "invalid workflow token: missing or non-string namespace (ns)" + ) + + return cls( + _type=_type, + namespace=namespace, + workflow_id=workflow_id, + version=version, + ) + + +def _base64url_encode_no_padding(data: bytes) -> str: + return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=") + + +def _base64url_decode_no_padding(s: str) -> bytes: + if not all( + c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-" + for c in s + ): + raise ValueError( + "invalid base64URL encoded string: contains invalid characters" + ) + padding = "=" * (-len(s) % 4) + return base64.urlsafe_b64decode(s + padding) diff --git a/temporalio/types.py b/temporalio/types.py index 331c9596e..a756d328c 100644 --- a/temporalio/types.py +++ b/temporalio/types.py @@ -82,7 +82,7 @@ class MethodAsyncSingleParam( """Generic callable type.""" def __call__( - self, __self: ProtocolSelfType, __arg: ProtocolParamType, / + self, __self: ProtocolSelfType, __arg: ProtocolParamType ) -> Awaitable[ProtocolReturnType]: """Generic callable type callback.""" ... @@ -94,7 +94,7 @@ class MethodSyncSingleParam( """Generic callable type.""" def __call__( - self, __self: ProtocolSelfType, __arg: ProtocolParamType, / + self, __self: ProtocolSelfType, __arg: ProtocolParamType ) -> ProtocolReturnType: """Generic callable type callback.""" ... @@ -116,7 +116,7 @@ class MethodSyncOrAsyncSingleParam( """Generic callable type.""" def __call__( - self, __self: ProtocolSelfType, __param: ProtocolParamType, / + self, __self: ProtocolSelfType, __param: ProtocolParamType ) -> Union[ProtocolReturnType, Awaitable[ProtocolReturnType]]: """Generic callable type callback.""" ... diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index fe18d1f18..c9f71834c 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -201,7 +201,7 @@ async def drain_poll_queue(self) -> None: # Only call this after run()/drain_poll_queue() have returned. This will not # raise an exception. - # TODO(dan): based on the comment above it looks like the intention may have been to use + # TODO(nexus-prerelease): based on the comment above it looks like the intention may have been to use # return_exceptions=True async def wait_all_completed(self) -> None: running_tasks = [v.task for v in self._running_activities.values() if v.task] diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index a3146200e..7e0a1d35b 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -3,22 +3,24 @@ from __future__ import annotations import concurrent.futures -from dataclasses import dataclass +from collections.abc import Callable, Mapping, MutableMapping +from dataclasses import dataclass, field from datetime import timedelta from typing import ( Any, Awaitable, - Callable, + Generic, List, - Mapping, - MutableMapping, NoReturn, Optional, Sequence, Type, + TypeVar, Union, ) +import nexusrpc.handler + import temporalio.activity import temporalio.api.common.v1 import temporalio.common @@ -285,6 +287,60 @@ class StartChildWorkflowInput: ret_type: Optional[Type] +# TODO(nexus-prerelease): Put these in a better location. Type variance? +I = TypeVar("I") +O = TypeVar("O") + + +@dataclass +class StartNexusOperationInput(Generic[I, O]): + """Input for :py:meth:`WorkflowOutboundInterceptor.start_nexus_operation`.""" + + endpoint: str + service: str + operation: Union[ + nexusrpc.Operation[I, O], + Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], + str, + ] + input: I + schedule_to_close_timeout: Optional[timedelta] + headers: Optional[Mapping[str, str]] + output_type: Optional[Type[O]] = None + + _operation_name: str = field(init=False, repr=False) + _input_type: Optional[Type[I]] = field(init=False, repr=False) + + def __post_init__(self) -> None: + if isinstance(self.operation, str): + self._operation_name = self.operation + self._input_type = None + elif isinstance(self.operation, nexusrpc.Operation): + self._operation_name = self.operation.name + self._input_type = self.operation.input_type + self.output_type = self.operation.output_type + elif isinstance(self.operation, Callable): + op = getattr(self.operation, "__nexus_operation__", None) + if isinstance(op, nexusrpc.Operation): + self._operation_name = op.name + self._input_type = op.input_type + self.output_type = op.output_type + else: + raise ValueError( + f"Operation callable is not a Nexus operation: {self.operation}" + ) + else: + raise ValueError(f"Operation is not a Nexus operation: {self.operation}") + + @property + def operation_name(self) -> str: + return self._operation_name + + @property + def input_type(self) -> Optional[Type[I]]: + return self._input_type + + @dataclass class StartLocalActivityInput: """Input for :py:meth:`WorkflowOutboundInterceptor.start_local_activity`.""" @@ -409,3 +465,9 @@ def start_local_activity( and :py:func:`temporalio.workflow.execute_local_activity` call. """ return self.next.start_local_activity(input) + + async def start_nexus_operation( + self, input: StartNexusOperationInput + ) -> temporalio.workflow.NexusOperationHandle[Any]: + """Called for every :py:func:`temporalio.workflow.start_nexus_operation` call.""" + return await self.next.start_nexus_operation(input) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py new file mode 100644 index 000000000..e8c57c4f8 --- /dev/null +++ b/temporalio/worker/_nexus.py @@ -0,0 +1,470 @@ +"""Nexus worker""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import json +import logging +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Optional, + Sequence, + Type, +) + +import google.protobuf.json_format +import nexusrpc.handler +from nexusrpc.handler._core import SyncExecutor + +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.api.failure.v1 +import temporalio.api.nexus.v1 +import temporalio.bridge.proto.nexus +import temporalio.bridge.worker +import temporalio.client +import temporalio.common +import temporalio.converter +import temporalio.nexus +import temporalio.nexus.handler +from temporalio.exceptions import ApplicationError +from temporalio.service import RPCError, RPCStatusCode + +from ._interceptor import Interceptor + +logger = logging.getLogger(__name__) + + +class _NexusWorker: + def __init__( + self, + *, + bridge_worker: Callable[[], temporalio.bridge.worker.Worker], + client: temporalio.client.Client, + task_queue: str, + nexus_services: Sequence[Any], + data_converter: temporalio.converter.DataConverter, + interceptors: Sequence[Interceptor], + metric_meter: temporalio.common.MetricMeter, + executor: Optional[concurrent.futures.ThreadPoolExecutor], + ) -> None: + # TODO(nexus-prerelease): make it possible to query task queue of bridge worker + # instead of passing unused task_queue into _NexusWorker, + # _ActivityWorker, etc? + self._bridge_worker = bridge_worker + self._client = client + self._task_queue = task_queue + + for service in nexus_services: + if isinstance(service, type): + raise TypeError( + f"Expected a service instance, but got a class: {service}. " + "Nexus services must be passed as instances, not classes." + ) + self._handler = nexusrpc.handler.Handler( + nexus_services, + SyncExecutor(executor) if executor is not None else None, + ) + self._data_converter = data_converter + # TODO(nexus-prerelease): interceptors + self._interceptors = interceptors + # TODO(nexus-prerelease): metric_meter + self._metric_meter = metric_meter + self._running_operations: dict[bytes, asyncio.Task[Any]] = {} + + async def run(self) -> None: + while True: + try: + poll_task = asyncio.create_task(self._bridge_worker().poll_nexus_task()) + except Exception as err: + raise RuntimeError("Nexus worker failed") from err + + task = await poll_task + + if task.HasField("task"): + task = task.task + if task.request.HasField("start_operation"): + self._running_operations[task.task_token] = asyncio.create_task( + self._run_nexus_operation( + task.task_token, + task.request.start_operation, + dict(task.request.header), + ) + ) + elif task.request.HasField("cancel_operation"): + # TODO(nexus-prerelease): report errors occurring during execution of user + # cancellation method + asyncio.create_task( + self._handle_cancel_operation( + task.request.cancel_operation, task.task_token + ) + ) + else: + raise NotImplementedError( + f"Invalid Nexus task request: {task.request}" + ) + elif task.HasField("cancel_task"): + task = task.cancel_task + if _task := self._running_operations.get(task.task_token): + # TODO(nexus-prerelease): when do we remove the entry from _running_operations? + _task.cancel() + else: + temporalio.nexus.logger.warning( + f"Received cancel_task but no running operation exists for " + f"task token: {task.task_token}" + ) + else: + raise NotImplementedError(f"Invalid Nexus task: {task}") + + # Only call this if run() raised an error + async def drain_poll_queue(self) -> None: + while True: + try: + # Take all tasks and say we can't handle them + task = await self._bridge_worker().poll_nexus_task() + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task.task.task_token + ) + completion.error.failure.message = "Worker shutting down" + await self._bridge_worker().complete_nexus_task(completion) + except temporalio.bridge.worker.PollShutdownError: + return + + async def wait_all_completed(self) -> None: + await asyncio.gather( + *self._running_operations.values(), return_exceptions=False + ) + + # TODO(nexus-prerelease): stack trace pruning. See sdk-typescript NexusHandler.execute + # "Any call up to this function and including this one will be trimmed out of stack traces."" + + async def _run_nexus_operation( + self, + task_token: bytes, + start_request: temporalio.api.nexus.v1.StartOperationRequest, + header: dict[str, str], + ) -> None: + async def run() -> temporalio.bridge.proto.nexus.NexusTaskCompletion: + temporalio.nexus.handler._current_context.set( + temporalio.nexus.handler._Context( + client=self._client, + task_queue=self._task_queue, + service=start_request.service, + operation=start_request.operation, + ) + ) + try: + ctx = nexusrpc.handler.StartOperationContext( + service=start_request.service, + operation=start_request.operation, + headers=header, + request_id=start_request.request_id, + callback_url=start_request.callback, + inbound_links=[ + nexusrpc.handler.Link(url=l.url, type=l.type) + for l in start_request.links + ], + callback_headers=dict(start_request.callback_header), + ) + input = nexusrpc.handler.LazyValue( + serializer=_DummyPayloadSerializer( + data_converter=self._data_converter, + payload=start_request.payload, + ), + headers={}, + stream=None, + ) + try: + result = await self._handler.start_operation(ctx, input) + except ( + nexusrpc.handler.UnknownServiceError, + nexusrpc.handler.UnknownOperationError, + ) as err: + # TODO(nexus-prerelease): error message + raise nexusrpc.handler.HandlerError( + "No matching operation handler", + type=nexusrpc.handler.HandlerErrorType.NOT_FOUND, + cause=err, + retryable=False, + ) from err + + except nexusrpc.handler.OperationError as err: + return temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + completed=temporalio.api.nexus.v1.Response( + start_operation=temporalio.api.nexus.v1.StartOperationResponse( + operation_error=await self._operation_error_to_proto(err), + ), + ), + ) + except BaseException as err: + handler_err = _exception_to_handler_error(err) + return temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + error=temporalio.api.nexus.v1.HandlerError( + error_type=handler_err.type.value, + failure=await self._exception_to_failure_proto( + handler_err.__cause__ + ), + retry_behavior=( + temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE + if handler_err.retryable + else temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE + ), + ), + ) + else: + if isinstance(result, nexusrpc.handler.StartOperationResultAsync): + op_resp = temporalio.api.nexus.v1.StartOperationResponse( + async_success=temporalio.api.nexus.v1.StartOperationResponse.Async( + operation_token=result.token, + links=[ + temporalio.api.nexus.v1.Link(url=l.url, type=l.type) + for l in ctx.outbound_links + ], + ) + ) + elif isinstance(result, nexusrpc.handler.StartOperationResultSync): + # TODO(nexus-prerelease): error handling here; what error type should it be? + [payload] = await self._data_converter.encode([result.value]) + op_resp = temporalio.api.nexus.v1.StartOperationResponse( + sync_success=temporalio.api.nexus.v1.StartOperationResponse.Sync( + payload=payload + ) + ) + else: + # TODO(nexus-prerelease): what should the error response be when the user has failed to wrap their return type? + # TODO(nexus-prerelease): unify this failure completion with the path above + err = TypeError( + "Operation start method must return either nexusrpc.handler.StartOperationResultSync " + "or nexusrpc.handler.StartOperationResultAsync" + ) + handler_err = _exception_to_handler_error(err) + return temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + error=temporalio.api.nexus.v1.HandlerError( + error_type=handler_err.type.value, + failure=await self._exception_to_failure_proto( + handler_err.__cause__ + ), + retry_behavior=( + temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE + if handler_err.retryable + else temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE + ), + ), + ) + + return temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + completed=temporalio.api.nexus.v1.Response(start_operation=op_resp), + ) + + try: + completion = await run() + await self._bridge_worker().complete_nexus_task(completion) + except Exception: + temporalio.nexus.logger.exception("Failed completing Nexus operation") + finally: + try: + del self._running_operations[task_token] + except KeyError: + temporalio.nexus.logger.exception( + "Failed to remove completed Nexus operation" + ) + + async def _handle_cancel_operation( + self, request: temporalio.api.nexus.v1.CancelOperationRequest, task_token: bytes + ) -> None: + temporalio.nexus.handler._current_context.set( + temporalio.nexus.handler._Context( + client=self._client, + task_queue=self._task_queue, + service=request.service, + operation=request.operation, + ) + ) + ctx = nexusrpc.handler.CancelOperationContext( + service=request.service, + operation=request.operation, + ) + # TODO(nexus-prerelease): header + try: + await self._handler.cancel_operation(ctx, request.operation_token) + except Exception as err: + temporalio.nexus.logger.exception( + "Failed to execute Nexus operation cancel method", err + ) + # TODO(nexus-prerelease): when do we use ack_cancel? + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + completed=temporalio.api.nexus.v1.Response( + cancel_operation=temporalio.api.nexus.v1.CancelOperationResponse() + ), + ) + try: + await self._bridge_worker().complete_nexus_task(completion) + except Exception as err: + temporalio.nexus.logger.exception( + "Failed to send Nexus task completion", err + ) + + async def _exception_to_failure_proto( + self, + err: BaseException, + ) -> temporalio.api.nexus.v1.Failure: + api_failure = temporalio.api.failure.v1.Failure() + await self._data_converter.encode_failure(err, api_failure) + api_failure = google.protobuf.json_format.MessageToDict(api_failure) + # TODO(nexus-prerelease): is metadata correct and playing intended role here? + return temporalio.api.nexus.v1.Failure( + message=api_failure.pop("message", ""), + metadata={"type": "temporal.api.failure.v1.Failure"}, + details=json.dumps(api_failure).encode("utf-8"), + ) + + async def _operation_error_to_proto( + self, + err: nexusrpc.handler.OperationError, + ) -> temporalio.api.nexus.v1.UnsuccessfulOperationError: + cause = err.__cause__ + if cause is None: + cause = Exception(*err.args).with_traceback(err.__traceback__) + return temporalio.api.nexus.v1.UnsuccessfulOperationError( + operation_state=err.state.value, + failure=await self._exception_to_failure_proto(cause), + ) + + async def _handler_error_to_proto( + self, err: nexusrpc.handler.HandlerError + ) -> temporalio.api.nexus.v1.HandlerError: + return temporalio.api.nexus.v1.HandlerError( + error_type=err.type.value, + failure=await self._exception_to_failure_proto(err), + # TODO(nexus-prerelease): is there a reason to support retryable=None? + retry_behavior=( + temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE + if err.retryable + else temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE + ), + ) + + +@dataclass +class _DummyPayloadSerializer: + data_converter: temporalio.converter.DataConverter + payload: temporalio.api.common.v1.Payload + + async def serialize(self, value: Any) -> nexusrpc.handler.Content: + raise NotImplementedError( + "The serialize method of the Serializer is not used by handlers" + ) + + async def deserialize( + self, + content: nexusrpc.handler.Content, + as_type: Optional[Type[Any]] = None, + ) -> Any: + try: + [input] = await self.data_converter.decode( + [self.payload], + type_hints=[as_type] if as_type else None, + ) + except Exception as err: + raise nexusrpc.handler.HandlerError( + "Data converter failed to decode Nexus operation input", + type=nexusrpc.handler.HandlerErrorType.BAD_REQUEST, + cause=err, + retryable=False, + ) from err + return input + + +# TODO(nexus-prerelease): tests for this function +def _exception_to_handler_error(err: BaseException) -> nexusrpc.handler.HandlerError: + # Based on sdk-typescript's convertKnownErrors: + # https://github.com/temporalio/sdk-typescript/blob/nexus/packages/worker/src/nexus.ts + if isinstance(err, nexusrpc.handler.HandlerError): + return err + elif isinstance(err, ApplicationError): + return nexusrpc.handler.HandlerError( + # TODO(nexus-prerelease): what should message be? + err.message, + type=nexusrpc.handler.HandlerErrorType.INTERNAL, + cause=err, + # TODO(nexus-prerelease): is there a reason to support retryable=None? + retryable=not err.non_retryable, + ) + elif isinstance(err, RPCError): + if err.status == RPCStatusCode.INVALID_ARGUMENT: + return nexusrpc.handler.HandlerError( + err.message, + type=nexusrpc.handler.HandlerErrorType.BAD_REQUEST, + cause=err, + ) + elif err.status in [ + RPCStatusCode.ALREADY_EXISTS, + RPCStatusCode.FAILED_PRECONDITION, + RPCStatusCode.OUT_OF_RANGE, + ]: + return nexusrpc.handler.HandlerError( + err.message, + type=nexusrpc.handler.HandlerErrorType.INTERNAL, + cause=err, + retryable=False, + ) + elif err.status in [RPCStatusCode.ABORTED, RPCStatusCode.UNAVAILABLE]: + return nexusrpc.handler.HandlerError( + err.message, + type=nexusrpc.handler.HandlerErrorType.UNAVAILABLE, + cause=err, + ) + elif err.status in [ + RPCStatusCode.CANCELLED, + RPCStatusCode.DATA_LOSS, + RPCStatusCode.INTERNAL, + RPCStatusCode.UNKNOWN, + RPCStatusCode.UNAUTHENTICATED, + RPCStatusCode.PERMISSION_DENIED, + ]: + # Note that UNAUTHENTICATED and PERMISSION_DENIED have Nexus error types but + # we convert to internal because this is not a client auth error and happens + # when the handler fails to auth with Temporal and should be considered + # retryable. + return nexusrpc.handler.HandlerError( + err.message, type=nexusrpc.handler.HandlerErrorType.INTERNAL, cause=err + ) + elif err.status == RPCStatusCode.NOT_FOUND: + return nexusrpc.handler.HandlerError( + err.message, type=nexusrpc.handler.HandlerErrorType.NOT_FOUND, cause=err + ) + elif err.status == RPCStatusCode.RESOURCE_EXHAUSTED: + return nexusrpc.handler.HandlerError( + err.message, + type=nexusrpc.handler.HandlerErrorType.RESOURCE_EXHAUSTED, + cause=err, + ) + elif err.status == RPCStatusCode.UNIMPLEMENTED: + return nexusrpc.handler.HandlerError( + err.message, + type=nexusrpc.handler.HandlerErrorType.NOT_IMPLEMENTED, + cause=err, + ) + elif err.status == RPCStatusCode.DEADLINE_EXCEEDED: + return nexusrpc.handler.HandlerError( + err.message, + type=nexusrpc.handler.HandlerErrorType.UPSTREAM_TIMEOUT, + cause=err, + ) + else: + return nexusrpc.handler.HandlerError( + f"Unhandled RPC error status: {err.status}", + type=nexusrpc.handler.HandlerErrorType.INTERNAL, + cause=err, + ) + return nexusrpc.handler.HandlerError( + str(err), type=nexusrpc.handler.HandlerErrorType.INTERNAL, cause=err + ) diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 4793f675e..66e1060f4 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -41,6 +41,7 @@ from ._activity import SharedStateManager, _ActivityWorker from ._interceptor import Interceptor +from ._nexus import _NexusWorker from ._tuning import WorkerTuner from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner @@ -106,9 +107,11 @@ def __init__( *, task_queue: str, activities: Sequence[Callable] = [], + nexus_services: Sequence[Any] = [], workflows: Sequence[Type] = [], activity_executor: Optional[concurrent.futures.Executor] = None, workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None, + nexus_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None, workflow_runner: WorkflowRunner = SandboxedWorkflowRunner(), unsandboxed_workflow_runner: WorkflowRunner = UnsandboxedWorkflowRunner(), interceptors: Sequence[Interceptor] = [], @@ -153,10 +156,12 @@ def __init__( client's underlying service client. This client cannot be "lazy". task_queue: Required task queue for this worker. - activities: Set of activity callables decorated with + activities: Activity callables decorated with :py:func:`@activity.defn`. Activities may be async functions or non-async functions. - workflows: Set of workflow classes decorated with + nexus_services: Nexus service instances decorated with + :py:func:`@nexusrpc.handler.service_handler`. + workflows: Workflow classes decorated with :py:func:`@workflow.defn`. activity_executor: Concurrent executor to use for non-async activities. This is required if any activities are non-async. @@ -195,9 +200,11 @@ def __init__( tasks that will ever be given to this worker at one time. Mutually exclusive with ``tuner``. Must be set to at least two if ``max_cached_workflows`` is nonzero. max_concurrent_activities: Maximum number of activity tasks that - will ever be given to this worker concurrently. Mutually exclusive with ``tuner``. + will ever be given to the activity worker concurrently. Mutually exclusive with ``tuner``. max_concurrent_local_activities: Maximum number of local activity - tasks that will ever be given to this worker concurrently. Mutually exclusive with ``tuner``. + tasks that will ever be given to the activity worker concurrently. Mutually exclusive with ``tuner``. + max_concurrent_workflow_tasks: Maximum allowed number of + tasks that will ever be given to the workflow worker at one time. Mutually exclusive with ``tuner``. tuner: Provide a custom :py:class:`WorkerTuner`. Mutually exclusive with the ``max_concurrent_workflow_tasks``, ``max_concurrent_activities``, and ``max_concurrent_local_activities`` arguments. @@ -297,8 +304,22 @@ def __init__( activity_task_poller_behavior: Specify the behavior of activity task polling. Defaults to a 5-poller maximum. """ - if not activities and not workflows: - raise ValueError("At least one activity or workflow must be specified") + # TODO(nexus-prerelease): non-async (executor-based) Nexus worker; honor + # max_concurrent_nexus_operations and nexus_operation_executor. + # nexus_operation_executor: Concurrent executor to use for non-async + # Nexus operations. This is required if any operation start methods + # are non-async. :py:class:`concurrent.futures.ThreadPoolExecutor` + # is recommended. If this is a + # :py:class:`concurrent.futures.ProcessPoolExecutor`, all non-async + # start methods must be picklable. ``max_workers`` on the executor + # should at least be ``max_concurrent_nexus_operations`` or a warning + # is issued. + # max_concurrent_nexus_operations: Maximum number of Nexus operations that + # will ever be given to the Nexus worker concurrently. Mutually exclusive with ``tuner``. + if not (activities or nexus_services or workflows): + raise ValueError( + "At least one activity, Nexus service, or workflow must be specified" + ) if use_worker_versioning and not build_id: raise ValueError( "build_id must be specified when use_worker_versioning is True" @@ -327,6 +348,7 @@ def __init__( workflows=workflows, activity_executor=activity_executor, workflow_task_executor=workflow_task_executor, + nexus_task_executor=nexus_task_executor, workflow_runner=workflow_runner, unsandboxed_workflow_runner=unsandboxed_workflow_runner, interceptors=interceptors, @@ -361,7 +383,6 @@ def __init__( self._async_context_run_task: Optional[asyncio.Task] = None self._async_context_run_exception: Optional[BaseException] = None - # Create activity and workflow worker self._activity_worker: Optional[_ActivityWorker] = None self._runtime = ( bridge_client.config.runtime or temporalio.runtime.Runtime.default() @@ -393,6 +414,20 @@ def __init__( interceptors=interceptors, metric_meter=self._runtime.metric_meter, ) + self._nexus_worker: Optional[_NexusWorker] = None + if nexus_services: + # TODO(nexus-prerelease): consider not allowing / warning on max_workers < + # max_concurrent_nexus_operations? See warning above for activity worker. + self._nexus_worker = _NexusWorker( + bridge_worker=lambda: self._bridge_worker, + client=client, + task_queue=task_queue, + nexus_services=nexus_services, + data_converter=client_config["data_converter"], + interceptors=interceptors, + metric_meter=self._runtime.metric_meter, + executor=nexus_task_executor, + ) self._workflow_worker: Optional[_WorkflowWorker] = None if workflows: should_enforce_versioning_behavior = ( @@ -432,6 +467,7 @@ def check_activity(activity): ) if tuner is not None: + # TODO(nexus-prerelease): Nexus tuner support if ( max_concurrent_workflow_tasks or max_concurrent_activities @@ -617,21 +653,30 @@ async def raise_on_shutdown(): except asyncio.CancelledError: pass - tasks: List[asyncio.Task] = [asyncio.create_task(raise_on_shutdown())] + tasks: dict[ + Union[None, _ActivityWorker, _WorkflowWorker, _NexusWorker], asyncio.Task + ] = {None: asyncio.create_task(raise_on_shutdown())} # Create tasks for workers if self._activity_worker: - tasks.append(asyncio.create_task(self._activity_worker.run())) + tasks[self._activity_worker] = asyncio.create_task( + self._activity_worker.run() + ) if self._workflow_worker: - tasks.append(asyncio.create_task(self._workflow_worker.run())) + tasks[self._workflow_worker] = asyncio.create_task( + self._workflow_worker.run() + ) + if self._nexus_worker: + tasks[self._nexus_worker] = asyncio.create_task(self._nexus_worker.run()) # Wait for either worker or shutdown requested - wait_task = asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) + wait_task = asyncio.wait(tasks.values(), return_when=asyncio.FIRST_EXCEPTION) try: await asyncio.shield(wait_task) - # If any of the last two tasks failed, we want to re-raise that as - # the exception - exception = next((t.exception() for t in tasks[1:] if t.done()), None) + # If any of the worker tasks failed, re-raise that as the exception + exception = next( + (t.exception() for w, t in tasks.items() if w and t.done()), None + ) if exception: logger.error("Worker failed, shutting down", exc_info=exception) if self._config["on_fatal_error"]: @@ -646,7 +691,7 @@ async def raise_on_shutdown(): exception = user_cancel_err # Cancel the shutdown task (safe if already done) - tasks[0].cancel() + tasks[None].cancel() graceful_timeout = self._config["graceful_shutdown_timeout"] logger.info( f"Beginning worker shutdown, will wait {graceful_timeout} before cancelling activities" @@ -655,18 +700,10 @@ async def raise_on_shutdown(): # Initiate core worker shutdown self._bridge_worker.initiate_shutdown() - # If any worker task had an exception, replace that task with a queue - # drain (task at index 1 can be activity or workflow worker, task at - # index 2 must be workflow worker if present) - if tasks[1].done() and tasks[1].exception(): - if self._activity_worker: - tasks[1] = asyncio.create_task(self._activity_worker.drain_poll_queue()) - else: - assert self._workflow_worker - tasks[1] = asyncio.create_task(self._workflow_worker.drain_poll_queue()) - if len(tasks) > 2 and tasks[2].done() and tasks[2].exception(): - assert self._workflow_worker - tasks[2] = asyncio.create_task(self._workflow_worker.drain_poll_queue()) + # If any worker task had an exception, replace that task with a queue drain + for worker, task in tasks.items(): + if worker and task.done() and task.exception(): + tasks[worker] = asyncio.create_task(worker.drain_poll_queue()) # Notify shutdown occurring if self._activity_worker: @@ -675,20 +712,23 @@ async def raise_on_shutdown(): self._workflow_worker.notify_shutdown() # Wait for all tasks to complete (i.e. for poller loops to stop) - await asyncio.wait(tasks) + await asyncio.wait(tasks.values()) # Sometimes both workers throw an exception and since we only take the # first, Python may complain with "Task exception was never retrieved" # if we don't get the others. Therefore we call cancel on each task # which suppresses this. - for task in tasks: + for task in tasks.values(): task.cancel() - # If there's an activity worker, we have to let all activity completions - # finish. We cannot guarantee that because poll shutdown completed - # (which means activities completed) that they got flushed to the - # server. + # Let all activity / nexus operations completions finish. We cannot guarantee that + # because poll shutdown completed (which means activities/operations completed) + # that they got flushed to the server. if self._activity_worker: await self._activity_worker.wait_all_completed() + if self._nexus_worker: + await self._nexus_worker.wait_all_completed() + + # TODO(nexus-prerelease): check that we do all appropriate things for nexus worker that we do for activity worker # Do final shutdown try: @@ -770,6 +810,7 @@ class WorkerConfig(TypedDict, total=False): workflows: Sequence[Type] activity_executor: Optional[concurrent.futures.Executor] workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] + nexus_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] workflow_runner: WorkflowRunner unsandboxed_workflow_runner: WorkflowRunner interceptors: Sequence[Interceptor] diff --git a/tests/conftest.py b/tests/conftest.py index 37b1fe89c..f3baa1b72 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -99,6 +99,7 @@ def env_type(request: pytest.FixtureRequest) -> str: @pytest_asyncio.fixture(scope="session") async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: if env_type == "local": + http_port = 7243 env = await WorkflowEnvironment.start_local( dev_server_extra_args=[ "--dynamic-config-value", @@ -117,6 +118,8 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: "system.enableDeploymentVersions=true", "--dynamic-config-value", "frontend.activityAPIsEnabled=true", + "--http-port", + str(http_port), ], dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION, ) @@ -124,6 +127,9 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: env = await WorkflowEnvironment.start_time_skipping() else: env = WorkflowEnvironment.from_client(await Client.connect(env_type)) + + # TODO(nexus-prerelease): expose this in a principled way + env._http_port = http_port # type: ignore yield env await env.shutdown() diff --git a/tests/helpers/nexus.py b/tests/helpers/nexus.py new file mode 100644 index 000000000..878111438 --- /dev/null +++ b/tests/helpers/nexus.py @@ -0,0 +1,37 @@ +import temporalio.api +import temporalio.api.common +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.api.nexus +import temporalio.api.nexus.v1 +import temporalio.api.operatorservice +import temporalio.api.operatorservice.v1 +import temporalio.nexus +import temporalio.nexus.handler +from temporalio.client import Client + + +def make_nexus_endpoint_name(task_queue: str) -> str: + # Create endpoints for different task queues without name collisions. + return f"nexus-endpoint-{task_queue}" + + +# TODO(nexus-prerelease): How do we recommend that users create endpoints in their own tests? +# See https://github.com/temporalio/sdk-typescript/pull/1708/files?show-viewed-files=true&file-filters%5B%5D=&w=0#r2082549085 +async def create_nexus_endpoint( + task_queue: str, client: Client +) -> temporalio.api.operatorservice.v1.CreateNexusEndpointResponse: + name = make_nexus_endpoint_name(task_queue) + return await client.operator_service.create_nexus_endpoint( + temporalio.api.operatorservice.v1.CreateNexusEndpointRequest( + spec=temporalio.api.nexus.v1.EndpointSpec( + name=name, + target=temporalio.api.nexus.v1.EndpointTarget( + worker=temporalio.api.nexus.v1.EndpointTarget.Worker( + namespace=client.namespace, + task_queue=task_queue, + ) + ), + ) + ) + ) diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py new file mode 100644 index 000000000..da3925e80 --- /dev/null +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -0,0 +1,83 @@ +import uuid + +import httpx +import nexusrpc.handler +import pytest + +from temporalio.client import Client +from temporalio.worker import Worker +from tests.helpers.nexus import create_nexus_endpoint + +HTTP_PORT = 7243 + + +# TODO(nexus-prerelease): test programmatic creation from ServiceHandler +def make_incrementer_service_from_service_handler( + op_names: list[str], +) -> tuple[str, type]: + pass + + +def make_incrementer_user_service_definition_and_service_handler_classes( + op_names: list[str], +) -> tuple[type, type]: + # + # service contract + # + + ops = {name: nexusrpc.Operation[int, int] for name in op_names} + service_cls = nexusrpc.service(type("ServiceContract", (), ops)) + + # + # service handler + # + async def _increment_op( + self, + ctx: nexusrpc.handler.StartOperationContext, + input: int, + ) -> int: + return input + 1 + + op_handler_factories = { + # TODO(nexus-prerelease): check that name=name should be required here. Should the op factory + # name not default to the name of the method attribute (i.e. key), as opposed to + # the name of the method object (i.e. value.__name__)? + name: nexusrpc.handler.sync_operation_handler(_increment_op, name=name) + for name in op_names + } + + handler_cls = nexusrpc.handler.service_handler(service=service_cls)( + type("ServiceImpl", (), op_handler_factories) + ) + + return service_cls, handler_cls + + +@pytest.mark.skip( + reason="Dynamic creation of service contract using type() is not supported" +) +async def test_dynamic_creation_of_user_handler_classes(client: Client): + task_queue = str(uuid.uuid4()) + + service_cls, handler_cls = ( + make_incrementer_user_service_definition_and_service_handler_classes( + ["increment"] + ) + ) + + service_name = service_cls.__nexus_service__.name + + endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id + async with Worker( + client, + task_queue=task_queue, + nexus_services=[handler_cls()], + ): + async with httpx.AsyncClient() as http_client: + response = await http_client.post( + f"http://127.0.0.1:{HTTP_PORT}/nexus/endpoints/{endpoint}/services/{service_name}/increment", + json=1, + headers={}, + ) + assert response.status_code == 200 + assert response.json() == 2 diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py new file mode 100644 index 000000000..981df21e0 --- /dev/null +++ b/tests/nexus/test_handler.py @@ -0,0 +1,904 @@ +""" +See https://github.com/nexus-rpc/api/blob/main/SPEC.md + +This file contains test coverage for Nexus StartOperation and CancelOperation +operations issued by a caller directly via HTTP. + +The response to StartOperation may indicate a protocol-level failure (400 +BAD_REQUEST, 520 UPSTREAM_TIMEOUT, etc). In this case the body is a valid +Failure object. + + +(https://github.com/nexus-rpc/api/blob/main/SPEC.md#predefined-handler-errors) + +""" + +import asyncio +import concurrent.futures +import dataclasses +import json +import logging +import uuid +from concurrent.futures.thread import ThreadPoolExecutor +from dataclasses import dataclass +from types import MappingProxyType +from typing import Any, Callable, Mapping, Optional, Type, Union + +import httpx +import nexusrpc +import nexusrpc.handler +import pytest +from google.protobuf import json_format +from nexusrpc.testing.client import ServiceClient + +import temporalio.api.failure.v1 +import temporalio.nexus +from temporalio import workflow +from temporalio.client import Client, WorkflowHandle +from temporalio.converter import FailureConverter, PayloadConverter +from temporalio.exceptions import ApplicationError +from temporalio.nexus import logger +from temporalio.nexus.handler import start_workflow +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from tests.helpers.nexus import create_nexus_endpoint + +HTTP_PORT = 7243 + + +@dataclass +class Input: + value: str + + +@dataclass +class Output: + value: str + + +# TODO: type check nexus implementation under mypy + +# TODO(nexus-prerelease): test dynamic creation of a service from unsugared definition +# TODO(nexus-prerelease): test malformed inbound_links and outbound_links + +# TODO(nexus-prerelease): test good error message on forgetting to add decorators etc + + +@nexusrpc.service +class MyService: + echo: nexusrpc.Operation[Input, Output] + # TODO(nexus-prerelease): support renamed operations! + # echo_renamed: nexusrpc.Operation[Input, Output] = ( + # nexusrpc.Operation(name="echo-renamed") + # ) + hang: nexusrpc.Operation[Input, Output] + log: nexusrpc.Operation[Input, Output] + async_operation: nexusrpc.Operation[Input, Output] + async_operation_without_type_annotations: nexusrpc.Operation[Input, Output] + sync_operation_without_type_annotations: nexusrpc.Operation[Input, Output] + sync_operation_with_non_async_def: nexusrpc.Operation[Input, Output] + sync_operation_with_non_async_callable_instance: nexusrpc.Operation[Input, Output] + operation_returning_unwrapped_result_at_runtime_error: nexusrpc.Operation[ + Input, Output + ] + non_retryable_application_error: nexusrpc.Operation[Input, Output] + retryable_application_error: nexusrpc.Operation[Input, Output] + check_operation_timeout_header: nexusrpc.Operation[Input, Output] + workflow_run_op_link_test: nexusrpc.Operation[Input, Output] + handler_error_internal: nexusrpc.Operation[Input, Output] + operation_error_failed: nexusrpc.Operation[Input, Output] + + +@workflow.defn +class MyWorkflow: + @workflow.run + async def run(self, input: Input) -> Output: + return Output(value=f"from workflow: {input.value}") + + +@workflow.defn +class WorkflowWithoutTypeAnnotations: + @workflow.run + async def run(self, input): # type: ignore + return Output(value=f"from workflow without type annotations: {input}") + + +@workflow.defn +class MyLinkTestWorkflow: + @workflow.run + async def run(self, input: Input) -> Output: + return Output(value=f"from link test workflow: {input.value}") + + +# TODO: implement some of these ops as explicit OperationHandler classes to provide coverage for that? + + +# The service_handler decorator is applied by the test +class MyServiceHandler: + @nexusrpc.handler.sync_operation_handler + async def echo( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> Output: + assert ctx.headers["test-header-key"] == "test-header-value" + ctx.outbound_links.extend(ctx.inbound_links) + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + @nexusrpc.handler.sync_operation_handler + async def hang( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> Output: + await asyncio.Future() + return Output(value="won't reach here") + + @nexusrpc.handler.sync_operation_handler + async def non_retryable_application_error( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> Output: + raise ApplicationError( + "non-retryable application error", + "details arg", + # TODO(nexus-prerelease): what values of `type` should be tested? + type="TestFailureType", + non_retryable=True, + ) + + @nexusrpc.handler.sync_operation_handler + async def retryable_application_error( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> Output: + raise ApplicationError( + "retryable application error", + "details arg", + type="TestFailureType", + non_retryable=False, + ) + + @nexusrpc.handler.sync_operation_handler + async def handler_error_internal( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> Output: + raise nexusrpc.handler.HandlerError( + message="deliberate internal handler error", + type=nexusrpc.handler.HandlerErrorType.INTERNAL, + retryable=False, + cause=RuntimeError("cause message"), + ) + + @nexusrpc.handler.sync_operation_handler + async def operation_error_failed( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> Output: + raise nexusrpc.handler.OperationError( + message="deliberate operation error", + state=nexusrpc.handler.OperationErrorState.FAILED, + ) + + @nexusrpc.handler.sync_operation_handler + async def check_operation_timeout_header( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> Output: + assert "operation-timeout" in ctx.headers + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + @nexusrpc.handler.sync_operation_handler + async def log( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> Output: + logger.info("Logging from start method", extra={"input_value": input.value}) + return Output(value=f"logged: {input.value}") + + @temporalio.nexus.handler.workflow_run_operation_handler + async def async_operation( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> WorkflowHandle[Any, Output]: + assert "operation-timeout" in ctx.headers + return await start_workflow( + ctx, + MyWorkflow.run, + input, + id=str(uuid.uuid4()), + ) + + @nexusrpc.handler.sync_operation_handler + def sync_operation_with_non_async_def( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> Output: + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + class sync_operation_with_non_async_callable_instance: + def __call__( + self, + _handler: "MyServiceHandler", + ctx: nexusrpc.handler.StartOperationContext, + input: Input, + ) -> Output: + return Output( + value=f"from start method on {_handler.__class__.__name__}: {input.value}" + ) + + _sync_operation_with_non_async_callable_instance = ( + nexusrpc.handler.sync_operation_handler( + name="sync_operation_with_non_async_callable_instance", + )( + sync_operation_with_non_async_callable_instance(), + ) + ) + + @nexusrpc.handler.sync_operation_handler + async def sync_operation_without_type_annotations(self, ctx, input): + # The input type from the op definition in the service definition is used to deserialize the input. + return Output( + value=f"from start method on {self.__class__.__name__} without type annotations: {input}" + ) + + @temporalio.nexus.handler.workflow_run_operation_handler + async def async_operation_without_type_annotations(self, ctx, input): + return await start_workflow( + ctx, + WorkflowWithoutTypeAnnotations.run, + input, + id=str(uuid.uuid4()), + ) + + @temporalio.nexus.handler.workflow_run_operation_handler + async def workflow_run_op_link_test( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> WorkflowHandle[Any, Output]: + assert any( + link.url == "http://inbound-link/" for link in ctx.inbound_links + ), "Inbound link not found" + assert ctx.request_id == "test-request-id-123", "Request ID mismatch" + ctx.outbound_links.extend(ctx.inbound_links) + return await start_workflow( + ctx, + MyLinkTestWorkflow.run, + input, + id=f"link-test-{uuid.uuid4()}", + ) + + class OperationHandlerReturningUnwrappedResult( + nexusrpc.handler.SyncOperationHandler[Input, Output] + ): + async def start( + self, + ctx: nexusrpc.handler.StartOperationContext, + input: Input, + # This return type is a type error, but VSCode doesn't flag it unless + # "python.analysis.typeCheckingMode" is set to "strict" + ) -> Output: + # Invalid: start method must wrap result as StartOperationResultSync + # or StartOperationResultAsync + return Output(value="unwrapped result error") # type: ignore + + @nexusrpc.handler.operation_handler + def operation_returning_unwrapped_result_at_runtime_error( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + return MyServiceHandler.OperationHandlerReturningUnwrappedResult() + + +@dataclass +class Failure: + message: str = "" + metadata: Optional[dict[str, str]] = None + details: Optional[dict[str, Any]] = None + + exception: Optional[BaseException] = dataclasses.field(init=False, default=None) + + def __post_init__(self) -> None: + if self.metadata and (error_type := self.metadata.get("type")): + self.exception = self._instantiate_exception(error_type, self.details) + + def _instantiate_exception( + self, error_type: str, details: Optional[dict[str, Any]] + ) -> BaseException: + proto = { + "temporal.api.failure.v1.Failure": temporalio.api.failure.v1.Failure, + }[error_type]() + json_format.ParseDict(self.details, proto, ignore_unknown_fields=True) + return FailureConverter.default.from_failure(proto, PayloadConverter.default) + + +# Immutable dicts that can be used as dataclass field defaults + +SUCCESSFUL_RESPONSE_HEADERS = MappingProxyType( + { + "content-type": "application/json", + } +) + +UNSUCCESSFUL_RESPONSE_HEADERS = MappingProxyType( + { + "content-type": "application/json", + "temporal-nexus-failure-source": "worker", + } +) + + +@dataclass +class SuccessfulResponse: + status_code: int + body_json: Optional[Union[dict[str, Any], Callable[[dict[str, Any]], bool]]] = None + headers: Mapping[str, str] = SUCCESSFUL_RESPONSE_HEADERS + + +@dataclass +class UnsuccessfulResponse: + status_code: int + # Expected value of Nexus-Request-Retryable header + retryable_header: Optional[bool] + failure_message: Union[str, Callable[[str], bool]] + # Expected value of inverse of non_retryable attribute of exception. + retryable_exception: bool = True + # TODO(nexus-prerelease): the body of a successful response need not be JSON; test non-JSON-parseable string + body_json: Optional[Callable[[dict[str, Any]], bool]] = None + headers: Mapping[str, str] = UNSUCCESSFUL_RESPONSE_HEADERS + + +class _TestCase: + operation: str + service_defn: str = "MyService" + input: Input = Input("") + headers: dict[str, str] = {} + expected: SuccessfulResponse + expected_without_service_definition: Optional[SuccessfulResponse] = None + skip = "" + + @classmethod + def check_response( + cls, + response: httpx.Response, + with_service_definition: bool, + ) -> None: + assert response.status_code == cls.expected.status_code, ( + f"expected status code {cls.expected.status_code} " + f"but got {response.status_code} for response content {response.content.decode()}" + ) + if not with_service_definition and cls.expected_without_service_definition: + expected = cls.expected_without_service_definition + else: + expected = cls.expected + if expected.body_json is not None: + body = response.json() + assert isinstance(body, dict) + if isinstance(expected.body_json, dict): + assert body == expected.body_json + else: + assert expected.body_json(body) + assert response.headers.items() >= cls.expected.headers.items() + + +class _FailureTestCase(_TestCase): + expected: UnsuccessfulResponse + + @classmethod + def check_response( + cls, response: httpx.Response, with_service_definition: bool + ) -> None: + super().check_response(response, with_service_definition) + failure = Failure(**response.json()) + + if isinstance(cls.expected.failure_message, str): + assert failure.message == cls.expected.failure_message + else: + assert cls.expected.failure_message(failure.message) + + # retryability assertions + if ( + retryable_header := response.headers.get("nexus-request-retryable") + ) is not None: + assert json.loads(retryable_header) == cls.expected.retryable_header + else: + assert cls.expected.retryable_header is None + + if failure.exception: + assert isinstance(failure.exception, ApplicationError) + assert failure.exception.non_retryable == ( + not cls.expected.retryable_exception + ) + else: + print(f"TODO(dan): {cls} did not yield a Failure with exception details") + + +class SyncHandlerHappyPath(_TestCase): + operation = "echo" + input = Input("hello") + # TODO(nexus-prerelease): why is application/json randomly scattered around these tests? + headers = { + "Content-Type": "application/json", + "Test-Header-Key": "test-header-value", + "Nexus-Link": '; type="test"', + } + expected = SuccessfulResponse( + status_code=200, + body_json={"value": "from start method on MyServiceHandler: hello"}, + ) + # TODO(nexus-prerelease): headers should be lower-cased + assert ( + headers.get("Nexus-Link") == '; type="test"' + ), "Nexus-Link header not echoed correctly." + + +class SyncHandlerHappyPathRenamed(_TestCase): + operation = "echo-renamed" + input = Input("hello") + expected = SuccessfulResponse( + status_code=200, + body_json={"value": "from start method on MyServiceHandler: hello"}, + ) + + +class SyncHandlerHappyPathNonAsyncDef(_TestCase): + operation = "sync_operation_with_non_async_def" + input = Input("hello") + expected = SuccessfulResponse( + status_code=200, + body_json={"value": "from start method on MyServiceHandler: hello"}, + ) + + +class SyncHandlerHappyPathWithNonAsyncCallableInstance(_TestCase): + operation = "sync_operation_with_non_async_callable_instance" + input = Input("hello") + expected = SuccessfulResponse( + status_code=200, + body_json={"value": "from start method on MyServiceHandler: hello"}, + ) + + +class SyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): + operation = "sync_operation_without_type_annotations" + input = Input("hello") + expected = SuccessfulResponse( + status_code=200, + body_json={ + "value": "from start method on MyServiceHandler without type annotations: Input(value='hello')" + }, + ) + expected_without_service_definition = SuccessfulResponse( + status_code=200, + body_json={ + "value": "from start method on MyServiceHandler without type annotations: {'value': 'hello'}" + }, + ) + + +class AsyncHandlerHappyPath(_TestCase): + operation = "async_operation" + input = Input("hello") + headers = {"Operation-Timeout": "777s"} + expected = SuccessfulResponse( + status_code=201, + ) + + +class AsyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): + operation = "async_operation_without_type_annotations" + input = Input("hello") + expected = SuccessfulResponse( + status_code=201, + ) + + +class WorkflowRunOpLinkTestHappyPath(_TestCase): + # TODO(nexus-prerelease): fix this test + skip = "Yields invalid link" + operation = "workflow_run_op_link_test" + input = Input("link-test-input") + headers = { + "Nexus-Link": '; type="test"', + "Nexus-Request-Id": "test-request-id-123", + } + expected = SuccessfulResponse( + status_code=201, + ) + + @classmethod + def check_response( + cls, response: httpx.Response, with_service_definition: bool + ) -> None: + super().check_response(response, with_service_definition) + nexus_link = response.headers.get("nexus-link") + assert nexus_link is not None, "nexus-link header not found in response" + assert nexus_link.startswith( + " None: + super().check_response(response, with_service_definition) + failure = Failure(**response.json()) + err = failure.exception + assert isinstance(err, ApplicationError) + assert err.non_retryable + assert err.type == "TestFailureType" + assert err.details == ("details arg",) + + +class RetryableApplicationError(_FailureTestCase): + operation = "retryable_application_error" + expected = UnsuccessfulResponse( + status_code=500, + retryable_header=True, + failure_message="retryable application error", + ) + + +class HandlerErrorInternal(_FailureTestCase): + operation = "handler_error_internal" + expected = UnsuccessfulResponse( + status_code=500, + # TODO(nexus-prerelease): check this assertion + retryable_header=False, + failure_message="cause message", + ) + + +class OperationError(_FailureTestCase): + operation = "operation_error_failed" + expected = UnsuccessfulResponse( + status_code=424, + # TODO(nexus-prerelease): check that OperationError should not set retryable header + retryable_header=None, + failure_message="deliberate operation error", + headers=UNSUCCESSFUL_RESPONSE_HEADERS | {"nexus-operation-state": "failed"}, + ) + + +class UnknownService(_FailureTestCase): + service_defn = "NonExistentService" + operation = "" + expected = UnsuccessfulResponse( + status_code=404, + retryable_header=False, + failure_message="No handler for service 'NonExistentService'.", + ) + + +class UnknownOperation(_FailureTestCase): + operation = "NonExistentOperation" + expected = UnsuccessfulResponse( + status_code=404, + retryable_header=False, + failure_message=lambda s: s.startswith( + "Nexus service definition 'MyService' has no operation 'NonExistentOperation'." + ), + ) + + +@pytest.mark.parametrize( + "test_case", + [ + SyncHandlerHappyPath, + # TODO(nexus-prerelease): support renamed operations! + # SyncHandlerHappyPathRenamed, + SyncHandlerHappyPathNonAsyncDef, + # TODO(nexus-prerelease): make callable instance work + # SyncHandlerHappyPathWithNonAsyncCallableInstance, + SyncHandlerHappyPathWithoutTypeAnnotations, + AsyncHandlerHappyPath, + AsyncHandlerHappyPathWithoutTypeAnnotations, + WorkflowRunOpLinkTestHappyPath, + ], +) +@pytest.mark.parametrize("with_service_definition", [True, False]) +async def test_start_operation_happy_path( + test_case: Type[_TestCase], + with_service_definition: bool, + env: WorkflowEnvironment, +): + await _test_start_operation(test_case, with_service_definition, env) + + +@pytest.mark.parametrize( + "test_case", + [ + OperationHandlerReturningUnwrappedResultError, + UpstreamTimeoutViaRequestTimeout, + OperationTimeoutHeader, + BadRequest, + HandlerErrorInternal, + UnknownService, + UnknownOperation, + ], +) +async def test_start_operation_protocol_level_failures( + test_case: Type[_TestCase], env: WorkflowEnvironment +): + await _test_start_operation(test_case, True, env) + + +@pytest.mark.parametrize( + "test_case", + [ + NonRetryableApplicationError, + RetryableApplicationError, + OperationError, + ], +) +async def test_start_operation_operation_failures( + test_case: Type[_TestCase], env: WorkflowEnvironment +): + await _test_start_operation(test_case, True, env) + + +async def _test_start_operation( + test_case: Type[_TestCase], + with_service_definition: bool, + env: WorkflowEnvironment, +): + if test_case.skip: + pytest.skip(test_case.skip) + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + server_address=f"http://127.0.0.1:{env._http_port}", # type: ignore + endpoint=endpoint, + service=( + test_case.service_defn + if with_service_definition + else MyServiceHandler.__name__ + ), + ) + + decorator = ( + nexusrpc.handler.service_handler(service=MyService) + if with_service_definition + else nexusrpc.handler.service_handler + ) + service_handler = decorator(MyServiceHandler)() + + async with Worker( + env.client, + task_queue=task_queue, + nexus_services=[service_handler], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + response = await service_client.start_operation( + test_case.operation, + dataclass_as_dict(test_case.input), + test_case.headers, + ) + test_case.check_response(response, with_service_definition) + + +async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: Any): + task_queue = str(uuid.uuid4()) + service_name = MyService.__name__ + operation_name = "log" + resp = await create_nexus_endpoint(task_queue, env.client) + endpoint = resp.endpoint.id + service_client = ServiceClient( + server_address=f"http://127.0.0.1:{env._http_port}", # type: ignore + endpoint=endpoint, + service=service_name, + ) + caplog.set_level(logging.INFO) + + async with Worker( + env.client, + task_queue=task_queue, + nexus_services=[MyServiceHandler()], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + response = await service_client.start_operation( + operation_name, + dataclass_as_dict(Input("test_log")), + { + "Content-Type": "application/json", + "Test-Log-Header": "test-log-header-value", + }, + ) + assert response.is_success + response.raise_for_status() + output_json = response.json() + assert output_json == {"value": "logged: test_log"} + + record = next( + ( + record + for record in caplog.records + if record.name == "temporalio.nexus" + and record.getMessage() == "Logging from start method" + ), + None, + ) + assert record is not None, "Expected log message not found" + assert record.levelname == "INFO" + assert getattr(record, "input_value", None) == "test_log" + assert getattr(record, "service", None) == service_name + assert getattr(record, "operation", None) == operation_name + + +def dataclass_as_dict(dataclass: Any) -> dict[str, Any]: + """ + Return a shallow dict of the dataclass's fields. + + dataclasses.as_dict goes too far (attempts to pickle values) + """ + return { + field.name: getattr(dataclass, field.name) + for field in dataclasses.fields(dataclass) + } + + +class _InstantiationCase: + executor: bool + handler: Callable[[], Any] + exception: Optional[Type[Exception]] + match: Optional[str] + + +@nexusrpc.service +class EchoService: + echo: nexusrpc.Operation[Input, Output] + + +@nexusrpc.handler.service_handler(service=EchoService) +class SyncStartHandler: + @nexusrpc.handler.sync_operation_handler + def echo(self, ctx: nexusrpc.handler.StartOperationContext, input: Input) -> Output: + assert ctx.headers["test-header-key"] == "test-header-value" + ctx.outbound_links.extend(ctx.inbound_links) + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + +@nexusrpc.handler.service_handler(service=EchoService) +class DefaultCancelHandler: + @nexusrpc.handler.sync_operation_handler + async def echo( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> Output: + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + +@nexusrpc.handler.service_handler(service=EchoService) +class SyncCancelHandler: + class SyncCancel(nexusrpc.handler.SyncOperationHandler[Input, Output]): + async def start( + self, + ctx: nexusrpc.handler.StartOperationContext, + input: Input, + # This return type is a type error, but VSCode doesn't flag it unless + # "python.analysis.typeCheckingMode" is set to "strict" + ) -> Output: + # Invalid: start method must wrap result as StartOperationResultSync + # or StartOperationResultAsync + return Output(value="Hello") # type: ignore + + def cancel( + self, ctx: nexusrpc.handler.CancelOperationContext, token: str + ) -> Output: + return Output(value="Hello") # type: ignore + + @nexusrpc.handler.operation_handler + def echo(self) -> nexusrpc.handler.OperationHandler[Input, Output]: + return SyncCancelHandler.SyncCancel() + + +class SyncHandlerNoExecutor(_InstantiationCase): + handler = SyncStartHandler + executor = False + exception = RuntimeError + match = "start must be an `async def`" + + +class DefaultCancel(_InstantiationCase): + handler = DefaultCancelHandler + executor = False + exception = None + + +class SyncCancel(_InstantiationCase): + handler = SyncCancelHandler + executor = False + exception = RuntimeError + match = "cancel must be an `async def`" + + +@pytest.mark.parametrize( + "test_case", + [SyncHandlerNoExecutor, DefaultCancel, SyncCancel], +) +async def test_handler_instantiation( + test_case: Type[_InstantiationCase], client: Client +): + task_queue = str(uuid.uuid4()) + + if test_case.exception is not None: + with pytest.raises(test_case.exception, match=test_case.match): + Worker( + client, + task_queue=task_queue, + nexus_services=[test_case.handler()], + nexus_task_executor=ThreadPoolExecutor() + if test_case.executor + else None, + ) + else: + Worker( + client, + task_queue=task_queue, + nexus_services=[test_case.handler()], + nexus_task_executor=ThreadPoolExecutor() if test_case.executor else None, + ) diff --git a/tests/nexus/test_handler_async_operation.py b/tests/nexus/test_handler_async_operation.py new file mode 100644 index 000000000..dc7fc0dec --- /dev/null +++ b/tests/nexus/test_handler_async_operation.py @@ -0,0 +1,260 @@ +""" +Test that the Nexus SDK can be used to define an operation that responds asynchronously. +""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import dataclasses +import uuid +from collections.abc import Coroutine +from dataclasses import dataclass, field +from typing import Any, Type, Union + +import nexusrpc +import nexusrpc.handler +import pytest +from nexusrpc.handler import ( + CancelOperationContext, + FetchOperationInfoContext, + FetchOperationResultContext, + OperationHandler, + OperationInfo, + StartOperationContext, + StartOperationResultAsync, +) +from nexusrpc.testing.client import ServiceClient + +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from tests.helpers.nexus import create_nexus_endpoint + + +@dataclass +class Input: + value: str + + +@dataclass +class Output: + value: str + + +@dataclass +class AsyncOperationWithAsyncDefs(OperationHandler[Input, Output]): + executor: TaskExecutor + + async def start( + self, ctx: StartOperationContext, input: Input + ) -> StartOperationResultAsync: + async def task() -> Output: + await asyncio.sleep(0.1) + return Output("Hello from async operation!") + + task_id = str(uuid.uuid4()) + await self.executor.add_task(task_id, task()) + return StartOperationResultAsync(token=task_id) + + async def fetch_info( + self, ctx: FetchOperationInfoContext, token: str + ) -> OperationInfo: + # status = self.executor.get_task_status(task_id=token) + # return OperationInfo(token=token, status=status) + raise NotImplementedError( + "Not possible to test this currently since the server's Nexus implementation does not support fetch_info" + ) + + async def fetch_result( + self, ctx: FetchOperationResultContext, token: str + ) -> Output: + # return await self.executor.get_task_result(task_id=token) + raise NotImplementedError( + "Not possible to test this currently since the server's Nexus implementation does not support fetch_result" + ) + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + self.executor.request_cancel_task(task_id=token) + + +@dataclass +class AsyncOperationWithNonAsyncDefs(OperationHandler[Input, Output]): + executor: TaskExecutor + + def start( + self, ctx: StartOperationContext, input: Input + ) -> StartOperationResultAsync: + async def task() -> Output: + await asyncio.sleep(0.1) + return Output("Hello from async operation!") + + task_id = str(uuid.uuid4()) + self.executor.add_task_sync(task_id, task()) + return StartOperationResultAsync(token=task_id) + + def fetch_info(self, ctx: FetchOperationInfoContext, token: str) -> OperationInfo: + # status = self.executor.get_task_status(task_id=token) + # return OperationInfo(token=token, status=status) + raise NotImplementedError( + "Not possible to test this currently since the server's Nexus implementation does not support fetch_info" + ) + + def fetch_result(self, ctx: FetchOperationResultContext, token: str) -> Output: + # return self.executor.get_task_result_sync(task_id=token) + raise NotImplementedError( + "Not possible to test this currently since the server's Nexus implementation does not support fetch_result" + ) + + def cancel(self, ctx: CancelOperationContext, token: str) -> None: + self.executor.request_cancel_task(task_id=token) + + +@dataclass +@nexusrpc.handler.service_handler +class MyServiceHandlerWithAsyncDefs: + executor: TaskExecutor + + @nexusrpc.handler.operation_handler + def async_operation(self) -> OperationHandler[Input, Output]: + return AsyncOperationWithAsyncDefs(self.executor) + + +@dataclass +@nexusrpc.handler.service_handler +class MyServiceHandlerWithNonAsyncDefs: + executor: TaskExecutor + + @nexusrpc.handler.operation_handler + def async_operation(self) -> OperationHandler[Input, Output]: + return AsyncOperationWithNonAsyncDefs(self.executor) + + +@pytest.mark.parametrize( + "service_handler_cls", + [ + MyServiceHandlerWithAsyncDefs, + MyServiceHandlerWithNonAsyncDefs, + ], +) +async def test_async_operation_lifecycle( + env: WorkflowEnvironment, + service_handler_cls: Union[ + Type[MyServiceHandlerWithAsyncDefs], + Type[MyServiceHandlerWithNonAsyncDefs], + ], +): + task_executor = await TaskExecutor.connect() + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + f"http://127.0.0.1:{env._http_port}", # type: ignore + endpoint, + service_handler_cls.__name__, + ) + + async with Worker( + env.client, + task_queue=task_queue, + nexus_services=[service_handler_cls(task_executor)], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + start_response = await service_client.start_operation( + "async_operation", + body=dataclass_as_dict(Input(value="Hello from test")), + ) + assert start_response.status_code == 201 + assert start_response.json()["token"] + assert start_response.json()["state"] == "running" + + # Cancel it + cancel_response = await service_client.cancel_operation( + "async_operation", + token=start_response.json()["token"], + ) + assert cancel_response.status_code == 202 + + # get_info and get_result not implemented by server + + +@dataclass +class TaskExecutor: + """ + This class represents the task execution platform being used by the team operating the + Nexus operation. + """ + + event_loop: asyncio.AbstractEventLoop + tasks: dict[str, asyncio.Task[Any]] = field(default_factory=dict) + + @classmethod + async def connect(cls) -> TaskExecutor: + return cls(event_loop=asyncio.get_running_loop()) + + async def add_task(self, task_id: str, coro: Coroutine[Any, Any, Any]) -> None: + """ + Add a task to the task execution platform. + """ + if task_id in self.tasks: + raise RuntimeError(f"Task with id {task_id} already exists") + + # This function is async def because in reality this step will often write to + # durable storage. + self.tasks[task_id] = asyncio.create_task(coro) + + def add_task_sync(self, task_id: str, coro: Coroutine[Any, Any, Any]) -> None: + """ + Add a task to the task execution platform from a sync context. + """ + asyncio.run_coroutine_threadsafe( + self.add_task(task_id, coro), self.event_loop + ).result() + + def get_task_status(self, task_id: str) -> nexusrpc.handler.OperationState: + task = self.tasks[task_id] + if not task.done(): + return nexusrpc.handler.OperationState.RUNNING + elif task.cancelled(): + return nexusrpc.handler.OperationState.CANCELED + elif task.exception(): + return nexusrpc.handler.OperationState.FAILED + else: + return nexusrpc.handler.OperationState.SUCCEEDED + + async def get_task_result(self, task_id: str) -> Any: + """ + Get the result of a task from the task execution platform. + """ + task = self.tasks.get(task_id) + if not task: + raise RuntimeError(f"Task not found with id {task_id}") + return await task + + def get_task_result_sync(self, task_id: str) -> Any: + """ + Get the result of a task from the task execution platform from a sync context. + """ + return asyncio.run_coroutine_threadsafe( + self.get_task_result(task_id), self.event_loop + ).result() + + def request_cancel_task(self, task_id: str) -> None: + """ + Request cancellation of a task on the task execution platform. + """ + task = self.tasks.get(task_id) + if not task: + raise RuntimeError(f"Task not found with id {task_id}") + task.cancel() + # Not implemented: cancellation confirmation, deletion on cancellation + + +def dataclass_as_dict(dataclass: Any) -> dict[str, Any]: + """ + Return a shallow dict of the dataclass's fields. + + dataclasses.as_dict goes too far (attempts to pickle values) + """ + return { + field.name: getattr(dataclass, field.name) + for field in dataclasses.fields(dataclass) + } diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py new file mode 100644 index 000000000..f688ca791 --- /dev/null +++ b/tests/nexus/test_handler_interface_implementation.py @@ -0,0 +1,64 @@ +from typing import Any, Optional, Type + +import nexusrpc +import nexusrpc.handler +import pytest + +import temporalio.api.failure.v1 +import temporalio.nexus +from temporalio.client import WorkflowHandle + +HTTP_PORT = 7243 + + +class _InterfaceImplementationTestCase: + Interface: Type[Any] + Impl: Type[Any] + error_message: Optional[str] + + +class ValidImpl(_InterfaceImplementationTestCase): + @nexusrpc.service + class Interface: + op: nexusrpc.Operation[None, None] + + class Impl: + @nexusrpc.handler.sync_operation_handler + async def op( + self, ctx: nexusrpc.handler.StartOperationContext, input: None + ) -> None: ... + + error_message = None + + +class ValidWorkflowRunImpl(_InterfaceImplementationTestCase): + @nexusrpc.service + class Interface: + op: nexusrpc.Operation[str, int] + + class Impl: + @temporalio.nexus.handler.workflow_run_operation_handler + async def op( + self, ctx: nexusrpc.handler.StartOperationContext, input: str + ) -> WorkflowHandle[Any, int]: ... + + error_message = None + + +@pytest.mark.parametrize( + "test_case", + [ + ValidImpl, + ValidWorkflowRunImpl, + ], +) +def test_service_decorator_enforces_interface_conformance( + test_case: Type[_InterfaceImplementationTestCase], +): + if test_case.error_message: + with pytest.raises(Exception) as ei: + nexusrpc.handler.service_handler(test_case.Interface)(test_case.Impl) + err = ei.value + assert test_case.error_message in str(err) + else: + nexusrpc.handler.service_handler(service=test_case.Interface)(test_case.Impl) diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py new file mode 100644 index 000000000..fce864f20 --- /dev/null +++ b/tests/nexus/test_handler_operation_definitions.py @@ -0,0 +1,100 @@ +""" +Test that workflow_run_operation_handler decorator results in operation definitions with the correct name +and input/output types. +""" + +from dataclasses import dataclass +from typing import Any, Type + +import nexusrpc.handler +import pytest + +import temporalio.nexus.handler +from temporalio.client import WorkflowHandle + + +@dataclass +class Input: + pass + + +@dataclass +class Output: + pass + + +@dataclass +class _TestCase: + Service: Type[Any] + expected_operations: dict[str, nexusrpc.Operation] + + +class NotCalled(_TestCase): + @nexusrpc.handler.service_handler + class Service: + @temporalio.nexus.handler.workflow_run_operation_handler + async def workflow_run_operation_handler( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> WorkflowHandle[Any, Output]: ... + + expected_operations = { + "workflow_run_operation_handler": nexusrpc.Operation._create( + method_name="workflow_run_operation_handler", + input_type=Input, + output_type=Output, + ), + } + + +class CalledWithoutArgs(_TestCase): + @nexusrpc.handler.service_handler + class Service: + @temporalio.nexus.handler.workflow_run_operation_handler() + async def workflow_run_operation_handler( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> WorkflowHandle[Any, Output]: ... + + expected_operations = NotCalled.expected_operations + + +class CalledWithNameOverride(_TestCase): + @nexusrpc.handler.service_handler + class Service: + @temporalio.nexus.handler.workflow_run_operation_handler(name="operation-name") + async def workflow_run_operation_with_name_override( + self, ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> WorkflowHandle[Any, Output]: ... + + expected_operations = { + "workflow_run_operation_with_name_override": nexusrpc.Operation._create( + name="operation-name", + method_name="workflow_run_operation_with_name_override", + input_type=Input, + output_type=Output, + ), + } + + +@pytest.mark.parametrize( + "test_case", + [ + NotCalled, + CalledWithoutArgs, + CalledWithNameOverride, + ], +) +@pytest.mark.asyncio +async def test_collected_operation_names( + test_case: Type[_TestCase], +): + service: nexusrpc.ServiceDefinition = getattr( + test_case.Service, "__nexus_service__" + ) + assert isinstance(service, nexusrpc.ServiceDefinition) + assert service.name == "Service" + for method_name, expected_op in test_case.expected_operations.items(): + actual_op = getattr(test_case.Service, method_name).__nexus_operation__ + assert isinstance(actual_op, nexusrpc.Operation) + assert actual_op.name == expected_op.name + assert actual_op.input_type == expected_op.input_type + assert actual_op.output_type == expected_op.output_type From e97acf123f3dae703d97691c0c5e70438d39c565 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 10 Jun 2025 17:00:56 -0400 Subject: [PATCH 002/183] Nexus workflow caller --- temporalio/worker/_workflow_instance.py | 250 +++++- temporalio/workflow.py | 174 +++- tests/nexus/test_workflow_caller.py | 1047 +++++++++++++++++++++++ 3 files changed, 1467 insertions(+), 4 deletions(-) create mode 100644 tests/nexus/test_workflow_caller.py diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 528b42197..cc398cb14 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -43,6 +43,7 @@ cast, ) +import nexusrpc.handler from typing_extensions import Self, TypeAlias, TypedDict import temporalio.activity @@ -72,6 +73,7 @@ StartActivityInput, StartChildWorkflowInput, StartLocalActivityInput, + StartNexusOperationInput, WorkflowInboundInterceptor, WorkflowOutboundInterceptor, ) @@ -228,6 +230,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: self._pending_timers: Dict[int, _TimerHandle] = {} self._pending_activities: Dict[int, _ActivityHandle] = {} self._pending_child_workflows: Dict[int, _ChildWorkflowHandle] = {} + self._pending_nexus_operations: Dict[int, _NexusOperationHandle] = {} self._pending_external_signals: Dict[int, asyncio.Future] = {} self._pending_external_cancels: Dict[int, asyncio.Future] = {} # Keyed by type @@ -507,6 +510,10 @@ def _apply( self._apply_resolve_child_workflow_execution_start( job.resolve_child_workflow_execution_start ) + elif job.HasField("resolve_nexus_operation_start"): + self._apply_resolve_nexus_operation_start(job.resolve_nexus_operation_start) + elif job.HasField("resolve_nexus_operation"): + self._apply_resolve_nexus_operation(job.resolve_nexus_operation) elif job.HasField("resolve_request_cancel_external_workflow"): self._apply_resolve_request_cancel_external_workflow( job.resolve_request_cancel_external_workflow @@ -770,7 +777,6 @@ def _apply_resolve_child_workflow_execution( self, job: temporalio.bridge.proto.workflow_activation.ResolveChildWorkflowExecution, ) -> None: - # No matter the result, we know we want to pop handle = self._pending_child_workflows.pop(job.seq, None) if not handle: raise RuntimeError( @@ -839,6 +845,74 @@ def _apply_resolve_child_workflow_execution_start( else: raise RuntimeError("Child workflow start did not have a known status") + def _apply_resolve_nexus_operation_start( + self, + job: temporalio.bridge.proto.workflow_activation.ResolveNexusOperationStart, + ) -> None: + handle = self._pending_nexus_operations.get(job.seq) + if not handle: + raise RuntimeError( + f"Failed to find nexus operation handle for job sequence number {job.seq}" + ) + # TODO(dan): change core protos to use operation_token instead of operation_id + if job.HasField("operation_id"): + # The Nexus operation started asynchronously. A `ResolveNexusOperation` job + # will follow in a future activation. + handle._resolve_start_success(job.operation_id) + elif job.HasField("started_sync"): + # The Nexus operation 'started' in the sense that it's already resolved. A + # `ResolveNexusOperation` job will be in the same activation. + handle._resolve_start_success(None) + elif job.HasField("cancelled_before_start"): + # The operation was cancelled before it was ever sent to server (same WFT). + # Note that core will still send a `ResolveNexusOperation` job in the same + # activation, so there does not need to be an exceptional case for this in + # lang. + # TODO(dan): confirm appropriate to take no action here + pass + else: + raise ValueError(f"Unknown Nexus operation start status: {job}") + + def _apply_resolve_nexus_operation( + self, + job: temporalio.bridge.proto.workflow_activation.ResolveNexusOperation, + ) -> None: + handle = self._pending_nexus_operations.get(job.seq) + if not handle: + raise RuntimeError( + f"Failed to find nexus operation handle for job sequence number {job.seq}" + ) + + # Handle the four oneof variants of NexusOperationResult + result = job.result + if result.HasField("completed"): + [output] = self._convert_payloads( + [result.completed], + [handle._input.output_type] if handle._input.output_type else None, + ) + handle._resolve_success(output) + elif result.HasField("failed"): + # TODO(dan): test failure converter + handle._resolve_failure( + self._failure_converter.from_failure( + result.failed, self._payload_converter + ) + ) + elif result.HasField("cancelled"): + handle._resolve_failure( + self._failure_converter.from_failure( + result.cancelled, self._payload_converter + ) + ) + elif result.HasField("timed_out"): + handle._resolve_failure( + self._failure_converter.from_failure( + result.timed_out, self._payload_converter + ) + ) + else: + raise RuntimeError("Nexus operation did not have a result") + def _apply_resolve_request_cancel_external_workflow( self, job: temporalio.bridge.proto.workflow_activation.ResolveRequestCancelExternalWorkflow, @@ -1299,6 +1373,7 @@ def workflow_start_activity( ) ) + # workflow_start_child_workflow ret_type async def workflow_start_child_workflow( self, workflow: Any, @@ -1333,7 +1408,7 @@ async def workflow_start_child_workflow( if isinstance(workflow, str): name = workflow elif callable(workflow): - defn = temporalio.workflow._Definition.must_from_run_fn(workflow) + defn = temporalio.workflow._Definition.must_from_run_fn(workflow) # pyright: ignore if not defn.name: raise TypeError("Cannot invoke dynamic workflow explicitly") name = defn.name @@ -1418,6 +1493,33 @@ def workflow_start_local_activity( ) ) + async def workflow_start_nexus_operation( + self, + endpoint: str, + service: str, + operation: Union[ + nexusrpc.Operation[I, O], + Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], + str, + ], + input: Any, + output_type: Optional[Type[O]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> temporalio.workflow.NexusOperationHandle[Any]: + # start_nexus_operation + return await self._outbound.start_nexus_operation( + StartNexusOperationInput( + endpoint=endpoint, + service=service, + operation=operation, + input=input, + output_type=output_type, + schedule_to_close_timeout=schedule_to_close_timeout, + headers=headers, + ) + ) + def workflow_time_ns(self) -> int: return self._time_ns @@ -1722,6 +1824,47 @@ async def run_child() -> Any: except asyncio.CancelledError: apply_child_cancel_error() + async def _outbound_start_nexus_operation( + self, input: StartNexusOperationInput + ) -> _NexusOperationHandle[Any]: + # A Nexus operation handle contains two futures: self._start_fut is resolved as a + # result of the Nexus operation starting (activation job: + # resolve_nexus_operation_start), and self._result_fut is resolved as a result of + # the Nexus operation completing (activation job: resolve_nexus_operation). The + # handle itself corresponds to an asyncio.Task which waits on self.result_fut, + # handling CancelledError by emitting a RequestCancelNexusOperation command. We do + # not return the handle until we receive resolve_nexus_operation_start, like + # ChildWorkflowHandle and unlike ActivityHandle. Note that a Nexus operation may + # complete synchronously (in which case both jobs will be sent in the same + # activation, and start will be resolved without an operation token), or + # asynchronously (in which case start they may be sent in separate activations, + # and start will be resolved with an operation token). See comments in + # tests/worker/test_nexus.py for worked examples of the evolution of the resulting + # handle state machine in the sync and async Nexus response cases. + handle: _NexusOperationHandle + + async def operation_handle_fn() -> Any: + while True: + try: + return await asyncio.shield(handle._result_fut) + except asyncio.CancelledError: + cancel_command = self._add_command() + handle._apply_cancel_command(cancel_command) + + handle = _NexusOperationHandle( + self, self._next_seq("nexus_operation"), input, operation_handle_fn() + ) + handle._apply_schedule_command() + self._pending_nexus_operations[handle._seq] = handle + + while True: + try: + await asyncio.shield(handle._start_fut) + return handle + except asyncio.CancelledError: + cancel_command = self._add_command() + handle._apply_cancel_command(cancel_command) + #### Miscellaneous helpers #### # These are in alphabetical order. @@ -2458,6 +2601,11 @@ async def start_child_workflow( ) -> temporalio.workflow.ChildWorkflowHandle[Any, Any]: return await self._instance._outbound_start_child_workflow(input) + async def start_nexus_operation( + self, input: StartNexusOperationInput + ) -> temporalio.workflow.NexusOperationHandle[Any]: + return await self._instance._outbound_start_nexus_operation(input) + def start_local_activity( self, input: StartLocalActivityInput ) -> temporalio.workflow.ActivityHandle[Any]: @@ -2844,6 +2992,104 @@ async def cancel(self) -> None: await self._instance._cancel_external_workflow(command) +I = TypeVar("I") +O = TypeVar("O") + + +# TODO(dan): are we sure we don't want to inherit from asyncio.Task as ActivityHandle and +# ChildWorkflowHandle do? I worry that we should provide .done(), .result(), .exception() +# etc for consistency. +class _NexusOperationHandle(temporalio.workflow.NexusOperationHandle[O]): + def __init__( + self, + instance: _WorkflowInstanceImpl, + seq: int, + input: StartNexusOperationInput, + fn: Coroutine[Any, Any, O], + ): + self._instance = instance + self._seq = seq + self._input = input + self._task = asyncio.Task(fn) + self._start_fut: asyncio.Future[Optional[str]] = instance.create_future() + self._result_fut: asyncio.Future[Optional[O]] = instance.create_future() + + @property + def operation_token(self) -> Optional[str]: + # TODO(dan): How should this behave? + # Java has a separate class that only exists if the operation token exists: + # https://github.com/temporalio/sdk-java/blob/master/temporal-sdk/src/main/java/io/temporal/internal/sync/NexusOperationExecutionImpl.java#L26 + # And Go similar: + # https://github.com/temporalio/sdk-go/blob/master/internal/workflow.go#L2770-L2771 + try: + return self._start_fut.result() + except BaseException: + return None + + async def result(self) -> O: + return await self._task + + def __await__(self) -> Generator[Any, Any, O]: + return self._task.__await__() + + def __repr__(self) -> str: + return ( + f"{self._start_fut} " + f"{self._result_fut} " + f"Task[{self._task._state}] fut_waiter = {self._task._fut_waiter}) ({self._task._must_cancel})" # type: ignore + ) + + def cancel(self) -> bool: + return self._task.cancel() + + def _resolve_start_success(self, operation_token: Optional[str]) -> None: + print(f"🟢 _resolve_start_success: operation_id: {operation_token}") + # We intentionally let this error if already done + self._start_fut.set_result(operation_token) + + def _resolve_success(self, result: Any) -> None: + print( + f"🟢 _resolve_success: operation_id: {self.operation_token} result: {result}" + ) + # We intentionally let this error if already done + self._result_fut.set_result(result) + + def _resolve_failure(self, err: BaseException) -> None: + print(f"🔴 _resolve_failure: operation_id: {self.operation_token} err: {err}") + if self._start_fut.done(): + # We intentionally let this error if already done + self._result_fut.set_exception(err) + else: + self._start_fut.set_exception(err) + # Set null result to avoid warning about unhandled future + self._result_fut.set_result(None) + + def _apply_schedule_command(self) -> None: + command = self._instance._add_command() + v = command.schedule_nexus_operation + v.seq = self._seq + v.endpoint = self._input.endpoint + v.service = self._input.service + v.operation = self._input.operation_name + v.input.CopyFrom( + self._instance._payload_converter.to_payload(self._input.input) + ) + if self._input.schedule_to_close_timeout is not None: + v.schedule_to_close_timeout.FromTimedelta( + self._input.schedule_to_close_timeout + ) + if self._input.headers: + for key, val in self._input.headers.items(): + print(f"🌈 adding nexus header: {key} = {val}") + v.nexus_header[key] = val + + def _apply_cancel_command( + self, + command: temporalio.bridge.proto.workflow_commands.WorkflowCommand, + ) -> None: + command.request_cancel_nexus_operation.seq = self._seq + + class _ContinueAsNewError(temporalio.workflow.ContinueAsNewError): def __init__( self, instance: _WorkflowInstanceImpl, input: ContinueAsNewInput diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 409c8d690..0e7af635b 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -4,7 +4,6 @@ import asyncio import contextvars -import dataclasses import inspect import logging import threading @@ -23,6 +22,7 @@ Awaitable, Callable, Dict, + Generator, Generic, Iterable, Iterator, @@ -40,6 +40,8 @@ overload, ) +import nexusrpc +import nexusrpc.handler from typing_extensions import ( Concatenate, Literal, @@ -846,6 +848,22 @@ def workflow_start_local_activity( activity_id: Optional[str], ) -> ActivityHandle[Any]: ... + @abstractmethod + async def workflow_start_nexus_operation( + self, + endpoint: str, + service: str, + operation: Union[ + nexusrpc.Operation[I, O], + Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], + str, + ], + input: Any, + output_type: Optional[Type[O]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> NexusOperationHandle[Any]: ... + @abstractmethod def workflow_time_ns(self) -> int: ... @@ -1967,7 +1985,7 @@ class _AsyncioTask(asyncio.Task[AnyType]): pass else: - + # TODO(dan): inherited classes should be other way around? class _AsyncioTask(Generic[AnyType], asyncio.Task): pass @@ -4368,6 +4386,77 @@ async def execute_child_workflow( return await handle +I = TypeVar("I") +O = TypeVar("O") +S = TypeVar("S") + + +# TODO(dan): ABC? +class NexusOperationHandle(Generic[O]): + def cancel(self) -> bool: + # TODO(dan): docstring + """ + Call task.cancel() on the asyncio task that is backing this handle. + + From asyncio docs: + + Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, change the future's state to cancelled, schedule the callbacks and return True. + """ + raise NotImplementedError + + def __await__(self) -> Generator[Any, Any, O]: + raise NotImplementedError + + # TODO(dan): check SDK-wide philosophy on @property vs nullary accessor methods. + @property + def operation_token(self) -> Optional[str]: + raise NotImplementedError + + +async def start_nexus_operation( + endpoint: str, + service: str, + operation: Union[ + nexusrpc.Operation[I, O], + Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], + str, + ], + input: Any, + *, + output_type: Optional[Type[O]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + headers: Optional[Mapping[str, str]] = None, +) -> NexusOperationHandle[Any]: + """Start a Nexus operation and return its handle. + + Args: + endpoint: The Nexus endpoint. + service: The Nexus service. + operation: The Nexus operation. + input: The Nexus operation input. + output_type: The Nexus operation output type. + schedule_to_close_timeout: Timeout for the entire operation attempt. + headers: Headers to send with the Nexus HTTP request. + + Returns: + A handle to the Nexus operation. The result can be obtained as + ```python + await handle.result() + ``` + """ + return await _Runtime.current().workflow_start_nexus_operation( + endpoint=endpoint, + service=service, + operation=operation, + input=input, + output_type=output_type, + schedule_to_close_timeout=schedule_to_close_timeout, + headers=headers, + ) + + class ExternalWorkflowHandle(Generic[SelfType]): """Handle for interacting with an external workflow. @@ -5074,3 +5163,84 @@ def _to_proto(self) -> temporalio.bridge.proto.common.VersioningIntent.ValueType elif self == VersioningIntent.DEFAULT: return temporalio.bridge.proto.common.VersioningIntent.DEFAULT return temporalio.bridge.proto.common.VersioningIntent.UNSPECIFIED + + +# Nexus + + +class NexusClient(Generic[S]): + def __init__( + self, + service: Union[ + # TODO(dan): Type[S] is modeling the interface case as well the impl case, but + # the typevar S is used below only in the impl case. I think this is OK, but + # think about it again before deleting this TODO. + Type[S], + str, + ], + *, + endpoint: str, + ) -> None: + # If service is not a str, then it must be a service interface or implementation + # class. + if isinstance(service, str): + self._service_name = service + elif service_defn := getattr(service, "__nexus_service__", None): + self._service_name = service_defn.name + elif service_metadata := getattr(service, "__nexus_service_metadata__", None): + self._service_name = service_metadata.name + else: + raise ValueError( + f"`service` may be a name (str), or a class decorated with either " + f"@nexusrpc.handler.service_handler or @nexusrpc.service. " + f"Invalid service type: {type(service)}" + ) + self._endpoint = endpoint + + # TODO(dan): overloads: no-input, operation name, ret type + # TODO(dan): should it be an error to use a reference to a method on a class other than that supplied? + async def start_operation( + self, + operation: Union[ + nexusrpc.Operation[I, O], + Callable[[S], nexusrpc.handler.OperationHandler[I, O]], + str, + ], + input: I, + *, + output_type: Optional[Type[O]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> NexusOperationHandle[O]: + return await temporalio.workflow.start_nexus_operation( + endpoint=self._endpoint, + service=self._service_name, + operation=operation, + input=input, + output_type=output_type, + schedule_to_close_timeout=schedule_to_close_timeout, + headers=headers, + ) + + # TODO(dan): overloads: no-input, operation name, ret type + async def execute_operation( + self, + operation: Union[ + nexusrpc.Operation[I, O], + Callable[[S], nexusrpc.handler.OperationHandler[I, O]], + str, + ], + input: I, + *, + output_type: Optional[Type[O]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> O: + handle: NexusOperationHandle[O] = await self.start_operation( + operation, + input, + output_type=output_type, + schedule_to_close_timeout=schedule_to_close_timeout, + headers=headers, + ) + return await handle diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py new file mode 100644 index 000000000..cee54d4b7 --- /dev/null +++ b/tests/nexus/test_workflow_caller.py @@ -0,0 +1,1047 @@ +import asyncio +import uuid +from dataclasses import dataclass +from enum import IntEnum +from typing import Any, Callable, Optional, Union + +import nexusrpc +import nexusrpc.handler +import pytest + +import temporalio.api +import temporalio.api.common +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.api.nexus +import temporalio.api.nexus.v1 +import temporalio.api.operatorservice +import temporalio.api.operatorservice.v1 +import temporalio.nexus +import temporalio.nexus.handler +from temporalio import workflow +from temporalio.client import ( + Client, + WithStartWorkflowOperation, + WorkflowExecutionStatus, + WorkflowFailureError, +) +from temporalio.common import WorkflowIDConflictPolicy +from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError +from temporalio.nexus.handler import WorkflowHandle +from temporalio.nexus.token import WorkflowOperationToken +from temporalio.service import RPCError, RPCStatusCode +from temporalio.worker import UnsandboxedWorkflowRunner, Worker +from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name + +# TODO(dan): test availability of Temporal client etc in async context set by worker +# TODO(dan): test worker shutdown, wait_all_completed, drain etc +# TODO(dan): test worker op handling failure + +# ----------------------------------------------------------------------------- +# Test definition +# + + +class CallerReference(IntEnum): + IMPL_WITHOUT_INTERFACE = 0 + IMPL_WITH_INTERFACE = 1 + INTERFACE = 2 + + +class OpDefinitionType(IntEnum): + SHORTHAND = 0 + LONGHAND = 1 + + +@dataclass +class SyncResponse: + op_definition_type: OpDefinitionType + use_async_def: bool + exception_in_operation_start: bool + + +@dataclass +class AsyncResponse: + operation_workflow_id: str + block_forever_waiting_for_cancellation: bool + op_definition_type: OpDefinitionType + exception_in_operation_start: bool + + +# The order of the two types in this union is critical since the data converter matches +# eagerly, ignoring unknown fields, and so would identify an AsyncResponse as a +# SyncResponse if SyncResponse came first. +ResponseType = Union[AsyncResponse, SyncResponse] + +# ----------------------------------------------------------------------------- +# Service interface +# + + +@dataclass +class OpInput: + response_type: ResponseType + headers: dict[str, str] + caller_reference: CallerReference + + +@dataclass +class OpOutput: + value: str + start_options_received_by_handler: Optional[nexusrpc.handler.StartOperationContext] + + +@dataclass +class HandlerWfInput: + op_input: OpInput + + +@dataclass +class HandlerWfOutput: + value: str + start_options_received_by_handler: Optional[nexusrpc.handler.StartOperationContext] + + +@nexusrpc.service +class ServiceInterface: + sync_or_async_operation: nexusrpc.Operation[OpInput, OpOutput] + sync_operation: nexusrpc.Operation[OpInput, OpOutput] + async_operation: nexusrpc.Operation[OpInput, HandlerWfOutput] + + +# ----------------------------------------------------------------------------- +# Service implementation +# + + +@workflow.defn +class HandlerWorkflow: + @workflow.run + async def run( + self, + input: HandlerWfInput, + start_options_received_by_handler: nexusrpc.handler.StartOperationContext, + ) -> HandlerWfOutput: + assert isinstance(input.op_input.response_type, AsyncResponse) + if input.op_input.response_type.block_forever_waiting_for_cancellation: + await asyncio.Future() + return HandlerWfOutput( + value="workflow result", + start_options_received_by_handler=start_options_received_by_handler, + ) + + +# TODO: make types pass pyright strict mode + + +class SyncOrAsyncOperation(nexusrpc.handler.OperationHandler[OpInput, OpOutput]): + async def start( + self, ctx: nexusrpc.handler.StartOperationContext, input: OpInput + ) -> Union[ + nexusrpc.handler.StartOperationResultSync[OpOutput], + nexusrpc.handler.StartOperationResultAsync, + ]: + if input.response_type.exception_in_operation_start: + # TODO(dan): don't think RPCError should be used here + raise RPCError( + "RPCError INVALID_ARGUMENT in Nexus operation", + RPCStatusCode.INVALID_ARGUMENT, + b"", + ) + if isinstance(input.response_type, SyncResponse): + return nexusrpc.handler.StartOperationResultSync( + value=OpOutput( + value="sync response", + start_options_received_by_handler=ctx, + ) + ) + elif isinstance(input.response_type, AsyncResponse): + wf_handle = await temporalio.nexus.handler.start_workflow( + ctx, + HandlerWorkflow.run, + args=[HandlerWfInput(op_input=input), ctx], + id=input.response_type.operation_workflow_id, + ) + return nexusrpc.handler.StartOperationResultAsync( + WorkflowOperationToken.from_workflow_handle(wf_handle).encode() + ) + else: + raise TypeError + + async def cancel( + self, ctx: nexusrpc.handler.CancelOperationContext, token: str + ) -> None: + return await temporalio.nexus.handler.cancel_workflow(ctx, token) + + async def fetch_info( + self, ctx: nexusrpc.handler.FetchOperationInfoContext, token: str + ) -> nexusrpc.handler.OperationInfo: + raise NotImplementedError + + async def fetch_result( + self, ctx: nexusrpc.handler.FetchOperationResultContext, token: str + ) -> OpOutput: + raise NotImplementedError + + +@nexusrpc.handler.service_handler(service=ServiceInterface) +class ServiceImpl: + @nexusrpc.handler.operation_handler + def sync_or_async_operation( + self, + ) -> nexusrpc.handler.OperationHandler[OpInput, OpOutput]: + return SyncOrAsyncOperation() + + @nexusrpc.handler.sync_operation_handler + async def sync_operation( + self, ctx: nexusrpc.handler.StartOperationContext, input: OpInput + ) -> OpOutput: + assert isinstance(input.response_type, SyncResponse) + if input.response_type.exception_in_operation_start: + raise RPCError( + "RPCError INVALID_ARGUMENT in Nexus operation", + RPCStatusCode.INVALID_ARGUMENT, + b"", + ) + return OpOutput( + value="sync response", + start_options_received_by_handler=ctx, + ) + + @temporalio.nexus.handler.workflow_run_operation_handler + async def async_operation( + self, ctx: nexusrpc.handler.StartOperationContext, input: OpInput + ) -> WorkflowHandle[HandlerWorkflow, HandlerWfOutput]: + assert isinstance(input.response_type, AsyncResponse) + if input.response_type.exception_in_operation_start: + raise RPCError( + "RPCError INVALID_ARGUMENT in Nexus operation", + RPCStatusCode.INVALID_ARGUMENT, + b"", + ) + return await temporalio.nexus.handler.start_workflow( + ctx, + HandlerWorkflow.run, + args=[HandlerWfInput(op_input=input), ctx], + id=input.response_type.operation_workflow_id, + ) + + +# ----------------------------------------------------------------------------- +# Caller workflow +# + + +@dataclass +class CallerWfInput: + op_input: OpInput + + +@dataclass +class CallerWfOutput: + op_output: OpOutput + + +@workflow.defn +class CallerWorkflow: + """ + A workflow that executes a Nexus operation, specifying whether it should return + synchronously or asynchronously. + """ + + @workflow.init + def __init__( + self, + input: CallerWfInput, + request_cancel: bool, + task_queue: str, + ) -> None: + self.nexus_client = workflow.NexusClient( + service={ + CallerReference.IMPL_WITH_INTERFACE: ServiceImpl, + CallerReference.INTERFACE: ServiceInterface, + }[input.op_input.caller_reference], + endpoint=make_nexus_endpoint_name(task_queue), + ) + self._nexus_operation_started = False + self._proceed = False + + @workflow.run + async def run( + self, + input: CallerWfInput, + request_cancel: bool, + task_queue: str, + ) -> CallerWfOutput: + op_input = input.op_input + op_handle = await self.nexus_client.start_operation( + self._get_operation(op_input), + op_input, + headers=op_input.headers, + ) + self._nexus_operation_started = True + if not input.op_input.response_type.exception_in_operation_start: + if isinstance(input.op_input.response_type, SyncResponse): + assert ( + op_handle.operation_token is None + ), "operation_token should be absent after a sync response" + else: + assert ( + op_handle.operation_token + ), "operation_token should be present after an async response" + + if request_cancel: + # Even for SyncResponse, the op_handle future is not done at this point; that + # transition doesn't happen until the handle is awaited. + assert op_handle.cancel() + op_output = await op_handle + return CallerWfOutput( + op_output=OpOutput( + value=op_output.value, + start_options_received_by_handler=op_output.start_options_received_by_handler, + ) + ) + + @workflow.update + async def wait_nexus_operation_started(self) -> None: + await workflow.wait_condition(lambda: self._nexus_operation_started) + + @staticmethod + def _get_operation( + op_input: OpInput, + ) -> Union[ + nexusrpc.Operation[OpInput, OpOutput], + Callable[[Any], nexusrpc.handler.OperationHandler[OpInput, OpOutput]], + ]: + return { + ( + SyncResponse, + OpDefinitionType.SHORTHAND, + CallerReference.IMPL_WITH_INTERFACE, + True, + ): ServiceImpl.sync_operation, + ( + SyncResponse, + OpDefinitionType.SHORTHAND, + CallerReference.INTERFACE, + True, + ): ServiceInterface.sync_operation, + ( + SyncResponse, + OpDefinitionType.LONGHAND, + CallerReference.IMPL_WITH_INTERFACE, + True, + ): ServiceImpl.sync_or_async_operation, + ( + SyncResponse, + OpDefinitionType.LONGHAND, + CallerReference.INTERFACE, + True, + ): ServiceInterface.sync_or_async_operation, + ( + AsyncResponse, + OpDefinitionType.SHORTHAND, + CallerReference.IMPL_WITH_INTERFACE, + True, + ): ServiceImpl.async_operation, + ( + AsyncResponse, + OpDefinitionType.SHORTHAND, + CallerReference.INTERFACE, + True, + ): ServiceInterface.async_operation, + ( + AsyncResponse, + OpDefinitionType.LONGHAND, + CallerReference.IMPL_WITH_INTERFACE, + True, + ): ServiceImpl.sync_or_async_operation, + ( + AsyncResponse, + OpDefinitionType.LONGHAND, + CallerReference.INTERFACE, + True, + ): ServiceInterface.sync_or_async_operation, + }[ + {True: SyncResponse, False: AsyncResponse}[ + isinstance(op_input.response_type, SyncResponse) + ], + op_input.response_type.op_definition_type, + op_input.caller_reference, + ( + op_input.response_type.use_async_def + if isinstance(op_input.response_type, SyncResponse) + else True + ), + ] + + +@workflow.defn +class UntypedCallerWorkflow: + @workflow.init + def __init__( + self, input: CallerWfInput, request_cancel: bool, task_queue: str + ) -> None: + # TODO(dan): untyped caller cannot reference name of implementation. I think this is as it should be. + service_name = "ServiceInterface" + self.nexus_client = workflow.NexusClient( + service=service_name, + endpoint=make_nexus_endpoint_name(task_queue), + ) + + @workflow.run + async def run( + self, input: CallerWfInput, request_cancel: bool, task_queue: str + ) -> CallerWfOutput: + op_input = input.op_input + if op_input.response_type.op_definition_type == OpDefinitionType.LONGHAND: + op_name = "sync_or_async_operation" + elif isinstance(op_input.response_type, AsyncResponse): + op_name = "async_operation" + elif isinstance(op_input.response_type, SyncResponse): + op_name = "sync_operation" + else: + raise TypeError + + arbitrary_condition = isinstance(op_input.response_type, SyncResponse) + + if arbitrary_condition: + op_handle = await self.nexus_client.start_operation( + op_name, + op_input, + headers=op_input.headers, + output_type=OpOutput, + ) + op_output = await op_handle + else: + op_output = await self.nexus_client.execute_operation( + op_name, + op_input, + headers=op_input.headers, + output_type=OpOutput, + ) + return CallerWfOutput( + op_output=OpOutput( + value=op_output.value, + start_options_received_by_handler=op_output.start_options_received_by_handler, + ) + ) + + +# ----------------------------------------------------------------------------- +# Tests +# + + +# TODO(dan): cross-namespace tests +# TODO(dan): nexus endpoint pytest fixture? +# TODO(dan): test headers +@pytest.mark.parametrize("exception_in_operation_start", [False, True]) +@pytest.mark.parametrize("request_cancel", [False, True]) +@pytest.mark.parametrize( + "op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND] +) +@pytest.mark.parametrize( + "caller_reference", + [CallerReference.IMPL_WITH_INTERFACE, CallerReference.INTERFACE], +) +async def test_sync_response( + client: Client, + exception_in_operation_start: bool, + request_cancel: bool, + op_definition_type: OpDefinitionType, + caller_reference: CallerReference, +): + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_services=[ServiceImpl()], + workflows=[CallerWorkflow, HandlerWorkflow], + task_queue=task_queue, + # TODO(dan): enable sandbox + workflow_runner=UnsandboxedWorkflowRunner(), + workflow_failure_exception_types=[Exception], + ): + await create_nexus_endpoint(task_queue, client) + caller_wf_handle = await client.start_workflow( + CallerWorkflow.run, + args=[ + CallerWfInput( + op_input=OpInput( + response_type=SyncResponse( + op_definition_type=op_definition_type, + use_async_def=True, + exception_in_operation_start=exception_in_operation_start, + ), + headers={"header-key": "header-value"}, + caller_reference=caller_reference, + ), + ), + request_cancel, + task_queue, + ], + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + + # TODO(dan): check bidi links for sync operation + + # The operation result is returned even when request_cancel=True, because the + # response was synchronous and it could not be cancelled. See explanation below. + if exception_in_operation_start: + with pytest.raises(WorkflowFailureError) as ei: + await caller_wf_handle.result() + e = ei.value + assert isinstance(e, WorkflowFailureError) + assert isinstance(e.__cause__, NexusOperationError) + assert isinstance(e.__cause__.__cause__, NexusHandlerError) + # ID of first command + await print_history(caller_wf_handle) + assert e.__cause__.scheduled_event_id == 5 + assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue) + assert e.__cause__.service == "ServiceInterface" + assert ( + e.__cause__.operation == "sync_operation" + if op_definition_type == OpDefinitionType.SHORTHAND + else "sync_or_async_operation" + ) + else: + result = await caller_wf_handle.result() + assert result.op_output.value == "sync response" + assert result.op_output.start_options_received_by_handler + + +@pytest.mark.parametrize("exception_in_operation_start", [False, True]) +@pytest.mark.parametrize("request_cancel", [False, True]) +@pytest.mark.parametrize( + "op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND] +) +@pytest.mark.parametrize( + "caller_reference", + [CallerReference.IMPL_WITH_INTERFACE, CallerReference.INTERFACE], +) +async def test_async_response( + client: Client, + exception_in_operation_start: bool, + request_cancel: bool, + op_definition_type: OpDefinitionType, + caller_reference: CallerReference, +): + print(f"🌈 {'test_async_response':<24}: {request_cancel=} {op_definition_type=}") + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_services=[ServiceImpl()], + workflows=[CallerWorkflow, HandlerWorkflow], + task_queue=task_queue, + workflow_runner=UnsandboxedWorkflowRunner(), + workflow_failure_exception_types=[Exception], + ): + caller_wf_handle, handler_wf_handle = await _start_wf_and_nexus_op( + client, + task_queue, + exception_in_operation_start, + request_cancel, + op_definition_type, + caller_reference, + ) + if exception_in_operation_start: + with pytest.raises(WorkflowFailureError) as ei: + await caller_wf_handle.result() + e = ei.value + assert isinstance(e, WorkflowFailureError) + assert isinstance(e.__cause__, NexusOperationError) + assert isinstance(e.__cause__.__cause__, NexusHandlerError) + # ID of first command after update accepted + assert e.__cause__.scheduled_event_id == 6 + assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue) + assert e.__cause__.service == "ServiceInterface" + assert ( + e.__cause__.operation == "async_operation" + if op_definition_type == OpDefinitionType.SHORTHAND + else "sync_or_async_operation" + ) + return + + # TODO(dan): race here? How do we know it hasn't been canceled already? + handler_wf_info = await handler_wf_handle.describe() + assert handler_wf_info.status in [ + WorkflowExecutionStatus.RUNNING, + WorkflowExecutionStatus.COMPLETED, + ] + await assert_caller_workflow_has_link_to_handler_workflow( + caller_wf_handle, handler_wf_handle, handler_wf_info.run_id + ) + await assert_handler_workflow_has_link_to_caller_workflow( + caller_wf_handle, handler_wf_handle + ) + + if request_cancel: + # The operation response was asynchronous and so request_cancel is honored. See + # explanation below. + with pytest.raises(WorkflowFailureError) as ei: + await caller_wf_handle.result() + e = ei.value + assert isinstance(e, WorkflowFailureError) + assert isinstance(e.__cause__, NexusOperationError) + assert isinstance(e.__cause__.__cause__, CancelledError) + # ID of first command after update accepted + assert e.__cause__.scheduled_event_id == 6 + assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue) + assert e.__cause__.service == "ServiceInterface" + assert ( + e.__cause__.operation == "async_operation" + if op_definition_type == OpDefinitionType.SHORTHAND + else "sync_or_async_operation" + ) + assert WorkflowOperationToken.decode( + e.__cause__.operation_token + ) == WorkflowOperationToken( + namespace=handler_wf_handle._client.namespace, + workflow_id=handler_wf_handle.id, + ) + # Check that the handler workflow was canceled + handler_wf_info = await handler_wf_handle.describe() + assert handler_wf_info.status == WorkflowExecutionStatus.CANCELED + else: + handler_wf_info = await handler_wf_handle.describe() + assert handler_wf_info.status == WorkflowExecutionStatus.COMPLETED + result = await caller_wf_handle.result() + assert result.op_output.value == "workflow result" + assert result.op_output.start_options_received_by_handler + + +async def _start_wf_and_nexus_op( + client: Client, + task_queue: str, + exception_in_operation_start: bool, + request_cancel: bool, + op_definition_type: OpDefinitionType, + caller_reference: CallerReference, +) -> tuple[ + WorkflowHandle[CallerWorkflow, CallerWfOutput], + WorkflowHandle[HandlerWorkflow, HandlerWfOutput], +]: + """ + Start the caller workflow and wait until the Nexus operation has started. + """ + await create_nexus_endpoint(task_queue, client) + operation_workflow_id = str(uuid.uuid4()) + + # Start the caller workflow and wait until it confirms the Nexus operation has started. + block_forever_waiting_for_cancellation = request_cancel + start_op = WithStartWorkflowOperation( + CallerWorkflow.run, + args=[ + CallerWfInput( + op_input=OpInput( + response_type=AsyncResponse( + operation_workflow_id, + block_forever_waiting_for_cancellation, + op_definition_type, + exception_in_operation_start=exception_in_operation_start, + ), + headers={"header-key": "header-value"}, + caller_reference=caller_reference, + ), + ), + request_cancel, + task_queue, + ], + id=str(uuid.uuid4()), + task_queue=task_queue, + id_conflict_policy=WorkflowIDConflictPolicy.FAIL, + ) + + await client.execute_update_with_start_workflow( + CallerWorkflow.wait_nexus_operation_started, + start_workflow_operation=start_op, + ) + caller_wf_handle = await start_op.workflow_handle() + + # check that the operation-backing workflow now exists, and that (a) the handler + # workflow accepted the link to the calling Nexus event, and that (b) the caller + # workflow NexusOperationStarted event received in return a link to the + # operation-backing workflow. + handler_wf_handle: WorkflowHandle[HandlerWorkflow, HandlerWfOutput] = ( + client.get_workflow_handle(operation_workflow_id) + ) + return caller_wf_handle, handler_wf_handle + + +@pytest.mark.parametrize("exception_in_operation_start", [False, True]) +@pytest.mark.parametrize( + "op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND] +) +@pytest.mark.parametrize( + "caller_reference", + [CallerReference.IMPL_WITH_INTERFACE, CallerReference.INTERFACE], +) +@pytest.mark.parametrize("response_type", [SyncResponse, AsyncResponse]) +async def test_untyped_caller( + client: Client, + exception_in_operation_start: bool, + op_definition_type: OpDefinitionType, + caller_reference: CallerReference, + response_type: ResponseType, +): + task_queue = str(uuid.uuid4()) + async with Worker( + client, + workflows=[UntypedCallerWorkflow, HandlerWorkflow], + nexus_services=[ServiceImpl()], + task_queue=task_queue, + workflow_runner=UnsandboxedWorkflowRunner(), + workflow_failure_exception_types=[Exception], + ): + if response_type == SyncResponse: + response_type = SyncResponse( + op_definition_type=op_definition_type, + use_async_def=True, + exception_in_operation_start=exception_in_operation_start, + ) + else: + response_type = AsyncResponse( + operation_workflow_id=str(uuid.uuid4()), + block_forever_waiting_for_cancellation=False, + op_definition_type=op_definition_type, + exception_in_operation_start=exception_in_operation_start, + ) + await create_nexus_endpoint(task_queue, client) + caller_wf_handle = await client.start_workflow( + UntypedCallerWorkflow.run, + args=[ + CallerWfInput( + op_input=OpInput( + response_type=response_type, + headers={}, + caller_reference=caller_reference, + ), + ), + False, + task_queue, + ], + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + if exception_in_operation_start: + with pytest.raises(WorkflowFailureError) as ei: + await caller_wf_handle.result() + e = ei.value + assert isinstance(e, WorkflowFailureError) + assert isinstance(e.__cause__, NexusOperationError) + assert isinstance(e.__cause__.__cause__, NexusHandlerError) + else: + result = await caller_wf_handle.result() + assert result.op_output.value == ( + "sync response" + if isinstance(response_type, SyncResponse) + else "workflow result" + ) + assert result.op_output.start_options_received_by_handler + + +# +# Test routing of workflow calls +# + + +@dataclass +class ServiceClassNameOutput: + name: str + + +# TODO(dan): test interface op types not matching +# TODO(dan): async and non-async cancel methods + + +@nexusrpc.service +class ServiceInterfaceWithoutNameOverride: + op: nexusrpc.Operation[None, ServiceClassNameOutput] + + +@nexusrpc.service(name="service-interface-🌈") +class ServiceInterfaceWithNameOverride: + op: nexusrpc.Operation[None, ServiceClassNameOutput] + + +@nexusrpc.handler.service_handler +class ServiceImplInterfaceWithNeitherInterfaceNorNameOverride: + @nexusrpc.handler.sync_operation_handler + async def op( + self, ctx: nexusrpc.handler.StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) + + +@nexusrpc.handler.service_handler(service=ServiceInterfaceWithoutNameOverride) +class ServiceImplInterfaceWithoutNameOverride: + @nexusrpc.handler.sync_operation_handler + async def op( + self, ctx: nexusrpc.handler.StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) + + +@nexusrpc.handler.service_handler(service=ServiceInterfaceWithNameOverride) +class ServiceImplInterfaceWithNameOverride: + @nexusrpc.handler.sync_operation_handler + async def op( + self, ctx: nexusrpc.handler.StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) + + +@nexusrpc.handler.service_handler(name="service-impl-🌈") +class ServiceImplWithNameOverride: + @nexusrpc.handler.sync_operation_handler + async def op( + self, ctx: nexusrpc.handler.StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) + + +class NameOverride(IntEnum): + NO = 0 + YES = 1 + + +@workflow.defn +class ServiceInterfaceAndImplCallerWorkflow: + @workflow.run + async def run( + self, + caller_reference: CallerReference, + name_override: NameOverride, + task_queue: str, + ) -> ServiceClassNameOutput: + C, N = CallerReference, NameOverride + if (caller_reference, name_override) == (C.INTERFACE, N.YES): + service_cls = ServiceInterfaceWithNameOverride + elif (caller_reference, name_override) == (C.INTERFACE, N.NO): + service_cls = ServiceInterfaceWithoutNameOverride + elif (caller_reference, name_override) == (C.IMPL_WITH_INTERFACE, N.YES): + service_cls = ServiceImplWithNameOverride + elif (caller_reference, name_override) == (C.IMPL_WITH_INTERFACE, N.NO): + service_cls = ServiceImplInterfaceWithoutNameOverride + elif (caller_reference, name_override) == (C.IMPL_WITHOUT_INTERFACE, N.NO): + service_cls = ServiceImplInterfaceWithNeitherInterfaceNorNameOverride + else: + raise ValueError( + f"Invalid combination of caller_reference ({caller_reference}) and name_override ({name_override})" + ) + + nexus_client = workflow.NexusClient( + service=service_cls, + endpoint=make_nexus_endpoint_name(task_queue), + ) + + # TODO(dan): maybe not surprising that this doesn't type check given complexity of + # the union? + return await nexus_client.execute_operation(service_cls.op, None) # type: ignore + + +# TODO(dan): check missing decorator behavior + + +async def test_service_interface_and_implementation_names(client: Client): + # Note that: + # - The caller can specify the service & operation via a reference to either the + # interface or implementation class. + # - An interface class may optionally override its name. + # - An implementation class may either override its name or specify an interface that + # it is implementing, but not both. + # - On registering a service implementation with a worker, the name by which the + # service is addressed in requests is the interface name if the implementation + # supplies one, or else the name override made by the impl class, or else the impl + # class name. + # + # This test checks that the request is routed to the expected service under a variety + # of scenarios related to the above considerations. + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_services=[ + ServiceImplWithNameOverride(), + ServiceImplInterfaceWithNameOverride(), + ServiceImplInterfaceWithoutNameOverride(), + ServiceImplInterfaceWithNeitherInterfaceNorNameOverride(), + ], + workflows=[ServiceInterfaceAndImplCallerWorkflow], + task_queue=task_queue, + workflow_runner=UnsandboxedWorkflowRunner(), + workflow_failure_exception_types=[Exception], + ): + await create_nexus_endpoint(task_queue, client) + assert await client.execute_workflow( + ServiceInterfaceAndImplCallerWorkflow.run, + args=(CallerReference.INTERFACE, NameOverride.YES, task_queue), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) == ServiceClassNameOutput("ServiceImplInterfaceWithNameOverride") + assert await client.execute_workflow( + ServiceInterfaceAndImplCallerWorkflow.run, + args=(CallerReference.INTERFACE, NameOverride.NO, task_queue), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) == ServiceClassNameOutput("ServiceImplInterfaceWithoutNameOverride") + assert await client.execute_workflow( + ServiceInterfaceAndImplCallerWorkflow.run, + args=( + CallerReference.IMPL_WITH_INTERFACE, + NameOverride.YES, + task_queue, + ), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) == ServiceClassNameOutput("ServiceImplWithNameOverride") + assert await client.execute_workflow( + ServiceInterfaceAndImplCallerWorkflow.run, + args=( + CallerReference.IMPL_WITH_INTERFACE, + NameOverride.NO, + task_queue, + ), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) == ServiceClassNameOutput("ServiceImplInterfaceWithoutNameOverride") + assert await client.execute_workflow( + ServiceInterfaceAndImplCallerWorkflow.run, + args=( + CallerReference.IMPL_WITHOUT_INTERFACE, + NameOverride.NO, + task_queue, + ), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) == ServiceClassNameOutput( + "ServiceImplInterfaceWithNeitherInterfaceNorNameOverride" + ) + + +# TODO(dan): test invalid service interface implementations +# TODO(dan): test caller passing output_type + + +async def assert_caller_workflow_has_link_to_handler_workflow( + caller_wf_handle: WorkflowHandle, + handler_wf_handle: WorkflowHandle, + handler_wf_run_id: str, +): + caller_history = await caller_wf_handle.fetch_history() + op_started_event = next( + e + for e in caller_history.events + if ( + e.event_type + == temporalio.api.enums.v1.EventType.EVENT_TYPE_NEXUS_OPERATION_STARTED + ) + ) + if not len(op_started_event.links) == 1: + pytest.fail( + f"Expected 1 link on NexusOperationStarted event, got {len(op_started_event.links)}" + ) + [link] = op_started_event.links + assert link.workflow_event.namespace == handler_wf_handle._client.namespace + assert link.workflow_event.workflow_id == handler_wf_handle.id + assert link.workflow_event.run_id + assert link.workflow_event.run_id == handler_wf_run_id + assert ( + link.workflow_event.event_ref.event_type + == temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED + ) + + +async def assert_handler_workflow_has_link_to_caller_workflow( + caller_wf_handle: WorkflowHandle, + handler_wf_handle: WorkflowHandle, +): + handler_history = await handler_wf_handle.fetch_history() + wf_started_event = next( + e + for e in handler_history.events + if ( + e.event_type + == temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED + ) + ) + if not len(wf_started_event.links) == 1: + pytest.fail( + f"Expected 1 link on WorkflowExecutionStarted event, got {len(wf_started_event.links)}" + ) + [link] = wf_started_event.links + assert link.workflow_event.namespace == caller_wf_handle._client.namespace + assert link.workflow_event.workflow_id == caller_wf_handle.id + assert link.workflow_event.run_id + assert link.workflow_event.run_id == caller_wf_handle.first_execution_run_id + assert ( + link.workflow_event.event_ref.event_type + == temporalio.api.enums.v1.EventType.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED + ) + + +async def print_history(handle: WorkflowHandle): + print("\n\n") + history = await handle.fetch_history() + for event in history.events: + try: + event_type_name = temporalio.api.enums.v1.EventType.Name( + event.event_type + ).replace("EVENT_TYPE_", "") + except ValueError: + # Handle unknown event types + event_type_name = f"Unknown({event.event_type})" + print(f"{event.event_id}. {event_type_name}") + print("\n\n") + + +# When request_cancel is True, the NexusOperationHandle in the workflow evolves +# through the following states: +# start_fut result_fut handle_task w/ fut_waiter (task._must_cancel) +# +# Case 1: Sync Nexus operation response w/ cancellation of NexusOperationHandle +# ----------------------------------------------------------------------------- +# >>>>>>>>>>>> WFT 1 +# after await start : Future_7856[FINISHED] Future_7984[FINISHED] Task[PENDING] fut_waiter = Future_8240[FINISHED]) (False) +# before op_handle.cancel : Future_7856[FINISHED] Future_7984[FINISHED] Task[PENDING] fut_waiter = Future_8240[FINISHED]) (False) +# Future_8240[FINISHED].cancel() -> False # no state transition; fut_waiter is already finished +# cancel returned : True +# before await op_handle : Future_7856[FINISHED] Future_7984[FINISHED] Task[PENDING] fut_waiter = Future_8240[FINISHED]) (True) +# --> Despite cancel having been requested, this await on the nexus op handle does not +# raise CancelledError, because the task's underlying fut_waiter is already finished. +# after await op_handle : Future_7856[FINISHED] Future_7984[FINISHED] Task[FINISHED] fut_waiter = None) (False) +# +# +# Case 2: Async Nexus operation response w/ cancellation of NexusOperationHandle +# ------------------------------------------------------------------------------ +# >>>>>>>>>>>> WFT 1 +# after await start : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[PENDING]) (False) +# >>>>>>>>>>>> WFT 2 +# >>>>>>>>>>>> WFT 3 +# after await proceed : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[PENDING]) (False) +# before op_handle.cancel : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[PENDING]) (False) +# Future_7952[PENDING].cancel() -> True # transition to cancelled state; fut_waiter was not finished +# cancel returned : True +# before await op_handle : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[CANCELLED]) (False) +# --> This await on the nexus op handle raises CancelledError, because the task's underlying fut_waiter is cancelled. +# +# Thus in the sync case, although the caller workflow attempted to cancel the +# NexusOperationHandle, this did not result in a CancelledError when the handle was +# awaited, because both resolve_nexus_operation_start and resolve_nexus_operation jobs +# were sent in the same activation and hence the task's fut_waiter was already finished. +# +# But in the async case, at the time that we cancel the NexusOperationHandle, only the +# resolve_nexus_operation_start job had been sent; the result_fut was unresolved. Thus +# when the handle was awaited, CancelledError was raised. +# +# To create output like that above, set the following __repr__s: +# asyncio.Future: +# def __repr__(self): +# return f"{self.__class__.__name__}_{str(id(self))[-4:]}[{self._state}]" +# _NexusOperationHandle: +# def __repr__(self) -> str: +# return ( +# f"{self._start_fut} " +# f"{self._result_fut} " +# f"Task[{self._task._state}] fut_waiter = {self._task._fut_waiter}) ({self._task._must_cancel})" +# ) From bba27cdb2df0ed96b572a85f99eaa3e979cb1376 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 11 Jun 2025 21:48:15 -0400 Subject: [PATCH 003/183] Nexus: squashed commit Import types from nexusrpc Rename: WorkflowRunOperation Use Client.start_workflow with Temporal-specific context classes Use temporal {Start,Cancel}OperationContext Handle errors during cancellation Failing test for requestID-based idempotency Propagate Nexus request ID when starting a workflow Move nexus http client into tests/helpers Testing: Fix env http port Test that Nexus request ID becomes Temporal StartWorkflow request ID Test non-serializable operation output Don't require HandlerError cause and fix test assertions Respond to upstream: Unknown{Operation,Service}Error are just HandlerError Rename argument in public Worker API: nexus_service_handlers Respond to upstream: Executor rename Sync/Async suffix names Separate _TemporalNexusOperationContext from Nexus op ctx _operation_handlers module Make TemporalNexusOperationContext public Add WorkflowRunOperationResult.to_workflow_handle --- pyproject.toml | 3 - temporalio/client.py | 43 +- temporalio/nexus/__init__.py | 29 +- temporalio/nexus/handler.py | 471 ------------------ temporalio/nexus/handler/__init__.py | 74 +++ .../nexus/handler/_operation_context.py | 233 +++++++++ .../nexus/handler/_operation_handlers.py | 292 +++++++++++ .../nexus/{token.py => handler/_token.py} | 27 +- temporalio/worker/_interceptor.py | 25 +- temporalio/worker/_nexus.py | 417 ++++++++-------- temporalio/worker/_worker.py | 15 +- tests/conftest.py | 2 +- tests/helpers/nexus.py | 71 ++- ...ynamic_creation_of_user_handler_classes.py | 6 +- tests/nexus/test_handler.py | 339 ++++++++++--- tests/nexus/test_handler_async_operation.py | 15 +- .../test_handler_operation_definitions.py | 5 +- tests/nexus/test_workflow_caller.py | 78 ++- 18 files changed, 1260 insertions(+), 885 deletions(-) delete mode 100644 temporalio/nexus/handler.py create mode 100644 temporalio/nexus/handler/__init__.py create mode 100644 temporalio/nexus/handler/_operation_context.py create mode 100644 temporalio/nexus/handler/_operation_handlers.py rename temporalio/nexus/{token.py => handler/_token.py} (82%) diff --git a/pyproject.toml b/pyproject.toml index 4a38ba6a3..1faed3fd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -200,9 +200,6 @@ exclude = [ [tool.ruff] target-version = "py39" -[tool.ruff.lint] -extend-ignore = ["E741"] # Allow single-letter variable names like I, O - [build-system] requires = ["maturin>=1.0,<2.0"] build-backend = "maturin" diff --git a/temporalio/client.py b/temporalio/client.py index a5cac9b18..5ab8b7c0b 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -53,10 +53,14 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.nexus.handler import temporalio.runtime import temporalio.service import temporalio.workflow from temporalio.activity import ActivityCancellationDetails +from temporalio.nexus.handler import ( + TemporalNexusOperationContext, +) from temporalio.service import ( HttpConnectProxyConfig, KeepAliveConfig, @@ -468,12 +472,6 @@ async def start_workflow( versioning_override: Optional[temporalio.common.VersioningOverride] = None, # The following options are deliberately not exposed in overloads stack_level: int = 2, - nexus_completion_callbacks: Sequence[ - temporalio.common.NexusCompletionCallback - ] = [], - workflow_event_links: Sequence[ - temporalio.api.common.v1.Link.WorkflowEvent - ] = [], ) -> WorkflowHandle[Any, Any]: """Start a workflow and return its handle. @@ -536,8 +534,21 @@ async def start_workflow( name, result_type_from_type_hint = ( temporalio.workflow._Definition.get_name_and_result_type(workflow) ) + nexus_start_ctx = None + if nexus_ctx := TemporalNexusOperationContext.try_current(): + # TODO(prerelease): I think this is too magical: what if a user implements a + # nexus handler by running one workflow to completion, and then starting a + # second workflow to act as the async operation itself? + # TODO(prerelease): What do we do if the Temporal Nexus context client + # (namespace) is not the same as the one being used to start this workflow? + if nexus_start_ctx := nexus_ctx.temporal_nexus_start_operation_context: + nexus_completion_callbacks = nexus_start_ctx.get_completion_callbacks() + workflow_event_links = nexus_start_ctx.get_workflow_event_links() + else: + nexus_completion_callbacks = [] + workflow_event_links = [] - return await self._impl.start_workflow( + wf_handle = await self._impl.start_workflow( StartWorkflowInput( workflow=name, args=temporalio.common._arg_or_args(arg, args), @@ -569,6 +580,11 @@ async def start_workflow( ) ) + if nexus_start_ctx: + nexus_start_ctx.add_outbound_links(wf_handle) + + return wf_handle + # Overload for no-param workflow @overload async def execute_workflow( @@ -5876,7 +5892,18 @@ async def _populate_start_workflow_execution_request( if input.task_timeout is not None: req.workflow_task_timeout.FromTimedelta(input.task_timeout) req.identity = self._client.identity - req.request_id = str(uuid.uuid4()) + # Use Nexus request ID if we're handling a Nexus Start operation + # TODO(prerelease): confirm that we should do this for every workflow started + # TODO(prerelease): add test coverage for multiple workflows started by a Nexus operation + if nexus_ctx := TemporalNexusOperationContext.try_current(): + if nexus_start_ctx := nexus_ctx.temporal_nexus_start_operation_context: + if ( + nexus_request_id + := nexus_start_ctx.nexus_operation_context.request_id + ): + req.request_id = nexus_request_id + if not req.request_id: + req.request_id = str(uuid.uuid4()) req.workflow_id_reuse_policy = cast( "temporalio.api.enums.v1.WorkflowIdReusePolicy.ValueType", int(input.id_reuse_policy), diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index 9750cfb88..571965eb9 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -1,28 +1 @@ -import dataclasses -import logging -from collections.abc import Mapping -from typing import Any, MutableMapping, Optional - -from .handler import _current_context as _current_context -from .handler import workflow_run_operation_handler as workflow_run_operation_handler -from .token import WorkflowOperationToken as WorkflowOperationToken - - -class LoggerAdapter(logging.LoggerAdapter): - def __init__(self, logger: logging.Logger, extra: Optional[Mapping[str, Any]]): - super().__init__(logger, extra or {}) - - def process( - self, msg: Any, kwargs: MutableMapping[str, Any] - ) -> tuple[Any, MutableMapping[str, Any]]: - extra = dict(self.extra or {}) - if context := _current_context.get(None): - extra.update( - {f.name: getattr(context, f.name) for f in dataclasses.fields(context)} - ) - kwargs["extra"] = extra | kwargs.get("extra", {}) - return msg, kwargs - - -logger = LoggerAdapter(logging.getLogger(__name__), None) -"""Logger that emits additional data describing the current Nexus operation.""" +from . import handler as handler diff --git a/temporalio/nexus/handler.py b/temporalio/nexus/handler.py deleted file mode 100644 index 4e96bb33e..000000000 --- a/temporalio/nexus/handler.py +++ /dev/null @@ -1,471 +0,0 @@ -from __future__ import annotations - -import logging -import re -import types -import typing -import urllib.parse -import warnings -from contextvars import ContextVar -from dataclasses import dataclass -from functools import wraps -from typing import ( - Any, - Awaitable, - Callable, - Generic, - Optional, - Sequence, - Type, - TypeVar, - Union, -) - -import nexusrpc.handler -from typing_extensions import Concatenate, Self, overload - -import temporalio.api.common.v1 -import temporalio.api.enums.v1 -import temporalio.common -from temporalio.client import ( - Client, - WorkflowHandle, -) -from temporalio.nexus.token import WorkflowOperationToken -from temporalio.types import ( - MethodAsyncNoParam, - MethodAsyncSingleParam, - MultiParamSpec, - ParamType, - ReturnType, - SelfType, -) - -I = TypeVar("I", contravariant=True) # operation input -O = TypeVar("O", covariant=True) # operation output -S = TypeVar("S") # a service - -logger = logging.getLogger(__name__) - - -# TODO(nexus-preview): demonstrate obtaining Temporal client in sync operation. - - -def _get_workflow_run_start_method_input_and_output_type_annotations( - start_method: Callable[ - [S, nexusrpc.handler.StartOperationContext, I], - Awaitable[WorkflowHandle[Any, O]], - ], -) -> tuple[ - Optional[Type[I]], - Optional[Type[O]], -]: - """Return operation input and output types. - - `start_method` must be a type-annotated start method that returns a - :py:class:`WorkflowHandle`. - """ - input_type, output_type = ( - nexusrpc.handler.get_start_method_input_and_output_types_annotations( - start_method - ) - ) - origin_type = typing.get_origin(output_type) - if not origin_type or not issubclass(origin_type, WorkflowHandle): - warnings.warn( - f"Expected return type of {start_method.__name__} to be a subclass of WorkflowHandle, " - f"but is {output_type}" - ) - output_type = None - - args = typing.get_args(output_type) - if len(args) != 2: - warnings.warn( - f"Expected return type of {start_method.__name__} to have exactly two type parameters, " - f"but has {len(args)}: {args}" - ) - output_type = None - else: - _wf_type, output_type = args - return input_type, output_type - - -# No-param overload -@overload -async def start_workflow( - ctx: nexusrpc.handler.StartOperationContext, - workflow: MethodAsyncNoParam[SelfType, ReturnType], - *, - id: str, - client: Optional[Client] = None, - task_queue: Optional[str] = None, -) -> WorkflowHandle[SelfType, ReturnType]: ... - - -# Single-param overload -@overload -async def start_workflow( - ctx: nexusrpc.handler.StartOperationContext, - workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], - arg: ParamType, - *, - id: str, - client: Optional[Client] = None, - task_queue: Optional[str] = None, -) -> WorkflowHandle[SelfType, ReturnType]: ... - - -# Multiple-params overload -@overload -async def start_workflow( - ctx: nexusrpc.handler.StartOperationContext, - workflow: Callable[Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType]], - *, - args: Sequence[Any], - id: str, - client: Optional[Client] = None, - task_queue: Optional[str] = None, -) -> WorkflowHandle[SelfType, ReturnType]: ... - - -# TODO(nexus-prerelease): Overload for string-name workflow - - -async def start_workflow( - ctx: nexusrpc.handler.StartOperationContext, - workflow: Callable[..., Awaitable[Any]], - arg: Any = temporalio.common._arg_unset, - *, - args: Sequence[Any] = [], - id: str, - client: Optional[Client] = None, - task_queue: Optional[str] = None, -) -> WorkflowHandle[Any, Any]: - if client is None: - client = get_client() - if task_queue is None: - # TODO(nexus-prerelease): are we handling empty string well elsewhere? - task_queue = get_task_queue() - completion_callbacks = ( - [ - # TODO(nexus-prerelease): For WorkflowRunOperation, when it handles the Nexus - # request, it needs to copy the links to the callback in - # StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links - # (for backwards compatibility). PR reference in Go SDK: - # https://github.com/temporalio/sdk-go/pull/1945 - temporalio.common.NexusCompletionCallback( - url=ctx.callback_url, header=ctx.callback_headers - ) - ] - if ctx.callback_url - else [] - ) - # We need to pass options (completion_callbacks, links, on_conflict_options) which are - # deliberately not exposed in any overload, hence the type error. - wf_handle = await client.start_workflow( # type: ignore - workflow, - args=temporalio.common._arg_or_args(arg, args), - id=id, - task_queue=task_queue, - nexus_completion_callbacks=completion_callbacks, - workflow_event_links=[ - _nexus_link_to_workflow_event(l) for l in ctx.inbound_links - ], - ) - try: - link = _workflow_event_to_nexus_link( - _workflow_handle_to_workflow_execution_started_event_link(wf_handle) - ) - except Exception as e: - logger.warning( - f"Failed to create WorkflowExecutionStarted event link for workflow {id}: {e}" - ) - else: - ctx.outbound_links.append( - # TODO(nexus-prerelease): Before, WorkflowRunOperation was generating an EventReference - # link to send back to the caller. Now, it checks if the server returned - # the link in the StartWorkflowExecutionResponse, and if so, send the link - # from the response to the caller. Fallback to generating the link for - # backwards compatibility. PR reference in Go SDK: - # https://github.com/temporalio/sdk-go/pull/1934 - link - ) - return wf_handle - - -# TODO(nexus-prerelease): support request_id -# See e.g. TS -# packages/nexus/src/context.ts attachRequestId -# packages/test/src/test-nexus-handler.ts ctx.requestId - - -async def cancel_workflow( - ctx: nexusrpc.handler.CancelOperationContext, - token: str, - client: Optional[Client] = None, -) -> None: - _client = client or get_client() - handle = WorkflowOperationToken.decode(token).to_workflow_handle(_client) - await handle.cancel() - - -_current_context: ContextVar[_Context] = ContextVar("nexus-handler") - - -@dataclass -class _Context: - client: Optional[Client] - task_queue: Optional[str] - service: Optional[str] = None - operation: Optional[str] = None - - -def get_client() -> Client: - context = _current_context.get(None) - if context is None: - raise RuntimeError("Not in Nexus handler context") - if context.client is None: - raise RuntimeError("Nexus handler client not set") - return context.client - - -def get_task_queue() -> str: - context = _current_context.get(None) - if context is None: - raise RuntimeError("Not in Nexus handler context") - if context.task_queue is None: - raise RuntimeError("Nexus handler task queue not set") - return context.task_queue - - -class WorkflowRunOperation(nexusrpc.handler.OperationHandler[I, O], Generic[I, O, S]): - def __init__( - self, - service: S, - start_method: Callable[ - [S, nexusrpc.handler.StartOperationContext, I], - Awaitable[WorkflowHandle[Any, O]], - ], - output_type: Optional[Type] = None, - ): - self.service = service - - @wraps(start_method) - async def start( - self, ctx: nexusrpc.handler.StartOperationContext, input: I - ) -> WorkflowRunOperationResult: - wf_handle = await start_method(service, ctx, input) - # TODO(nexus-prerelease): Error message if user has accidentally used the normal client.start_workflow - return WorkflowRunOperationResult.from_workflow_handle(wf_handle) - - self.start = types.MethodType(start, self) - - async def start( - self, ctx: nexusrpc.handler.StartOperationContext, input: I - ) -> nexusrpc.handler.StartOperationResultAsync: - raise NotImplementedError( - "The start method of a WorkflowRunOperation should be set " - "dynamically in the __init__ method. (Did you forget to call super()?)" - ) - - async def cancel( - self, ctx: nexusrpc.handler.CancelOperationContext, token: str - ) -> None: - await cancel_workflow(ctx, token) - - def fetch_info( - self, ctx: nexusrpc.handler.FetchOperationInfoContext, token: str - ) -> Union[ - nexusrpc.handler.OperationInfo, Awaitable[nexusrpc.handler.OperationInfo] - ]: - raise NotImplementedError( - "Temporal Nexus operation handlers do not support fetching operation info." - ) - - def fetch_result( - self, ctx: nexusrpc.handler.FetchOperationResultContext, token: str - ) -> Union[O, Awaitable[O]]: - raise NotImplementedError( - "Temporal Nexus operation handlers do not support fetching operation results." - ) - - -class WorkflowRunOperationResult(nexusrpc.handler.StartOperationResultAsync): - """ - A value returned by the start method of a :class:`WorkflowRunOperation`. - - It indicates that the operation is responding asynchronously, and contains a token - that the handler can use to construct a :class:`~temporalio.client.WorkflowHandle` to - interact with the workflow. - """ - - @classmethod - def from_workflow_handle(cls, workflow_handle: WorkflowHandle) -> Self: - token = WorkflowOperationToken.from_workflow_handle(workflow_handle).encode() - return cls(token=token) - - -@overload -def workflow_run_operation_handler( - start_method: Callable[ - [S, nexusrpc.handler.StartOperationContext, I], - Awaitable[WorkflowHandle[Any, O]], - ], -) -> Callable[[S], WorkflowRunOperation[I, O, S]]: ... - - -@overload -def workflow_run_operation_handler( - *, - name: Optional[str] = None, -) -> Callable[ - [ - Callable[ - [S, nexusrpc.handler.StartOperationContext, I], - Awaitable[WorkflowHandle[Any, O]], - ] - ], - Callable[[S], WorkflowRunOperation[I, O, S]], -]: ... - - -def workflow_run_operation_handler( - start_method: Optional[ - Callable[ - [S, nexusrpc.handler.StartOperationContext, I], - Awaitable[WorkflowHandle[Any, O]], - ] - ] = None, - *, - name: Optional[str] = None, -) -> Union[ - Callable[[S], WorkflowRunOperation[I, O, S]], - Callable[ - [ - Callable[ - [S, nexusrpc.handler.StartOperationContext, I], - Awaitable[WorkflowHandle[Any, O]], - ] - ], - Callable[[S], WorkflowRunOperation[I, O, S]], - ], -]: - def decorator( - start_method: Callable[ - [S, nexusrpc.handler.StartOperationContext, I], - Awaitable[WorkflowHandle[Any, O]], - ], - ) -> Callable[[S], WorkflowRunOperation[I, O, S]]: - input_type, output_type = ( - _get_workflow_run_start_method_input_and_output_type_annotations( - start_method - ) - ) - - def factory(service: S) -> WorkflowRunOperation[I, O, S]: - return WorkflowRunOperation(service, start_method, output_type=output_type) - - # TODO(nexus-prerelease): handle callable instances: __class__.__name__ as in sync_operation_handler - method_name = getattr(start_method, "__name__", None) - if not method_name and callable(start_method): - method_name = start_method.__class__.__name__ - if not method_name: - raise TypeError( - f"Could not determine operation method name: " - f"expected {start_method} to be a function or callable instance." - ) - - factory.__nexus_operation__ = nexusrpc.Operation._create( - name=name, - method_name=method_name, - input_type=input_type, - output_type=output_type, - ) - - return factory - - if start_method is None: - return decorator - - return decorator(start_method) - - -# TODO(nexus-prerelease): confirm that it is correct not to use event_id in the following functions. -# Should the proto say explicitly that it's optional or how it behaves when it's missing? -def _workflow_handle_to_workflow_execution_started_event_link( - handle: WorkflowHandle[Any, Any], -) -> temporalio.api.common.v1.Link.WorkflowEvent: - if handle.first_execution_run_id is None: - raise ValueError( - f"Workflow handle {handle} has no first execution run ID. " - "Cannot create WorkflowExecutionStarted event link." - ) - return temporalio.api.common.v1.Link.WorkflowEvent( - namespace=handle._client.namespace, - workflow_id=handle.id, - run_id=handle.first_execution_run_id, - event_ref=temporalio.api.common.v1.Link.WorkflowEvent.EventReference( - event_type=temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED - ), - ) - - -def _workflow_event_to_nexus_link( - workflow_event: temporalio.api.common.v1.Link.WorkflowEvent, -) -> nexusrpc.handler.Link: - scheme = "temporal" - namespace = urllib.parse.quote(workflow_event.namespace) - workflow_id = urllib.parse.quote(workflow_event.workflow_id) - run_id = urllib.parse.quote(workflow_event.run_id) - path = f"/namespaces/{namespace}/workflows/{workflow_id}/{run_id}/history" - query_params = urllib.parse.urlencode( - { - "eventType": temporalio.api.enums.v1.EventType.Name( - workflow_event.event_ref.event_type - ), - "referenceType": "EventReference", - } - ) - return nexusrpc.handler.Link( - url=urllib.parse.urlunparse((scheme, "", path, "", query_params, "")), - type=workflow_event.DESCRIPTOR.full_name, - ) - - -def _nexus_link_to_workflow_event( - link: nexusrpc.handler.Link, -) -> Optional[temporalio.api.common.v1.Link.WorkflowEvent]: - path_regex = re.compile( - r"^/namespaces/(?P[^/]+)/workflows/(?P[^/]+)/(?P[^/]+)/history$" - ) - url = urllib.parse.urlparse(link.url) - match = path_regex.match(url.path) - if not match: - logger.warning( - f"Invalid Nexus link: {link}. Expected path to match {path_regex.pattern}" - ) - return None - try: - query_params = urllib.parse.parse_qs(url.query) - [reference_type] = query_params.get("referenceType", []) - if reference_type != "EventReference": - raise ValueError( - f"@@ Expected Nexus link URL query parameter referenceType to be EventReference but got: {reference_type}" - ) - [event_type_name] = query_params.get("eventType", []) - event_ref = temporalio.api.common.v1.Link.WorkflowEvent.EventReference( - event_type=temporalio.api.enums.v1.EventType.Value(event_type_name) - ) - except ValueError as err: - logger.warning( - f"@@ Failed to parse event type from Nexus link URL query parameters: {link} ({err})" - ) - event_ref = None - - groups = match.groupdict() - return temporalio.api.common.v1.Link.WorkflowEvent( - namespace=urllib.parse.unquote(groups["namespace"]), - workflow_id=urllib.parse.unquote(groups["workflow_id"]), - run_id=urllib.parse.unquote(groups["run_id"]), - event_ref=event_ref, - ) diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py new file mode 100644 index 000000000..9750b876a --- /dev/null +++ b/temporalio/nexus/handler/__init__.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import logging +from collections.abc import Mapping +from typing import ( + TYPE_CHECKING, + Any, + MutableMapping, + Optional, +) + +from nexusrpc.handler import ( + CancelOperationContext as CancelOperationContext, +) +from nexusrpc.handler import ( + HandlerError as HandlerError, +) +from nexusrpc.handler import ( + HandlerErrorType as HandlerErrorType, +) + +from ._operation_context import ( + TemporalNexusOperationContext as TemporalNexusOperationContext, +) +from ._operation_handlers import ( + WorkflowRunOperationHandler as WorkflowRunOperationHandler, +) +from ._operation_handlers import ( + WorkflowRunOperationResult as WorkflowRunOperationResult, +) +from ._operation_handlers import cancel_workflow as cancel_workflow +from ._operation_handlers import ( + workflow_run_operation_handler as workflow_run_operation_handler, +) +from ._token import ( + WorkflowOperationToken as WorkflowOperationToken, +) + +if TYPE_CHECKING: + from temporalio.client import ( + Client as Client, + ) + from temporalio.client import ( + WorkflowHandle as WorkflowHandle, + ) + + +class LoggerAdapter(logging.LoggerAdapter): + def __init__(self, logger: logging.Logger, extra: Optional[Mapping[str, Any]]): + super().__init__(logger, extra or {}) + + def process( + self, msg: Any, kwargs: MutableMapping[str, Any] + ) -> tuple[Any, MutableMapping[str, Any]]: + extra = dict(self.extra or {}) + if tctx := TemporalNexusOperationContext.current(): + extra["service"] = tctx.nexus_operation_context.service + extra["operation"] = tctx.nexus_operation_context.operation + extra["task_queue"] = tctx.task_queue + kwargs["extra"] = extra | kwargs.get("extra", {}) + return msg, kwargs + + +logger = LoggerAdapter(logging.getLogger(__name__), None) +"""Logger that emits additional data describing the current Nexus operation.""" + + +# TODO(nexus-preview): demonstrate obtaining Temporal client in sync operation. + + +# TODO(nexus-prerelease): support request_id +# See e.g. TS +# packages/nexus/src/context.ts attachRequestId +# packages/test/src/test-nexus-handler.ts ctx.requestId diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py new file mode 100644 index 000000000..2b6dbd9cc --- /dev/null +++ b/temporalio/nexus/handler/_operation_context.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import contextvars +import logging +import re +import urllib.parse +from abc import ABC +from contextvars import ContextVar +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Union, +) + +import nexusrpc.handler +from nexusrpc.handler import CancelOperationContext, StartOperationContext + +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.common + +if TYPE_CHECKING: + from temporalio.client import ( + Client, + WorkflowHandle, + ) + + +logger = logging.getLogger(__name__) + + +_current_context: ContextVar[TemporalNexusOperationContext] = ContextVar( + "temporal-nexus-operation-context" +) + + +@dataclass +class TemporalNexusOperationContext(ABC): + """ + Context for a Nexus operation being handled by a Temporal Nexus Worker. + """ + + nexus_operation_context: Union[StartOperationContext, CancelOperationContext] + + client: Client + """The Temporal client in use by the worker handling this Nexus operation.""" + + task_queue: str + """The task queue of the worker handling this Nexus operation.""" + + @staticmethod + def try_current() -> Optional[TemporalNexusOperationContext]: + return _current_context.get(None) + + @staticmethod + def current() -> TemporalNexusOperationContext: + context = TemporalNexusOperationContext.try_current() + if not context: + raise RuntimeError("Not in Nexus operation context") + return context + + @staticmethod + def set(context: TemporalNexusOperationContext) -> contextvars.Token: + return _current_context.set(context) + + @staticmethod + def reset(token: contextvars.Token) -> None: + _current_context.reset(token) + + @property + def temporal_nexus_start_operation_context( + self, + ) -> Optional[_TemporalNexusStartOperationContext]: + ctx = self.nexus_operation_context + if not isinstance(ctx, StartOperationContext): + return None + return _TemporalNexusStartOperationContext(ctx) + + @property + def temporal_nexus_cancel_operation_context( + self, + ) -> Optional[_TemporalNexusCancelOperationContext]: + ctx = self.nexus_operation_context + if not isinstance(ctx, CancelOperationContext): + return None + return _TemporalNexusCancelOperationContext(ctx) + + +@dataclass +class _TemporalNexusStartOperationContext: + nexus_operation_context: StartOperationContext + + def get_completion_callbacks( + self, + ) -> list[temporalio.common.NexusCompletionCallback]: + ctx = self.nexus_operation_context + return ( + [ + # TODO(nexus-prerelease): For WorkflowRunOperation, when it handles the Nexus + # request, it needs to copy the links to the callback in + # StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links + # (for backwards compatibility). PR reference in Go SDK: + # https://github.com/temporalio/sdk-go/pull/1945 + temporalio.common.NexusCompletionCallback( + url=ctx.callback_url, + header=ctx.callback_headers, + ) + ] + if ctx.callback_url + else [] + ) + + def get_workflow_event_links( + self, + ) -> list[temporalio.api.common.v1.Link.WorkflowEvent]: + event_links = [] + for inbound_link in self.nexus_operation_context.inbound_links: + if link := _nexus_link_to_workflow_event(inbound_link): + event_links.append(link) + return event_links + + def add_outbound_links(self, workflow_handle: WorkflowHandle[Any, Any]): + try: + link = _workflow_event_to_nexus_link( + _workflow_handle_to_workflow_execution_started_event_link( + workflow_handle + ) + ) + except Exception as e: + logger.warning( + f"Failed to create WorkflowExecutionStarted event link for workflow {id}: {e}" + ) + else: + self.nexus_operation_context.outbound_links.append( + # TODO(nexus-prerelease): Before, WorkflowRunOperation was generating an EventReference + # link to send back to the caller. Now, it checks if the server returned + # the link in the StartWorkflowExecutionResponse, and if so, send the link + # from the response to the caller. Fallback to generating the link for + # backwards compatibility. PR reference in Go SDK: + # https://github.com/temporalio/sdk-go/pull/1934 + link + ) + return workflow_handle + + +@dataclass +class _TemporalNexusCancelOperationContext: + nexus_operation_context: CancelOperationContext + + +# TODO(nexus-prerelease): confirm that it is correct not to use event_id in the following functions. +# Should the proto say explicitly that it's optional or how it behaves when it's missing? +def _workflow_handle_to_workflow_execution_started_event_link( + handle: WorkflowHandle[Any, Any], +) -> temporalio.api.common.v1.Link.WorkflowEvent: + if handle.first_execution_run_id is None: + raise ValueError( + f"Workflow handle {handle} has no first execution run ID. " + "Cannot create WorkflowExecutionStarted event link." + ) + return temporalio.api.common.v1.Link.WorkflowEvent( + namespace=handle._client.namespace, + workflow_id=handle.id, + run_id=handle.first_execution_run_id, + event_ref=temporalio.api.common.v1.Link.WorkflowEvent.EventReference( + event_type=temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED + ), + ) + + +def _workflow_event_to_nexus_link( + workflow_event: temporalio.api.common.v1.Link.WorkflowEvent, +) -> nexusrpc.Link: + scheme = "temporal" + namespace = urllib.parse.quote(workflow_event.namespace) + workflow_id = urllib.parse.quote(workflow_event.workflow_id) + run_id = urllib.parse.quote(workflow_event.run_id) + path = f"/namespaces/{namespace}/workflows/{workflow_id}/{run_id}/history" + query_params = urllib.parse.urlencode( + { + "eventType": temporalio.api.enums.v1.EventType.Name( + workflow_event.event_ref.event_type + ), + "referenceType": "EventReference", + } + ) + return nexusrpc.Link( + url=urllib.parse.urlunparse((scheme, "", path, "", query_params, "")), + type=workflow_event.DESCRIPTOR.full_name, + ) + + +_LINK_URL_PATH_REGEX = re.compile( + r"^/namespaces/(?P[^/]+)/workflows/(?P[^/]+)/(?P[^/]+)/history$" +) + + +def _nexus_link_to_workflow_event( + link: nexusrpc.Link, +) -> Optional[temporalio.api.common.v1.Link.WorkflowEvent]: + url = urllib.parse.urlparse(link.url) + match = _LINK_URL_PATH_REGEX.match(url.path) + if not match: + logger.warning( + f"Invalid Nexus link: {link}. Expected path to match {_LINK_URL_PATH_REGEX.pattern}" + ) + return None + try: + query_params = urllib.parse.parse_qs(url.query) + [reference_type] = query_params.get("referenceType", []) + if reference_type != "EventReference": + raise ValueError( + f"Expected Nexus link URL query parameter referenceType to be EventReference but got: {reference_type}" + ) + [event_type_name] = query_params.get("eventType", []) + event_ref = temporalio.api.common.v1.Link.WorkflowEvent.EventReference( + event_type=temporalio.api.enums.v1.EventType.Value(event_type_name) + ) + except ValueError as err: + logger.warning( + f"Failed to parse event type from Nexus link URL query parameters: {link} ({err})" + ) + event_ref = None + + groups = match.groupdict() + return temporalio.api.common.v1.Link.WorkflowEvent( + namespace=urllib.parse.unquote(groups["namespace"]), + workflow_id=urllib.parse.unquote(groups["workflow_id"]), + run_id=urllib.parse.unquote(groups["run_id"]), + event_ref=event_ref, + ) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py new file mode 100644 index 000000000..56b2ccb51 --- /dev/null +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import types +import typing +import warnings +from functools import wraps +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generic, + Optional, + Type, + Union, +) + +import nexusrpc.handler +from nexusrpc.handler import ( + CancelOperationContext, + HandlerError, + HandlerErrorType, + StartOperationContext, +) +from nexusrpc.types import ( + InputT, + OutputT, + ServiceHandlerT, +) +from typing_extensions import Self, overload + +from ._operation_context import TemporalNexusOperationContext +from ._token import ( + WorkflowOperationToken as WorkflowOperationToken, +) + +if TYPE_CHECKING: + from temporalio.client import ( + Client, + WorkflowHandle, + ) + + +async def cancel_workflow( + ctx: CancelOperationContext, + token: str, + client: Optional[Client] = None, # noqa + **kwargs: Any, +) -> None: + client = client or TemporalNexusOperationContext.current().client + try: + decoded = WorkflowOperationToken.decode(token) + except Exception as err: + raise HandlerError( + "Failed to decode workflow operation token", + type=HandlerErrorType.NOT_FOUND, + cause=err, + ) + try: + handle = decoded.to_workflow_handle(client) + except Exception as err: + raise HandlerError( + "Failed to construct workflow handle from workflow operation token", + type=HandlerErrorType.NOT_FOUND, + cause=err, + ) + await handle.cancel(**kwargs) + + +class WorkflowRunOperationHandler( + nexusrpc.handler.OperationHandler[InputT, OutputT], + Generic[InputT, OutputT, ServiceHandlerT], +): + def __init__( + self, + service: ServiceHandlerT, + start_method: Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowHandle[Any, OutputT]], + ], + output_type: Optional[Type] = None, + ): + self.service = service + + @wraps(start_method) + async def start( + self, ctx: StartOperationContext, input: InputT + ) -> WorkflowRunOperationResult: + wf_handle = await start_method(service, ctx, input) + return WorkflowRunOperationResult.from_workflow_handle(wf_handle) + + self.start = types.MethodType(start, self) + + async def start( + self, ctx: StartOperationContext, input: InputT + ) -> nexusrpc.handler.StartOperationResultAsync: + raise NotImplementedError( + "The start method of a WorkflowRunOperation should be set " + "dynamically in the __init__ method. (Did you forget to call super()?)" + ) + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + await cancel_workflow(ctx, token) + + def fetch_info( + self, ctx: nexusrpc.handler.FetchOperationInfoContext, token: str + ) -> Union[ + nexusrpc.handler.OperationInfo, Awaitable[nexusrpc.handler.OperationInfo] + ]: + raise NotImplementedError( + "Temporal Nexus operation handlers do not support fetching operation info." + ) + + def fetch_result( + self, ctx: nexusrpc.handler.FetchOperationResultContext, token: str + ) -> Union[OutputT, Awaitable[OutputT]]: + raise NotImplementedError( + "Temporal Nexus operation handlers do not support fetching operation results." + ) + + +class WorkflowRunOperationResult(nexusrpc.handler.StartOperationResultAsync): + """ + A value returned by the start method of a :class:`WorkflowRunOperation`. + + It indicates that the operation is responding asynchronously, and contains a token + that the handler can use to construct a :class:`~temporalio.client.WorkflowHandle` to + interact with the workflow. + """ + + @classmethod + def from_workflow_handle(cls, workflow_handle: WorkflowHandle) -> Self: + """ + Create a :class:`WorkflowRunOperationResult` from a :py:class:`~temporalio.client.WorkflowHandle`. + """ + token = WorkflowOperationToken.from_workflow_handle(workflow_handle).encode() + return cls(token=token) + + def to_workflow_handle(self, client: Client) -> WorkflowHandle: + """ + Create a :py:class:`~temporalio.client.WorkflowHandle` from a :class:`WorkflowRunOperationResult`. + """ + workflow_operation_token = WorkflowOperationToken.decode(self.token) + if workflow_operation_token.namespace != client.namespace: + raise ValueError( + "Cannot create a workflow handle from a workflow operation result " + "with a client whose namespace is not the same as the namespace of the " + "workflow operation token." + ) + return WorkflowOperationToken.decode(self.token).to_workflow_handle(client) + + +@overload +def workflow_run_operation_handler( + start_method: Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowHandle[Any, OutputT]], + ], +) -> Callable[ + [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] +]: ... + + +@overload +def workflow_run_operation_handler( + *, + name: Optional[str] = None, +) -> Callable[ + [ + Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowHandle[Any, OutputT]], + ] + ], + Callable[ + [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] + ], +]: ... + + +def workflow_run_operation_handler( + start_method: Optional[ + Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowHandle[Any, OutputT]], + ] + ] = None, + *, + name: Optional[str] = None, +) -> Union[ + Callable[ + [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] + ], + Callable[ + [ + Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowHandle[Any, OutputT]], + ] + ], + Callable[ + [ServiceHandlerT], + WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT], + ], + ], +]: + def decorator( + start_method: Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowHandle[Any, OutputT]], + ], + ) -> Callable[ + [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] + ]: + input_type, output_type = ( + _get_workflow_run_start_method_input_and_output_type_annotations( + start_method + ) + ) + + def factory( + service: ServiceHandlerT, + ) -> WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT]: + return WorkflowRunOperationHandler( + service, start_method, output_type=output_type + ) + + # TODO(nexus-prerelease): handle callable instances: __class__.__name__ as in sync_operation_handler + method_name = getattr(start_method, "__name__", None) + if not method_name and callable(start_method): + method_name = start_method.__class__.__name__ + if not method_name: + raise TypeError( + f"Could not determine operation method name: " + f"expected {start_method} to be a function or callable instance." + ) + + factory.__nexus_operation__ = nexusrpc.Operation( + name=name or method_name, + method_name=method_name, + input_type=input_type, + output_type=output_type, + ) + + return factory + + if start_method is None: + return decorator + + return decorator(start_method) + + +def _get_workflow_run_start_method_input_and_output_type_annotations( + start_method: Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowHandle[Any, OutputT]], + ], +) -> tuple[ + Optional[Type[InputT]], + Optional[Type[OutputT]], +]: + """Return operation input and output types. + + `start_method` must be a type-annotated start method that returns a + :py:class:`WorkflowHandle`. + """ + # TODO(nexus-preview) circular import + from temporalio.client import WorkflowHandle + + input_type, output_type = ( + nexusrpc.handler.get_start_method_input_and_output_types_annotations( + start_method + ) + ) + origin_type = typing.get_origin(output_type) + if not origin_type or not issubclass(origin_type, WorkflowHandle): + warnings.warn( + f"Expected return type of {start_method.__name__} to be a subclass of WorkflowHandle, " + f"but is {output_type}" + ) + output_type = None + + args = typing.get_args(output_type) + if len(args) != 2: + warnings.warn( + f"Expected return type of {start_method.__name__} to have exactly two type parameters, " + f"but has {len(args)}: {args}" + ) + output_type = None + else: + _wf_type, output_type = args + return input_type, output_type diff --git a/temporalio/nexus/token.py b/temporalio/nexus/handler/_token.py similarity index 82% rename from temporalio/nexus/token.py rename to temporalio/nexus/handler/_token.py index d357ecb9c..bf08198e4 100644 --- a/temporalio/nexus/token.py +++ b/temporalio/nexus/handler/_token.py @@ -3,9 +3,10 @@ import base64 import json from dataclasses import dataclass -from typing import Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional -from temporalio.client import Client, WorkflowHandle +if TYPE_CHECKING: + from temporalio.client import Client, WorkflowHandle OPERATION_TOKEN_TYPE_WORKFLOW = 1 OperationTokenType = Literal[1] @@ -53,41 +54,43 @@ def encode(self) -> str: ) @classmethod - def decode(cls, data: str) -> WorkflowOperationToken: + def decode(cls, token: str) -> WorkflowOperationToken: """Decodes and validates a token from its base64url-encoded string representation.""" - if not data: + if not token: raise TypeError("invalid workflow token: token is empty") try: - decoded_bytes = _base64url_decode_no_padding(data) + decoded_bytes = _base64url_decode_no_padding(token) except Exception as err: raise TypeError("failed to decode token as base64url") from err try: - token = json.loads(decoded_bytes.decode("utf-8")) + workflow_operation_token = json.loads(decoded_bytes.decode("utf-8")) except Exception as err: raise TypeError("failed to unmarshal workflow operation token") from err - if not isinstance(token, dict): - raise TypeError(f"invalid workflow token: expected dict, got {type(token)}") + if not isinstance(workflow_operation_token, dict): + raise TypeError( + f"invalid workflow token: expected dict, got {type(workflow_operation_token)}" + ) - _type = token.get("t") + _type = workflow_operation_token.get("t") if _type != OPERATION_TOKEN_TYPE_WORKFLOW: raise TypeError( f"invalid workflow token type: {_type}, expected: {OPERATION_TOKEN_TYPE_WORKFLOW}" ) - version = token.get("v") + version = workflow_operation_token.get("v") if version is not None and version != 0: raise TypeError( "invalid workflow token: 'v' field, if present, must be 0 or null/absent" ) - workflow_id = token.get("wid") + workflow_id = workflow_operation_token.get("wid") if not workflow_id or not isinstance(workflow_id, str): raise TypeError( "invalid workflow token: missing, empty, or non-string workflow ID (wid)" ) - namespace = token.get("ns") + namespace = workflow_operation_token.get("ns") if namespace is None or not isinstance(namespace, str): # Allow empty string for ns, but it must be present and a string raise TypeError( diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 7e0a1d35b..6f6965093 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -15,11 +15,14 @@ Optional, Sequence, Type, - TypeVar, Union, ) import nexusrpc.handler +from nexusrpc.types import ( + InputT, + OutputT, +) import temporalio.activity import temporalio.api.common.v1 @@ -287,29 +290,24 @@ class StartChildWorkflowInput: ret_type: Optional[Type] -# TODO(nexus-prerelease): Put these in a better location. Type variance? -I = TypeVar("I") -O = TypeVar("O") - - @dataclass -class StartNexusOperationInput(Generic[I, O]): +class StartNexusOperationInput(Generic[InputT, OutputT]): """Input for :py:meth:`WorkflowOutboundInterceptor.start_nexus_operation`.""" endpoint: str service: str operation: Union[ - nexusrpc.Operation[I, O], - Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], + nexusrpc.Operation[InputT, OutputT], + Callable[[Any], nexusrpc.handler.OperationHandler[InputT, OutputT]], str, ] - input: I + input: InputT schedule_to_close_timeout: Optional[timedelta] headers: Optional[Mapping[str, str]] - output_type: Optional[Type[O]] = None + output_type: Optional[Type[OutputT]] = None _operation_name: str = field(init=False, repr=False) - _input_type: Optional[Type[I]] = field(init=False, repr=False) + _input_type: Optional[Type[InputT]] = field(init=False, repr=False) def __post_init__(self) -> None: if isinstance(self.operation, str): @@ -336,8 +334,9 @@ def __post_init__(self) -> None: def operation_name(self) -> str: return self._operation_name + # TODO(nexus-prerelease) contravariant type in output @property - def input_type(self) -> Optional[Type[I]]: + def input_type(self) -> Optional[Type[InputT]]: return self._input_type diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index e8c57c4f8..fdb41c762 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -10,6 +10,7 @@ from typing import ( Any, Callable, + NoReturn, Optional, Sequence, Type, @@ -17,7 +18,12 @@ import google.protobuf.json_format import nexusrpc.handler -from nexusrpc.handler._core import SyncExecutor +from nexusrpc import LazyValueAsync as LazyValue +from nexusrpc.handler import ( + CancelOperationContext, + StartOperationContext, +) +from nexusrpc.handler import HandlerAsync as Handler import temporalio.api.common.v1 import temporalio.api.enums.v1 @@ -31,6 +37,9 @@ import temporalio.nexus import temporalio.nexus.handler from temporalio.exceptions import ApplicationError +from temporalio.nexus.handler import ( + TemporalNexusOperationContext, +) from temporalio.service import RPCError, RPCStatusCode from ._interceptor import Interceptor @@ -45,79 +54,94 @@ def __init__( bridge_worker: Callable[[], temporalio.bridge.worker.Worker], client: temporalio.client.Client, task_queue: str, - nexus_services: Sequence[Any], + service_handlers: Sequence[Any], data_converter: temporalio.converter.DataConverter, interceptors: Sequence[Interceptor], metric_meter: temporalio.common.MetricMeter, - executor: Optional[concurrent.futures.ThreadPoolExecutor], + executor: Optional[concurrent.futures.Executor], ) -> None: - # TODO(nexus-prerelease): make it possible to query task queue of bridge worker - # instead of passing unused task_queue into _NexusWorker, - # _ActivityWorker, etc? + # TODO: make it possible to query task queue of bridge worker instead of passing + # unused task_queue into _NexusWorker, _ActivityWorker, etc? self._bridge_worker = bridge_worker self._client = client self._task_queue = task_queue - for service in nexus_services: + for service in service_handlers: if isinstance(service, type): raise TypeError( f"Expected a service instance, but got a class: {service}. " "Nexus services must be passed as instances, not classes." ) - self._handler = nexusrpc.handler.Handler( - nexus_services, - SyncExecutor(executor) if executor is not None else None, - ) + self._handler = Handler(service_handlers, executor) self._data_converter = data_converter # TODO(nexus-prerelease): interceptors self._interceptors = interceptors # TODO(nexus-prerelease): metric_meter self._metric_meter = metric_meter - self._running_operations: dict[bytes, asyncio.Task[Any]] = {} + self._running_tasks: dict[bytes, asyncio.Task[Any]] = {} + self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue() async def run(self) -> None: + """ + Continually poll for Nexus tasks and dispatch to handlers. + """ + + async def raise_from_exception_queue() -> NoReturn: + raise await self._fail_worker_exception_queue.get() + + exception_task = asyncio.create_task(raise_from_exception_queue()) + while True: try: poll_task = asyncio.create_task(self._bridge_worker().poll_nexus_task()) - except Exception as err: - raise RuntimeError("Nexus worker failed") from err - - task = await poll_task + await asyncio.wait( + [poll_task, exception_task], return_when=asyncio.FIRST_COMPLETED + ) + if exception_task.done(): + poll_task.cancel() + await exception_task + task = await poll_task - if task.HasField("task"): - task = task.task - if task.request.HasField("start_operation"): - self._running_operations[task.task_token] = asyncio.create_task( - self._run_nexus_operation( - task.task_token, - task.request.start_operation, - dict(task.request.header), + if task.HasField("task"): + task = task.task + if task.request.HasField("start_operation"): + self._running_tasks[task.task_token] = asyncio.create_task( + self._handle_start_operation_task( + task.task_token, + task.request.start_operation, + dict(task.request.header), + ) ) - ) - elif task.request.HasField("cancel_operation"): - # TODO(nexus-prerelease): report errors occurring during execution of user - # cancellation method - asyncio.create_task( - self._handle_cancel_operation( - task.request.cancel_operation, task.task_token + elif task.request.HasField("cancel_operation"): + # TODO(nexus-prerelease): do we need to track cancel operation + # tasks as we do start operation tasks? + asyncio.create_task( + self._handle_cancel_operation_task( + task.request.cancel_operation, task.task_token + ) + ) + else: + raise NotImplementedError( + f"Invalid Nexus task request: {task.request}" + ) + elif task.HasField("cancel_task"): + task = task.cancel_task + if _task := self._running_tasks.get(task.task_token): + # TODO(nexus-prerelease): when do we remove the entry from _running_operations? + _task.cancel() + else: + temporalio.nexus.handler.logger.warning( + f"Received cancel_task but no running operation exists for " + f"task token: {task.task_token}" ) - ) - else: - raise NotImplementedError( - f"Invalid Nexus task request: {task.request}" - ) - elif task.HasField("cancel_task"): - task = task.cancel_task - if _task := self._running_operations.get(task.task_token): - # TODO(nexus-prerelease): when do we remove the entry from _running_operations? - _task.cancel() else: - temporalio.nexus.logger.warning( - f"Received cancel_task but no running operation exists for " - f"task token: {task.task_token}" - ) - else: - raise NotImplementedError(f"Invalid Nexus task: {task}") + raise NotImplementedError(f"Invalid Nexus task: {task}") + + # TODO(nexus-prerelease): handle poller shutdown + # except temporalio.bridge.worker.PollShutdownError + + except Exception as err: + raise RuntimeError("Nexus worker failed") from err # Only call this if run() raised an error async def drain_poll_queue(self) -> None: @@ -133,183 +157,175 @@ async def drain_poll_queue(self) -> None: except temporalio.bridge.worker.PollShutdownError: return + # Only call this after run()/drain_poll_queue() have returned. This will not + # raise an exception. async def wait_all_completed(self) -> None: - await asyncio.gather( - *self._running_operations.values(), return_exceptions=False - ) + await asyncio.gather(*self._running_tasks.values(), return_exceptions=True) # TODO(nexus-prerelease): stack trace pruning. See sdk-typescript NexusHandler.execute # "Any call up to this function and including this one will be trimmed out of stack traces."" - async def _run_nexus_operation( + async def _handle_cancel_operation_task( + self, request: temporalio.api.nexus.v1.CancelOperationRequest, task_token: bytes + ) -> None: + """ + Handle a cancel operation task. + + Attempt to execute the user cancel_operation method. Handle errors and send the + task completion. + """ + ctx = CancelOperationContext( + service=request.service, + operation=request.operation, + ) + TemporalNexusOperationContext.set( + TemporalNexusOperationContext( + nexus_operation_context=ctx, + client=self._client, + task_queue=self._task_queue, + ) + ) + # TODO(nexus-prerelease): headers + try: + await self._handler.cancel_operation(ctx, request.operation_token) + except Exception as err: + temporalio.nexus.handler.logger.exception( + "Failed to execute Nexus cancel operation method" + ) + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + error=await self._handler_error_to_proto( + _exception_to_handler_error(err) + ), + ) + else: + # TODO(nexus-prerelease): when do we use ack_cancel? + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + completed=temporalio.api.nexus.v1.Response( + cancel_operation=temporalio.api.nexus.v1.CancelOperationResponse() + ), + ) + try: + await self._bridge_worker().complete_nexus_task(completion) + except Exception: + temporalio.nexus.handler.logger.exception( + "Failed to send Nexus task completion" + ) + + async def _handle_start_operation_task( self, task_token: bytes, start_request: temporalio.api.nexus.v1.StartOperationRequest, - header: dict[str, str], + headers: dict[str, str], ) -> None: - async def run() -> temporalio.bridge.proto.nexus.NexusTaskCompletion: - temporalio.nexus.handler._current_context.set( - temporalio.nexus.handler._Context( - client=self._client, - task_queue=self._task_queue, - service=start_request.service, - operation=start_request.operation, - ) - ) - try: - ctx = nexusrpc.handler.StartOperationContext( - service=start_request.service, - operation=start_request.operation, - headers=header, - request_id=start_request.request_id, - callback_url=start_request.callback, - inbound_links=[ - nexusrpc.handler.Link(url=l.url, type=l.type) - for l in start_request.links - ], - callback_headers=dict(start_request.callback_header), - ) - input = nexusrpc.handler.LazyValue( - serializer=_DummyPayloadSerializer( - data_converter=self._data_converter, - payload=start_request.payload, - ), - headers={}, - stream=None, - ) - try: - result = await self._handler.start_operation(ctx, input) - except ( - nexusrpc.handler.UnknownServiceError, - nexusrpc.handler.UnknownOperationError, - ) as err: - # TODO(nexus-prerelease): error message - raise nexusrpc.handler.HandlerError( - "No matching operation handler", - type=nexusrpc.handler.HandlerErrorType.NOT_FOUND, - cause=err, - retryable=False, - ) from err - - except nexusrpc.handler.OperationError as err: - return temporalio.bridge.proto.nexus.NexusTaskCompletion( - task_token=task_token, - completed=temporalio.api.nexus.v1.Response( - start_operation=temporalio.api.nexus.v1.StartOperationResponse( - operation_error=await self._operation_error_to_proto(err), - ), - ), - ) - except BaseException as err: - handler_err = _exception_to_handler_error(err) - return temporalio.bridge.proto.nexus.NexusTaskCompletion( - task_token=task_token, - error=temporalio.api.nexus.v1.HandlerError( - error_type=handler_err.type.value, - failure=await self._exception_to_failure_proto( - handler_err.__cause__ - ), - retry_behavior=( - temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE - if handler_err.retryable - else temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE - ), - ), - ) - else: - if isinstance(result, nexusrpc.handler.StartOperationResultAsync): - op_resp = temporalio.api.nexus.v1.StartOperationResponse( - async_success=temporalio.api.nexus.v1.StartOperationResponse.Async( - operation_token=result.token, - links=[ - temporalio.api.nexus.v1.Link(url=l.url, type=l.type) - for l in ctx.outbound_links - ], - ) - ) - elif isinstance(result, nexusrpc.handler.StartOperationResultSync): - # TODO(nexus-prerelease): error handling here; what error type should it be? - [payload] = await self._data_converter.encode([result.value]) - op_resp = temporalio.api.nexus.v1.StartOperationResponse( - sync_success=temporalio.api.nexus.v1.StartOperationResponse.Sync( - payload=payload - ) - ) - else: - # TODO(nexus-prerelease): what should the error response be when the user has failed to wrap their return type? - # TODO(nexus-prerelease): unify this failure completion with the path above - err = TypeError( - "Operation start method must return either nexusrpc.handler.StartOperationResultSync " - "or nexusrpc.handler.StartOperationResultAsync" - ) - handler_err = _exception_to_handler_error(err) - return temporalio.bridge.proto.nexus.NexusTaskCompletion( - task_token=task_token, - error=temporalio.api.nexus.v1.HandlerError( - error_type=handler_err.type.value, - failure=await self._exception_to_failure_proto( - handler_err.__cause__ - ), - retry_behavior=( - temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE - if handler_err.retryable - else temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE - ), - ), - ) + """ + Handle a start operation task. - return temporalio.bridge.proto.nexus.NexusTaskCompletion( - task_token=task_token, - completed=temporalio.api.nexus.v1.Response(start_operation=op_resp), - ) + Attempt to execute the user start_operation method and invoke the data converter + on the result. Handle errors and send the task completion. + """ + + try: + start_response = await self._start_operation(start_request, headers) + # TODO(nexus-prerelease): handle BrokenExecutor by failing the worker + except BaseException as err: + handler_err = _exception_to_handler_error(err) + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + error=await self._handler_error_to_proto(handler_err), + ) + else: + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + completed=temporalio.api.nexus.v1.Response( + start_operation=start_response + ), + ) try: - completion = await run() await self._bridge_worker().complete_nexus_task(completion) except Exception: - temporalio.nexus.logger.exception("Failed completing Nexus operation") + temporalio.nexus.handler.logger.exception( + "Failed to send Nexus task completion" + ) finally: try: - del self._running_operations[task_token] + del self._running_tasks[task_token] except KeyError: - temporalio.nexus.logger.exception( + temporalio.nexus.handler.logger.exception( "Failed to remove completed Nexus operation" ) - async def _handle_cancel_operation( - self, request: temporalio.api.nexus.v1.CancelOperationRequest, task_token: bytes - ) -> None: - temporalio.nexus.handler._current_context.set( - temporalio.nexus.handler._Context( + async def _start_operation( + self, + start_request: temporalio.api.nexus.v1.StartOperationRequest, + headers: dict[str, str], + ) -> temporalio.api.nexus.v1.StartOperationResponse: + """ + Invoke the Nexus handler's start_operation method and construct the StartOperationResponse. + + OperationError is handled by this function, since it results in a StartOperationResponse. + + All other exceptions are handled by a caller of this function. + """ + ctx = StartOperationContext( + service=start_request.service, + operation=start_request.operation, + headers=headers, + request_id=start_request.request_id, + callback_url=start_request.callback, + inbound_links=[ + nexusrpc.Link(url=link.url, type=link.type) + for link in start_request.links + ], + callback_headers=dict(start_request.callback_header), + ) + TemporalNexusOperationContext.set( + TemporalNexusOperationContext( + nexus_operation_context=ctx, client=self._client, task_queue=self._task_queue, - service=request.service, - operation=request.operation, ) ) - ctx = nexusrpc.handler.CancelOperationContext( - service=request.service, - operation=request.operation, - ) - # TODO(nexus-prerelease): header - try: - await self._handler.cancel_operation(ctx, request.operation_token) - except Exception as err: - temporalio.nexus.logger.exception( - "Failed to execute Nexus operation cancel method", err - ) - # TODO(nexus-prerelease): when do we use ack_cancel? - completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( - task_token=task_token, - completed=temporalio.api.nexus.v1.Response( - cancel_operation=temporalio.api.nexus.v1.CancelOperationResponse() + input = LazyValue( + serializer=_DummyPayloadSerializer( + data_converter=self._data_converter, + payload=start_request.payload, ), + headers={}, + stream=None, ) try: - await self._bridge_worker().complete_nexus_task(completion) - except Exception as err: - temporalio.nexus.logger.exception( - "Failed to send Nexus task completion", err + result = await self._handler.start_operation(ctx, input) + if isinstance(result, nexusrpc.handler.StartOperationResultAsync): + return temporalio.api.nexus.v1.StartOperationResponse( + async_success=temporalio.api.nexus.v1.StartOperationResponse.Async( + operation_token=result.token, + links=[ + temporalio.api.nexus.v1.Link(url=link.url, type=link.type) + for link in ctx.outbound_links + ], + ) + ) + elif isinstance(result, nexusrpc.handler.StartOperationResultSync): + [payload] = await self._data_converter.encode([result.value]) + return temporalio.api.nexus.v1.StartOperationResponse( + sync_success=temporalio.api.nexus.v1.StartOperationResponse.Sync( + payload=payload + ) + ) + else: + raise _exception_to_handler_error( + TypeError( + "Operation start method must return either " + "nexusrpc.handler.StartOperationResultSync or " + "nexusrpc.handler.StartOperationResultAsync." + ) + ) + except nexusrpc.handler.OperationError as err: + return temporalio.api.nexus.v1.StartOperationResponse( + operation_error=await self._operation_error_to_proto(err), ) async def _exception_to_failure_proto( @@ -319,7 +335,6 @@ async def _exception_to_failure_proto( api_failure = temporalio.api.failure.v1.Failure() await self._data_converter.encode_failure(err, api_failure) api_failure = google.protobuf.json_format.MessageToDict(api_failure) - # TODO(nexus-prerelease): is metadata correct and playing intended role here? return temporalio.api.nexus.v1.Failure( message=api_failure.pop("message", ""), metadata={"type": "temporal.api.failure.v1.Failure"}, @@ -358,14 +373,14 @@ class _DummyPayloadSerializer: data_converter: temporalio.converter.DataConverter payload: temporalio.api.common.v1.Payload - async def serialize(self, value: Any) -> nexusrpc.handler.Content: + async def serialize(self, value: Any) -> nexusrpc.Content: raise NotImplementedError( "The serialize method of the Serializer is not used by handlers" ) async def deserialize( self, - content: nexusrpc.handler.Content, + content: nexusrpc.Content, as_type: Optional[Type[Any]] = None, ) -> Any: try: @@ -373,6 +388,7 @@ async def deserialize( [self.payload], type_hints=[as_type] if as_type else None, ) + return input except Exception as err: raise nexusrpc.handler.HandlerError( "Data converter failed to decode Nexus operation input", @@ -380,7 +396,6 @@ async def deserialize( cause=err, retryable=False, ) from err - return input # TODO(nexus-prerelease): tests for this function diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 66e1060f4..038ce38aa 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -107,7 +107,10 @@ def __init__( *, task_queue: str, activities: Sequence[Callable] = [], - nexus_services: Sequence[Any] = [], + # TODO(nexus-prerelease): for naming consistency this should be named + # nexus_service_handlers. That will prevent users from mistakenly trying to add + # their service definitions here. + nexus_service_handlers: Sequence[Any] = [], workflows: Sequence[Type] = [], activity_executor: Optional[concurrent.futures.Executor] = None, workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None, @@ -159,8 +162,8 @@ def __init__( activities: Activity callables decorated with :py:func:`@activity.defn`. Activities may be async functions or non-async functions. - nexus_services: Nexus service instances decorated with - :py:func:`@nexusrpc.handler.service_handler`. + nexus_service_handlers: Nexus service handler instances decorated with + :py:func:`@nexusrpc.handler.service_handler`. workflows: Workflow classes decorated with :py:func:`@workflow.defn`. activity_executor: Concurrent executor to use for non-async @@ -316,7 +319,7 @@ def __init__( # is issued. # max_concurrent_nexus_operations: Maximum number of Nexus operations that # will ever be given to the Nexus worker concurrently. Mutually exclusive with ``tuner``. - if not (activities or nexus_services or workflows): + if not (activities or nexus_service_handlers or workflows): raise ValueError( "At least one activity, Nexus service, or workflow must be specified" ) @@ -415,14 +418,14 @@ def __init__( metric_meter=self._runtime.metric_meter, ) self._nexus_worker: Optional[_NexusWorker] = None - if nexus_services: + if nexus_service_handlers: # TODO(nexus-prerelease): consider not allowing / warning on max_workers < # max_concurrent_nexus_operations? See warning above for activity worker. self._nexus_worker = _NexusWorker( bridge_worker=lambda: self._bridge_worker, client=client, task_queue=task_queue, - nexus_services=nexus_services, + service_handlers=nexus_service_handlers, data_converter=client_config["data_converter"], interceptors=interceptors, metric_meter=self._runtime.metric_meter, diff --git a/tests/conftest.py b/tests/conftest.py index f3baa1b72..48df7285e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -123,13 +123,13 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: ], dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION, ) + env._http_port = http_port # type: ignore elif env_type == "time-skipping": env = await WorkflowEnvironment.start_time_skipping() else: env = WorkflowEnvironment.from_client(await Client.connect(env_type)) # TODO(nexus-prerelease): expose this in a principled way - env._http_port = http_port # type: ignore yield env await env.shutdown() diff --git a/tests/helpers/nexus.py b/tests/helpers/nexus.py index 878111438..c1225136c 100644 --- a/tests/helpers/nexus.py +++ b/tests/helpers/nexus.py @@ -1,13 +1,11 @@ +from dataclasses import dataclass +from typing import Any, Mapping, Optional + +import httpx + import temporalio.api -import temporalio.api.common -import temporalio.api.common.v1 -import temporalio.api.enums.v1 -import temporalio.api.nexus import temporalio.api.nexus.v1 -import temporalio.api.operatorservice import temporalio.api.operatorservice.v1 -import temporalio.nexus -import temporalio.nexus.handler from temporalio.client import Client @@ -35,3 +33,62 @@ async def create_nexus_endpoint( ) ) ) + + +@dataclass +class ServiceClient: + server_address: str # E.g. http://127.0.0.1:7243 + endpoint: str + service: str + + async def start_operation( + self, + operation: str, + body: Optional[dict[str, Any]] = None, + headers: Mapping[str, str] = {}, + ) -> httpx.Response: + """ + Start a Nexus operation. + """ + async with httpx.AsyncClient() as http_client: + return await http_client.post( + f"{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}", + json=body, + headers=headers, + ) + + async def fetch_operation_info( + self, + operation: str, + token: str, + ) -> httpx.Response: + async with httpx.AsyncClient() as http_client: + return await http_client.get( + f"{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}", + # Token can also be sent as "Nexus-Operation-Token" header + params={"token": token}, + ) + + async def fetch_operation_result( + self, + operation: str, + token: str, + ) -> httpx.Response: + async with httpx.AsyncClient() as http_client: + return await http_client.get( + f"{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}/result", + # Token can also be sent as "Nexus-Operation-Token" header + params={"token": token}, + ) + + async def cancel_operation( + self, + operation: str, + token: str, + ) -> httpx.Response: + async with httpx.AsyncClient() as http_client: + return await http_client.post( + f"{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}/cancel", + # Token can also be sent as "Nexus-Operation-Token" header + params={"token": token}, + ) diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index da3925e80..c9c24a8f9 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -1,4 +1,5 @@ import uuid +from typing import Any import httpx import nexusrpc.handler @@ -32,7 +33,7 @@ def make_incrementer_user_service_definition_and_service_handler_classes( # service handler # async def _increment_op( - self, + self: Any, ctx: nexusrpc.handler.StartOperationContext, input: int, ) -> int: @@ -42,6 +43,7 @@ async def _increment_op( # TODO(nexus-prerelease): check that name=name should be required here. Should the op factory # name not default to the name of the method attribute (i.e. key), as opposed to # the name of the method object (i.e. value.__name__)? + # TODO(nexus-prerelease): type error name: nexusrpc.handler.sync_operation_handler(_increment_op, name=name) for name in op_names } @@ -71,7 +73,7 @@ async def test_dynamic_creation_of_user_handler_classes(client: Client): async with Worker( client, task_queue=task_queue, - nexus_services=[handler_cls()], + nexus_service_handlers=[handler_cls()], ): async with httpx.AsyncClient() as http_client: response = await http_client.post( diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 981df21e0..b54198050 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -18,6 +18,7 @@ import dataclasses import json import logging +import pprint import uuid from concurrent.futures.thread import ThreadPoolExecutor from dataclasses import dataclass @@ -29,19 +30,25 @@ import nexusrpc.handler import pytest from google.protobuf import json_format -from nexusrpc.testing.client import ServiceClient +from nexusrpc.handler import ( + CancelOperationContext, + StartOperationContext, +) import temporalio.api.failure.v1 import temporalio.nexus from temporalio import workflow from temporalio.client import Client, WorkflowHandle +from temporalio.common import WorkflowIDReusePolicy from temporalio.converter import FailureConverter, PayloadConverter from temporalio.exceptions import ApplicationError -from temporalio.nexus import logger -from temporalio.nexus.handler import start_workflow +from temporalio.nexus.handler import ( + logger, +) +from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers.nexus import create_nexus_endpoint +from tests.helpers.nexus import ServiceClient, create_nexus_endpoint HTTP_PORT = 7243 @@ -56,6 +63,19 @@ class Output: value: str +@dataclass +class NonSerializableOutput: + callable: Callable[[], Any] = lambda: None + + +@dataclass +class TestContext: + workflow_id: Optional[str] = None + + +test_context = TestContext() + + # TODO: type check nexus implementation under mypy # TODO(nexus-prerelease): test dynamic creation of a service from unsugared definition @@ -73,8 +93,8 @@ class MyService: # ) hang: nexusrpc.Operation[Input, Output] log: nexusrpc.Operation[Input, Output] - async_operation: nexusrpc.Operation[Input, Output] - async_operation_without_type_annotations: nexusrpc.Operation[Input, Output] + workflow_run_operation: nexusrpc.Operation[Input, Output] + workflow_run_operation_without_type_annotations: nexusrpc.Operation[Input, Output] sync_operation_without_type_annotations: nexusrpc.Operation[Input, Output] sync_operation_with_non_async_def: nexusrpc.Operation[Input, Output] sync_operation_with_non_async_callable_instance: nexusrpc.Operation[Input, Output] @@ -87,6 +107,8 @@ class MyService: workflow_run_op_link_test: nexusrpc.Operation[Input, Output] handler_error_internal: nexusrpc.Operation[Input, Output] operation_error_failed: nexusrpc.Operation[Input, Output] + idempotency_check: nexusrpc.Operation[None, Output] + non_serializable_output: nexusrpc.Operation[Input, NonSerializableOutput] @workflow.defn @@ -116,9 +138,7 @@ async def run(self, input: Input) -> Output: # The service_handler decorator is applied by the test class MyServiceHandler: @nexusrpc.handler.sync_operation_handler - async def echo( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> Output: + async def echo(self, ctx: StartOperationContext, input: Input) -> Output: assert ctx.headers["test-header-key"] == "test-header-value" ctx.outbound_links.extend(ctx.inbound_links) return Output( @@ -126,15 +146,13 @@ async def echo( ) @nexusrpc.handler.sync_operation_handler - async def hang( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> Output: + async def hang(self, ctx: StartOperationContext, input: Input) -> Output: await asyncio.Future() return Output(value="won't reach here") @nexusrpc.handler.sync_operation_handler async def non_retryable_application_error( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input + self, ctx: StartOperationContext, input: Input ) -> Output: raise ApplicationError( "non-retryable application error", @@ -146,7 +164,7 @@ async def non_retryable_application_error( @nexusrpc.handler.sync_operation_handler async def retryable_application_error( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input + self, ctx: StartOperationContext, input: Input ) -> Output: raise ApplicationError( "retryable application error", @@ -157,7 +175,7 @@ async def retryable_application_error( @nexusrpc.handler.sync_operation_handler async def handler_error_internal( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input + self, ctx: StartOperationContext, input: Input ) -> Output: raise nexusrpc.handler.HandlerError( message="deliberate internal handler error", @@ -168,7 +186,7 @@ async def handler_error_internal( @nexusrpc.handler.sync_operation_handler async def operation_error_failed( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input + self, ctx: StartOperationContext, input: Input ) -> Output: raise nexusrpc.handler.OperationError( message="deliberate operation error", @@ -177,7 +195,7 @@ async def operation_error_failed( @nexusrpc.handler.sync_operation_handler async def check_operation_timeout_header( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input + self, ctx: StartOperationContext, input: Input ) -> Output: assert "operation-timeout" in ctx.headers return Output( @@ -185,27 +203,26 @@ async def check_operation_timeout_header( ) @nexusrpc.handler.sync_operation_handler - async def log( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> Output: + async def log(self, ctx: StartOperationContext, input: Input) -> Output: logger.info("Logging from start method", extra={"input_value": input.value}) return Output(value=f"logged: {input.value}") @temporalio.nexus.handler.workflow_run_operation_handler - async def async_operation( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input + async def workflow_run_operation( + self, ctx: StartOperationContext, input: Input ) -> WorkflowHandle[Any, Output]: - assert "operation-timeout" in ctx.headers - return await start_workflow( - ctx, + tctx = TemporalNexusOperationContext.current() + return await tctx.client.start_workflow( MyWorkflow.run, input, - id=str(uuid.uuid4()), + id=test_context.workflow_id or str(uuid.uuid4()), + task_queue=tctx.task_queue, + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) @nexusrpc.handler.sync_operation_handler def sync_operation_with_non_async_def( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input + self, ctx: StartOperationContext, input: Input ) -> Output: return Output( value=f"from start method on {self.__class__.__name__}: {input.value}" @@ -215,7 +232,7 @@ class sync_operation_with_non_async_callable_instance: def __call__( self, _handler: "MyServiceHandler", - ctx: nexusrpc.handler.StartOperationContext, + ctx: StartOperationContext, input: Input, ) -> Output: return Output( @@ -238,28 +255,31 @@ async def sync_operation_without_type_annotations(self, ctx, input): ) @temporalio.nexus.handler.workflow_run_operation_handler - async def async_operation_without_type_annotations(self, ctx, input): - return await start_workflow( - ctx, + async def workflow_run_operation_without_type_annotations(self, ctx, input): + tctx = TemporalNexusOperationContext.current() + return await tctx.client.start_workflow( WorkflowWithoutTypeAnnotations.run, input, - id=str(uuid.uuid4()), + id=test_context.workflow_id or str(uuid.uuid4()), + task_queue=tctx.task_queue, ) @temporalio.nexus.handler.workflow_run_operation_handler async def workflow_run_op_link_test( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input + self, ctx: StartOperationContext, input: Input ) -> WorkflowHandle[Any, Output]: assert any( link.url == "http://inbound-link/" for link in ctx.inbound_links ), "Inbound link not found" assert ctx.request_id == "test-request-id-123", "Request ID mismatch" ctx.outbound_links.extend(ctx.inbound_links) - return await start_workflow( - ctx, + + tctx = TemporalNexusOperationContext.current() + return await tctx.client.start_workflow( MyLinkTestWorkflow.run, input, - id=f"link-test-{uuid.uuid4()}", + id=test_context.workflow_id or str(uuid.uuid4()), + task_queue=tctx.task_queue, ) class OperationHandlerReturningUnwrappedResult( @@ -267,7 +287,7 @@ class OperationHandlerReturningUnwrappedResult( ): async def start( self, - ctx: nexusrpc.handler.StartOperationContext, + ctx: StartOperationContext, input: Input, # This return type is a type error, but VSCode doesn't flag it unless # "python.analysis.typeCheckingMode" is set to "strict" @@ -282,18 +302,39 @@ def operation_returning_unwrapped_result_at_runtime_error( ) -> nexusrpc.handler.OperationHandler[Input, Output]: return MyServiceHandler.OperationHandlerReturningUnwrappedResult() + @nexusrpc.handler.sync_operation_handler + async def idempotency_check( + self, ctx: nexusrpc.handler.StartOperationContext, input: None + ) -> Output: + return Output(value=f"request_id: {ctx.request_id}") + + @nexusrpc.handler.sync_operation_handler + async def non_serializable_output( + self, ctx: StartOperationContext, input: Input + ) -> NonSerializableOutput: + return NonSerializableOutput() + @dataclass class Failure: + """A Nexus Failure object, with details parsed into an exception. + + https://github.com/nexus-rpc/api/blob/main/SPEC.md#failure + """ + message: str = "" metadata: Optional[dict[str, str]] = None details: Optional[dict[str, Any]] = None - exception: Optional[BaseException] = dataclasses.field(init=False, default=None) + exception_from_details: Optional[BaseException] = dataclasses.field( + init=False, default=None + ) def __post_init__(self) -> None: if self.metadata and (error_type := self.metadata.get("type")): - self.exception = self._instantiate_exception(error_type, self.details) + self.exception_from_details = self._instantiate_exception( + error_type, self.details + ) def _instantiate_exception( self, error_type: str, details: Optional[dict[str, Any]] @@ -334,6 +375,8 @@ class UnsuccessfulResponse: # Expected value of Nexus-Request-Retryable header retryable_header: Optional[bool] failure_message: Union[str, Callable[[str], bool]] + # Is the Nexus Failure expected to have the details field populated? + failure_details: bool = True # Expected value of inverse of non_retryable attribute of exception. retryable_exception: bool = True # TODO(nexus-prerelease): the body of a successful response need not be JSON; test non-JSON-parseable string @@ -358,7 +401,8 @@ def check_response( ) -> None: assert response.status_code == cls.expected.status_code, ( f"expected status code {cls.expected.status_code} " - f"but got {response.status_code} for response content {response.content.decode()}" + f"but got {response.status_code} for response content" + f"{pprint.pformat(response.content.decode())}" ) if not with_service_definition and cls.expected_without_service_definition: expected = cls.expected_without_service_definition @@ -397,13 +441,25 @@ def check_response( else: assert cls.expected.retryable_header is None - if failure.exception: - assert isinstance(failure.exception, ApplicationError) - assert failure.exception.non_retryable == ( + if cls.expected.failure_details: + assert ( + failure.exception_from_details is not None + ), "Expected exception details, but found none." + assert isinstance(failure.exception_from_details, ApplicationError) + + exception_from_failure_details = failure.exception_from_details + if ( + exception_from_failure_details.type == "HandlerError" + and exception_from_failure_details.__cause__ + ): + exception_from_failure_details = ( + exception_from_failure_details.__cause__ + ) + assert isinstance(exception_from_failure_details, ApplicationError) + + assert exception_from_failure_details.non_retryable == ( not cls.expected.retryable_exception ) - else: - print(f"TODO(dan): {cls} did not yield a Failure with exception details") class SyncHandlerHappyPath(_TestCase): @@ -470,7 +526,7 @@ class SyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): class AsyncHandlerHappyPath(_TestCase): - operation = "async_operation" + operation = "workflow_run_operation" input = Input("hello") headers = {"Operation-Timeout": "777s"} expected = SuccessfulResponse( @@ -479,7 +535,7 @@ class AsyncHandlerHappyPath(_TestCase): class AsyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): - operation = "async_operation_without_type_annotations" + operation = "workflow_run_operation_without_type_annotations" input = Input("hello") expected = SuccessfulResponse( status_code=201, @@ -527,7 +583,7 @@ class OperationHandlerReturningUnwrappedResultError(_FailureTestCase): retryable_header=False, failure_message=( "Operation start method must return either " - "nexusrpc.handler.StartOperationResultSync or nexusrpc.handler.StartOperationResultAsync" + "nexusrpc.handler.StartOperationResultSync or nexusrpc.handler.StartOperationResultAsync." ), ) @@ -541,6 +597,7 @@ class UpstreamTimeoutViaRequestTimeout(_FailureTestCase): retryable_header=None, # This error is returned by the server; it doesn't populate metadata or details, and it # doesn't set temporal-nexus-failure-source. + failure_details=False, failure_message="upstream timeout", headers={ "content-type": "application/json", @@ -565,18 +622,14 @@ class BadRequest(_FailureTestCase): expected = UnsuccessfulResponse( status_code=400, retryable_header=False, - failure_message=lambda s: s.startswith("Failed converting field"), + failure_message=lambda s: s.startswith( + "Data converter failed to decode Nexus operation input" + ), ) -class NonRetryableApplicationError(_FailureTestCase): - operation = "non_retryable_application_error" - expected = UnsuccessfulResponse( - status_code=500, - retryable_header=False, - retryable_exception=False, - failure_message="non-retryable application error", - ) +class _ApplicationErrorTestCase(_FailureTestCase): + """Test cases in which the operation raises an ApplicationError.""" @classmethod def check_response( @@ -584,14 +637,25 @@ def check_response( ) -> None: super().check_response(response, with_service_definition) failure = Failure(**response.json()) - err = failure.exception + assert failure.exception_from_details + assert isinstance(failure.exception_from_details, ApplicationError) + err = failure.exception_from_details.__cause__ assert isinstance(err, ApplicationError) - assert err.non_retryable assert err.type == "TestFailureType" assert err.details == ("details arg",) -class RetryableApplicationError(_FailureTestCase): +class NonRetryableApplicationError(_ApplicationErrorTestCase): + operation = "non_retryable_application_error" + expected = UnsuccessfulResponse( + status_code=500, + retryable_header=False, + retryable_exception=False, + failure_message="non-retryable application error", + ) + + +class RetryableApplicationError(_ApplicationErrorTestCase): operation = "retryable_application_error" expected = UnsuccessfulResponse( status_code=500, @@ -606,7 +670,7 @@ class HandlerErrorInternal(_FailureTestCase): status_code=500, # TODO(nexus-prerelease): check this assertion retryable_header=False, - failure_message="cause message", + failure_message="deliberate internal handler error", ) @@ -642,6 +706,15 @@ class UnknownOperation(_FailureTestCase): ) +class NonSerializableOutputFailure(_FailureTestCase): + operation = "non_serializable_output" + expected = UnsuccessfulResponse( + status_code=500, + retryable_header=False, + failure_message="Object of type function is not JSON serializable", + ) + + @pytest.mark.parametrize( "test_case", [ @@ -676,6 +749,7 @@ async def test_start_operation_happy_path( HandlerErrorInternal, UnknownService, UnknownOperation, + NonSerializableOutputFailure, ], ) async def test_start_operation_protocol_level_failures( @@ -708,7 +782,7 @@ async def _test_start_operation( task_queue = str(uuid.uuid4()) endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id service_client = ServiceClient( - server_address=f"http://127.0.0.1:{env._http_port}", # type: ignore + server_address=server_address(env), endpoint=endpoint, service=( test_case.service_defn @@ -727,7 +801,7 @@ async def _test_start_operation( async with Worker( env.client, task_queue=task_queue, - nexus_services=[service_handler], + nexus_service_handlers=[service_handler], nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), ): response = await service_client.start_operation( @@ -745,7 +819,7 @@ async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: A resp = await create_nexus_endpoint(task_queue, env.client) endpoint = resp.endpoint.id service_client = ServiceClient( - server_address=f"http://127.0.0.1:{env._http_port}", # type: ignore + server_address=server_address(env), endpoint=endpoint, service=service_name, ) @@ -754,7 +828,7 @@ async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: A async with Worker( env.client, task_queue=task_queue, - nexus_services=[MyServiceHandler()], + nexus_service_handlers=[MyServiceHandler()], nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), ): response = await service_client.start_operation( @@ -774,7 +848,7 @@ async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: A ( record for record in caplog.records - if record.name == "temporalio.nexus" + if record.name == "temporalio.nexus.handler" and record.getMessage() == "Logging from start method" ), None, @@ -813,7 +887,7 @@ class EchoService: @nexusrpc.handler.service_handler(service=EchoService) class SyncStartHandler: @nexusrpc.handler.sync_operation_handler - def echo(self, ctx: nexusrpc.handler.StartOperationContext, input: Input) -> Output: + def echo(self, ctx: StartOperationContext, input: Input) -> Output: assert ctx.headers["test-header-key"] == "test-header-value" ctx.outbound_links.extend(ctx.inbound_links) return Output( @@ -824,9 +898,7 @@ def echo(self, ctx: nexusrpc.handler.StartOperationContext, input: Input) -> Out @nexusrpc.handler.service_handler(service=EchoService) class DefaultCancelHandler: @nexusrpc.handler.sync_operation_handler - async def echo( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> Output: + async def echo(self, ctx: StartOperationContext, input: Input) -> Output: return Output( value=f"from start method on {self.__class__.__name__}: {input.value}" ) @@ -837,7 +909,7 @@ class SyncCancelHandler: class SyncCancel(nexusrpc.handler.SyncOperationHandler[Input, Output]): async def start( self, - ctx: nexusrpc.handler.StartOperationContext, + ctx: StartOperationContext, input: Input, # This return type is a type error, but VSCode doesn't flag it unless # "python.analysis.typeCheckingMode" is set to "strict" @@ -846,9 +918,7 @@ async def start( # or StartOperationResultAsync return Output(value="Hello") # type: ignore - def cancel( - self, ctx: nexusrpc.handler.CancelOperationContext, token: str - ) -> Output: + def cancel(self, ctx: CancelOperationContext, token: str) -> Output: return Output(value="Hello") # type: ignore @nexusrpc.handler.operation_handler @@ -890,7 +960,7 @@ async def test_handler_instantiation( Worker( client, task_queue=task_queue, - nexus_services=[test_case.handler()], + nexus_service_handlers=[test_case.handler()], nexus_task_executor=ThreadPoolExecutor() if test_case.executor else None, @@ -899,6 +969,123 @@ async def test_handler_instantiation( Worker( client, task_queue=task_queue, - nexus_services=[test_case.handler()], + nexus_service_handlers=[test_case.handler()], nexus_task_executor=ThreadPoolExecutor() if test_case.executor else None, ) + + +async def test_cancel_operation_with_invalid_token(env: WorkflowEnvironment): + """Verify that canceling an operation with an invalid token fails correctly.""" + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=MyService.__name__, + ) + + decorator = nexusrpc.handler.service_handler(service=MyService) + service_handler = decorator(MyServiceHandler)() + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[service_handler], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + cancel_response = await service_client.cancel_operation( + "workflow_run_operation", + token="this-is-not-a-valid-token", + ) + assert cancel_response.status_code == 404 + failure = Failure(**cancel_response.json()) + assert "failed to decode workflow operation token" in failure.message.lower() + + +async def test_request_id_is_received_by_sync_operation_handler( + env: WorkflowEnvironment, +): + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=MyService.__name__, + ) + + decorator = nexusrpc.handler.service_handler(service=MyService) + service_handler = decorator(MyServiceHandler)() + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[service_handler], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + request_id = str(uuid.uuid4()) + resp = await service_client.start_operation( + "idempotency_check", None, {"Nexus-Request-Id": request_id} + ) + assert resp.status_code == 200 + assert resp.json() == {"value": f"request_id: {request_id}"} + + +async def test_request_id_becomes_start_workflow_request_id(env: WorkflowEnvironment): + # We send two Nexus requests that would start a workflow with the same workflow ID, + # using reuse_policy=REJECT_DUPLICATE. This would fail if they used different + # request IDs. However, when we use the same request ID, it does not fail, + # demonstrating that the Nexus Start Operation request ID has become the + # StartWorkflow request ID. + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=MyService.__name__, + ) + + decorator = nexusrpc.handler.service_handler(service=MyService) + service_handler = decorator(MyServiceHandler)() + + async def start_two_workflows_with_conflicting_workflow_ids( + request_ids: tuple[tuple[str, int], tuple[str, int]], + ): + test_context.workflow_id = str(uuid.uuid4()) + for request_id, status_code in request_ids: + resp = await service_client.start_operation( + "workflow_run_operation", + dataclass_as_dict(Input("")), + {"Nexus-Request-Id": request_id}, + ) + assert resp.status_code == status_code, ( + f"expected status code {status_code} " + f"but got {resp.status_code} for response content " + f"{pprint.pformat(resp.content.decode())}" + ) + if status_code == 201: + op_info = resp.json() + assert op_info["token"] + assert op_info["state"] == nexusrpc.OperationState.RUNNING.value + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[service_handler], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + request_id_1, request_id_2 = str(uuid.uuid4()), str(uuid.uuid4()) + # Reusing the same request ID does not fail + await start_two_workflows_with_conflicting_workflow_ids( + ((request_id_1, 201), (request_id_1, 201)) + ) + # Using a different request ID does fail + # TODO(nexus-prerelease) I think that this should be a 409 per the spec. Go and + # Java are not doing that. + await start_two_workflows_with_conflicting_workflow_ids( + ((request_id_1, 201), (request_id_2, 500)) + ) + + +def server_address(env: WorkflowEnvironment) -> str: + http_port = getattr(env, "_http_port", 7243) + return f"http://127.0.0.1:{http_port}" diff --git a/tests/nexus/test_handler_async_operation.py b/tests/nexus/test_handler_async_operation.py index dc7fc0dec..bfe850cbb 100644 --- a/tests/nexus/test_handler_async_operation.py +++ b/tests/nexus/test_handler_async_operation.py @@ -24,11 +24,10 @@ StartOperationContext, StartOperationResultAsync, ) -from nexusrpc.testing.client import ServiceClient from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers.nexus import create_nexus_endpoint +from tests.helpers.nexus import ServiceClient, create_nexus_endpoint @dataclass @@ -155,7 +154,7 @@ async def test_async_operation_lifecycle( async with Worker( env.client, task_queue=task_queue, - nexus_services=[service_handler_cls(task_executor)], + nexus_service_handlers=[service_handler_cls(task_executor)], nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), ): start_response = await service_client.start_operation( @@ -209,16 +208,16 @@ def add_task_sync(self, task_id: str, coro: Coroutine[Any, Any, Any]) -> None: self.add_task(task_id, coro), self.event_loop ).result() - def get_task_status(self, task_id: str) -> nexusrpc.handler.OperationState: + def get_task_status(self, task_id: str) -> nexusrpc.OperationState: task = self.tasks[task_id] if not task.done(): - return nexusrpc.handler.OperationState.RUNNING + return nexusrpc.OperationState.RUNNING elif task.cancelled(): - return nexusrpc.handler.OperationState.CANCELED + return nexusrpc.OperationState.CANCELED elif task.exception(): - return nexusrpc.handler.OperationState.FAILED + return nexusrpc.OperationState.FAILED else: - return nexusrpc.handler.OperationState.SUCCEEDED + return nexusrpc.OperationState.SUCCEEDED async def get_task_result(self, task_id: str) -> Any: """ diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index fce864f20..51b4e66d3 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -38,7 +38,8 @@ async def workflow_run_operation_handler( ) -> WorkflowHandle[Any, Output]: ... expected_operations = { - "workflow_run_operation_handler": nexusrpc.Operation._create( + "workflow_run_operation_handler": nexusrpc.Operation( + name="workflow_run_operation_handler", method_name="workflow_run_operation_handler", input_type=Input, output_type=Output, @@ -66,7 +67,7 @@ async def workflow_run_operation_with_name_override( ) -> WorkflowHandle[Any, Output]: ... expected_operations = { - "workflow_run_operation_with_name_override": nexusrpc.Operation._create( + "workflow_run_operation_with_name_override": nexusrpc.Operation( name="operation-name", method_name="workflow_run_operation_with_name_override", input_type=Input, diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index cee54d4b7..fb00316d3 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -2,11 +2,16 @@ import uuid from dataclasses import dataclass from enum import IntEnum -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union import nexusrpc import nexusrpc.handler import pytest +from nexusrpc.handler import ( + CancelOperationContext, + FetchOperationInfoContext, + StartOperationContext, +) import temporalio.api import temporalio.api.common @@ -24,11 +29,12 @@ WithStartWorkflowOperation, WorkflowExecutionStatus, WorkflowFailureError, + WorkflowHandle, ) from temporalio.common import WorkflowIDConflictPolicy from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError -from temporalio.nexus.handler import WorkflowHandle -from temporalio.nexus.token import WorkflowOperationToken +from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext +from temporalio.nexus.handler._token import WorkflowOperationToken from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -88,7 +94,6 @@ class OpInput: @dataclass class OpOutput: value: str - start_options_received_by_handler: Optional[nexusrpc.handler.StartOperationContext] @dataclass @@ -99,7 +104,6 @@ class HandlerWfInput: @dataclass class HandlerWfOutput: value: str - start_options_received_by_handler: Optional[nexusrpc.handler.StartOperationContext] @nexusrpc.service @@ -120,14 +124,12 @@ class HandlerWorkflow: async def run( self, input: HandlerWfInput, - start_options_received_by_handler: nexusrpc.handler.StartOperationContext, ) -> HandlerWfOutput: assert isinstance(input.op_input.response_type, AsyncResponse) if input.op_input.response_type.block_forever_waiting_for_cancellation: await asyncio.Future() return HandlerWfOutput( value="workflow result", - start_options_received_by_handler=start_options_received_by_handler, ) @@ -136,7 +138,7 @@ async def run( class SyncOrAsyncOperation(nexusrpc.handler.OperationHandler[OpInput, OpOutput]): async def start( - self, ctx: nexusrpc.handler.StartOperationContext, input: OpInput + self, ctx: StartOperationContext, input: OpInput ) -> Union[ nexusrpc.handler.StartOperationResultSync[OpOutput], nexusrpc.handler.StartOperationResultAsync, @@ -150,17 +152,15 @@ async def start( ) if isinstance(input.response_type, SyncResponse): return nexusrpc.handler.StartOperationResultSync( - value=OpOutput( - value="sync response", - start_options_received_by_handler=ctx, - ) + value=OpOutput(value="sync response") ) elif isinstance(input.response_type, AsyncResponse): - wf_handle = await temporalio.nexus.handler.start_workflow( - ctx, + tctx = TemporalNexusOperationContext.current() + wf_handle = await tctx.client.start_workflow( HandlerWorkflow.run, - args=[HandlerWfInput(op_input=input), ctx], + args=[HandlerWfInput(op_input=input)], id=input.response_type.operation_workflow_id, + task_queue=tctx.task_queue, ) return nexusrpc.handler.StartOperationResultAsync( WorkflowOperationToken.from_workflow_handle(wf_handle).encode() @@ -168,13 +168,11 @@ async def start( else: raise TypeError - async def cancel( - self, ctx: nexusrpc.handler.CancelOperationContext, token: str - ) -> None: + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: return await temporalio.nexus.handler.cancel_workflow(ctx, token) async def fetch_info( - self, ctx: nexusrpc.handler.FetchOperationInfoContext, token: str + self, ctx: FetchOperationInfoContext, token: str ) -> nexusrpc.handler.OperationInfo: raise NotImplementedError @@ -194,7 +192,7 @@ def sync_or_async_operation( @nexusrpc.handler.sync_operation_handler async def sync_operation( - self, ctx: nexusrpc.handler.StartOperationContext, input: OpInput + self, ctx: StartOperationContext, input: OpInput ) -> OpOutput: assert isinstance(input.response_type, SyncResponse) if input.response_type.exception_in_operation_start: @@ -203,14 +201,11 @@ async def sync_operation( RPCStatusCode.INVALID_ARGUMENT, b"", ) - return OpOutput( - value="sync response", - start_options_received_by_handler=ctx, - ) + return OpOutput(value="sync response") @temporalio.nexus.handler.workflow_run_operation_handler async def async_operation( - self, ctx: nexusrpc.handler.StartOperationContext, input: OpInput + self, ctx: StartOperationContext, input: OpInput ) -> WorkflowHandle[HandlerWorkflow, HandlerWfOutput]: assert isinstance(input.response_type, AsyncResponse) if input.response_type.exception_in_operation_start: @@ -219,11 +214,12 @@ async def async_operation( RPCStatusCode.INVALID_ARGUMENT, b"", ) - return await temporalio.nexus.handler.start_workflow( - ctx, + tctx = TemporalNexusOperationContext.current() + return await tctx.client.start_workflow( HandlerWorkflow.run, - args=[HandlerWfInput(op_input=input), ctx], + args=[HandlerWfInput(op_input=input)], id=input.response_type.operation_workflow_id, + task_queue=tctx.task_queue, ) @@ -295,12 +291,7 @@ async def run( # transition doesn't happen until the handle is awaited. assert op_handle.cancel() op_output = await op_handle - return CallerWfOutput( - op_output=OpOutput( - value=op_output.value, - start_options_received_by_handler=op_output.start_options_received_by_handler, - ) - ) + return CallerWfOutput(op_output=OpOutput(value=op_output.value)) @workflow.update async def wait_nexus_operation_started(self) -> None: @@ -420,12 +411,7 @@ async def run( headers=op_input.headers, output_type=OpOutput, ) - return CallerWfOutput( - op_output=OpOutput( - value=op_output.value, - start_options_received_by_handler=op_output.start_options_received_by_handler, - ) - ) + return CallerWfOutput(op_output=OpOutput(value=op_output.value)) # ----------------------------------------------------------------------------- @@ -455,7 +441,7 @@ async def test_sync_response( task_queue = str(uuid.uuid4()) async with Worker( client, - nexus_services=[ServiceImpl()], + nexus_service_handlers=[ServiceImpl()], workflows=[CallerWorkflow, HandlerWorkflow], task_queue=task_queue, # TODO(dan): enable sandbox @@ -508,7 +494,6 @@ async def test_sync_response( else: result = await caller_wf_handle.result() assert result.op_output.value == "sync response" - assert result.op_output.start_options_received_by_handler @pytest.mark.parametrize("exception_in_operation_start", [False, True]) @@ -531,7 +516,7 @@ async def test_async_response( task_queue = str(uuid.uuid4()) async with Worker( client, - nexus_services=[ServiceImpl()], + nexus_service_handlers=[ServiceImpl()], workflows=[CallerWorkflow, HandlerWorkflow], task_queue=task_queue, workflow_runner=UnsandboxedWorkflowRunner(), @@ -608,7 +593,6 @@ async def test_async_response( assert handler_wf_info.status == WorkflowExecutionStatus.COMPLETED result = await caller_wf_handle.result() assert result.op_output.value == "workflow result" - assert result.op_output.start_options_received_by_handler async def _start_wf_and_nexus_op( @@ -689,7 +673,7 @@ async def test_untyped_caller( async with Worker( client, workflows=[UntypedCallerWorkflow, HandlerWorkflow], - nexus_services=[ServiceImpl()], + nexus_service_handlers=[ServiceImpl()], task_queue=task_queue, workflow_runner=UnsandboxedWorkflowRunner(), workflow_failure_exception_types=[Exception], @@ -738,7 +722,6 @@ async def test_untyped_caller( if isinstance(response_type, SyncResponse) else "workflow result" ) - assert result.op_output.start_options_received_by_handler # @@ -825,6 +808,7 @@ async def run( elif (caller_reference, name_override) == (C.IMPL_WITH_INTERFACE, N.NO): service_cls = ServiceImplInterfaceWithoutNameOverride elif (caller_reference, name_override) == (C.IMPL_WITHOUT_INTERFACE, N.NO): + service_cls = ServiceImplInterfaceWithNameOverride service_cls = ServiceImplInterfaceWithNeitherInterfaceNorNameOverride else: raise ValueError( @@ -861,7 +845,7 @@ async def test_service_interface_and_implementation_names(client: Client): task_queue = str(uuid.uuid4()) async with Worker( client, - nexus_services=[ + nexus_service_handlers=[ ServiceImplWithNameOverride(), ServiceImplInterfaceWithNameOverride(), ServiceImplInterfaceWithoutNameOverride(), From 7692510f6ea421f01478004308d163c07fa4cf3f Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 21 Jun 2025 13:46:16 -0400 Subject: [PATCH 004/183] Option 1 for workflow_run_operation_handler Pass nexus_operation to start_workflow. But this seems problematic. The user may forget. But it seems that we have no way to prevent them from forgetting since as far as start_workflow knows this may be a valid start_workflow call in a Nexus operation handler, precdeding the "final" / "backing" start_workflow call. --- temporalio/client.py | 32 +++++++++++++++++++++-------- tests/nexus/test_workflow_caller.py | 2 ++ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 5ab8b7c0b..6bfeeb253 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -318,6 +318,7 @@ async def start_workflow( ] = None, static_summary: Optional[str] = None, static_details: Optional[str] = None, + nexus_operation: Optional[TemporalNexusOperationContext] = None, start_delay: Optional[timedelta] = None, start_signal: Optional[str] = None, start_signal_args: Sequence[Any] = [], @@ -353,6 +354,7 @@ async def start_workflow( ] = None, static_summary: Optional[str] = None, static_details: Optional[str] = None, + nexus_operation: Optional[TemporalNexusOperationContext] = None, start_delay: Optional[timedelta] = None, start_signal: Optional[str] = None, start_signal_args: Sequence[Any] = [], @@ -390,6 +392,7 @@ async def start_workflow( ] = None, static_summary: Optional[str] = None, static_details: Optional[str] = None, + nexus_operation: Optional[TemporalNexusOperationContext] = None, start_delay: Optional[timedelta] = None, start_signal: Optional[str] = None, start_signal_args: Sequence[Any] = [], @@ -427,6 +430,7 @@ async def start_workflow( ] = None, static_summary: Optional[str] = None, static_details: Optional[str] = None, + nexus_operation: Optional[TemporalNexusOperationContext] = None, start_delay: Optional[timedelta] = None, start_signal: Optional[str] = None, start_signal_args: Sequence[Any] = [], @@ -462,6 +466,7 @@ async def start_workflow( ] = None, static_summary: Optional[str] = None, static_details: Optional[str] = None, + nexus_operation: Optional[TemporalNexusOperationContext] = None, start_delay: Optional[timedelta] = None, start_signal: Optional[str] = None, start_signal_args: Sequence[Any] = [], @@ -505,6 +510,13 @@ async def start_workflow( UI/CLI. This can be in Temporal markdown format and can span multiple lines. This is a fixed value on the workflow that cannot be updated. For details that can be updated, use :py:meth:`temporalio.workflow.get_current_details` within the workflow. + nexus_operation: An optional + :py:class:`temporalio.nexus.handler.TemporalNexusOperationContext`. If supplied, + it means that the started workflow is backing that Nexus operation. This means that + the workflow result is the Nexus operation result, and will be delivered to the Nexus + caller on workflow completion, and that Nexus bidirectional links will be established + between the caller and the workflow. Do not supply this argument if the workflow is + not backing a Nexus operation. start_delay: Amount of time to wait before starting the workflow. This does not work with ``cron_schedule``. start_signal: If present, this signal is sent as signal-with-start @@ -535,15 +547,17 @@ async def start_workflow( temporalio.workflow._Definition.get_name_and_result_type(workflow) ) nexus_start_ctx = None - if nexus_ctx := TemporalNexusOperationContext.try_current(): - # TODO(prerelease): I think this is too magical: what if a user implements a - # nexus handler by running one workflow to completion, and then starting a - # second workflow to act as the async operation itself? - # TODO(prerelease): What do we do if the Temporal Nexus context client - # (namespace) is not the same as the one being used to start this workflow? - if nexus_start_ctx := nexus_ctx.temporal_nexus_start_operation_context: - nexus_completion_callbacks = nexus_start_ctx.get_completion_callbacks() - workflow_event_links = nexus_start_ctx.get_workflow_event_links() + if nexus_operation: + # TODO(prerelease): check what sdk-typescript does regarding workflow + # options for workflows being started by Nexus operations. + # https://github.com/temporalio/sdk-typescript/blob/nexus/packages/nexus/src/context.ts#L96 + nexus_start_ctx = nexus_operation.temporal_nexus_start_operation_context + if not nexus_start_ctx: + raise RuntimeError( + f"Nexus operation context {nexus_operation} is not a start operation context" + ) + nexus_completion_callbacks = nexus_start_ctx.get_completion_callbacks() + workflow_event_links = nexus_start_ctx.get_workflow_event_links() else: nexus_completion_callbacks = [] workflow_event_links = [] diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index fb00316d3..210bcc3c0 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -161,6 +161,7 @@ async def start( args=[HandlerWfInput(op_input=input)], id=input.response_type.operation_workflow_id, task_queue=tctx.task_queue, + nexus_operation=tctx, ) return nexusrpc.handler.StartOperationResultAsync( WorkflowOperationToken.from_workflow_handle(wf_handle).encode() @@ -220,6 +221,7 @@ async def async_operation( args=[HandlerWfInput(op_input=input)], id=input.response_type.operation_workflow_id, task_queue=tctx.task_queue, + nexus_operation=tctx, ) From 54dc18864abb2cf3b2b2ba113b2214cd71cfd808 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 21 Jun 2025 13:46:32 -0400 Subject: [PATCH 005/183] Revert "Option 1 for workflow_run_operation_handler" This reverts commit 334b889fd196d1412b264c9197762645027cea22. --- temporalio/client.py | 32 ++++++++--------------------- tests/nexus/test_workflow_caller.py | 2 -- 2 files changed, 9 insertions(+), 25 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 6bfeeb253..5ab8b7c0b 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -318,7 +318,6 @@ async def start_workflow( ] = None, static_summary: Optional[str] = None, static_details: Optional[str] = None, - nexus_operation: Optional[TemporalNexusOperationContext] = None, start_delay: Optional[timedelta] = None, start_signal: Optional[str] = None, start_signal_args: Sequence[Any] = [], @@ -354,7 +353,6 @@ async def start_workflow( ] = None, static_summary: Optional[str] = None, static_details: Optional[str] = None, - nexus_operation: Optional[TemporalNexusOperationContext] = None, start_delay: Optional[timedelta] = None, start_signal: Optional[str] = None, start_signal_args: Sequence[Any] = [], @@ -392,7 +390,6 @@ async def start_workflow( ] = None, static_summary: Optional[str] = None, static_details: Optional[str] = None, - nexus_operation: Optional[TemporalNexusOperationContext] = None, start_delay: Optional[timedelta] = None, start_signal: Optional[str] = None, start_signal_args: Sequence[Any] = [], @@ -430,7 +427,6 @@ async def start_workflow( ] = None, static_summary: Optional[str] = None, static_details: Optional[str] = None, - nexus_operation: Optional[TemporalNexusOperationContext] = None, start_delay: Optional[timedelta] = None, start_signal: Optional[str] = None, start_signal_args: Sequence[Any] = [], @@ -466,7 +462,6 @@ async def start_workflow( ] = None, static_summary: Optional[str] = None, static_details: Optional[str] = None, - nexus_operation: Optional[TemporalNexusOperationContext] = None, start_delay: Optional[timedelta] = None, start_signal: Optional[str] = None, start_signal_args: Sequence[Any] = [], @@ -510,13 +505,6 @@ async def start_workflow( UI/CLI. This can be in Temporal markdown format and can span multiple lines. This is a fixed value on the workflow that cannot be updated. For details that can be updated, use :py:meth:`temporalio.workflow.get_current_details` within the workflow. - nexus_operation: An optional - :py:class:`temporalio.nexus.handler.TemporalNexusOperationContext`. If supplied, - it means that the started workflow is backing that Nexus operation. This means that - the workflow result is the Nexus operation result, and will be delivered to the Nexus - caller on workflow completion, and that Nexus bidirectional links will be established - between the caller and the workflow. Do not supply this argument if the workflow is - not backing a Nexus operation. start_delay: Amount of time to wait before starting the workflow. This does not work with ``cron_schedule``. start_signal: If present, this signal is sent as signal-with-start @@ -547,17 +535,15 @@ async def start_workflow( temporalio.workflow._Definition.get_name_and_result_type(workflow) ) nexus_start_ctx = None - if nexus_operation: - # TODO(prerelease): check what sdk-typescript does regarding workflow - # options for workflows being started by Nexus operations. - # https://github.com/temporalio/sdk-typescript/blob/nexus/packages/nexus/src/context.ts#L96 - nexus_start_ctx = nexus_operation.temporal_nexus_start_operation_context - if not nexus_start_ctx: - raise RuntimeError( - f"Nexus operation context {nexus_operation} is not a start operation context" - ) - nexus_completion_callbacks = nexus_start_ctx.get_completion_callbacks() - workflow_event_links = nexus_start_ctx.get_workflow_event_links() + if nexus_ctx := TemporalNexusOperationContext.try_current(): + # TODO(prerelease): I think this is too magical: what if a user implements a + # nexus handler by running one workflow to completion, and then starting a + # second workflow to act as the async operation itself? + # TODO(prerelease): What do we do if the Temporal Nexus context client + # (namespace) is not the same as the one being used to start this workflow? + if nexus_start_ctx := nexus_ctx.temporal_nexus_start_operation_context: + nexus_completion_callbacks = nexus_start_ctx.get_completion_callbacks() + workflow_event_links = nexus_start_ctx.get_workflow_event_links() else: nexus_completion_callbacks = [] workflow_event_links = [] diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 210bcc3c0..fb00316d3 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -161,7 +161,6 @@ async def start( args=[HandlerWfInput(op_input=input)], id=input.response_type.operation_workflow_id, task_queue=tctx.task_queue, - nexus_operation=tctx, ) return nexusrpc.handler.StartOperationResultAsync( WorkflowOperationToken.from_workflow_handle(wf_handle).encode() @@ -221,7 +220,6 @@ async def async_operation( args=[HandlerWfInput(op_input=input)], id=input.response_type.operation_workflow_id, task_queue=tctx.task_queue, - nexus_operation=tctx, ) From 75cb096536bddd51785950568d9eb7c44ab15ae1 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 21 Jun 2025 13:59:55 -0400 Subject: [PATCH 006/183] Adjust imports --- temporalio/client.py | 3 +- .../nexus/handler/_operation_handlers.py | 33 ++++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 5ab8b7c0b..6469c40a2 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -53,12 +53,11 @@ import temporalio.common import temporalio.converter import temporalio.exceptions -import temporalio.nexus.handler import temporalio.runtime import temporalio.service import temporalio.workflow from temporalio.activity import ActivityCancellationDetails -from temporalio.nexus.handler import ( +from temporalio.nexus.handler._operation_context import ( TemporalNexusOperationContext, ) from temporalio.service import ( diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 56b2ccb51..e27ab9674 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -35,16 +35,13 @@ ) if TYPE_CHECKING: - from temporalio.client import ( - Client, - WorkflowHandle, - ) + import temporalio.client async def cancel_workflow( ctx: CancelOperationContext, token: str, - client: Optional[Client] = None, # noqa + client: Optional[temporalio.client.Client] = None, # noqa **kwargs: Any, ) -> None: client = client or TemporalNexusOperationContext.current().client @@ -76,7 +73,7 @@ def __init__( service: ServiceHandlerT, start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], ], output_type: Optional[Type] = None, ): @@ -129,14 +126,18 @@ class WorkflowRunOperationResult(nexusrpc.handler.StartOperationResultAsync): """ @classmethod - def from_workflow_handle(cls, workflow_handle: WorkflowHandle) -> Self: + def from_workflow_handle( + cls, workflow_handle: temporalio.client.WorkflowHandle[Any, Any] + ) -> Self: """ Create a :class:`WorkflowRunOperationResult` from a :py:class:`~temporalio.client.WorkflowHandle`. """ token = WorkflowOperationToken.from_workflow_handle(workflow_handle).encode() return cls(token=token) - def to_workflow_handle(self, client: Client) -> WorkflowHandle: + def to_workflow_handle( + self, client: temporalio.client.Client + ) -> temporalio.client.WorkflowHandle[Any, Any]: """ Create a :py:class:`~temporalio.client.WorkflowHandle` from a :class:`WorkflowRunOperationResult`. """ @@ -154,7 +155,7 @@ def to_workflow_handle(self, client: Client) -> WorkflowHandle: def workflow_run_operation_handler( start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], ], ) -> Callable[ [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] @@ -169,7 +170,7 @@ def workflow_run_operation_handler( [ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], ] ], Callable[ @@ -182,7 +183,7 @@ def workflow_run_operation_handler( start_method: Optional[ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], ] ] = None, *, @@ -195,7 +196,7 @@ def workflow_run_operation_handler( [ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], ] ], Callable[ @@ -207,7 +208,7 @@ def workflow_run_operation_handler( def decorator( start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], ], ) -> Callable[ [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] @@ -253,7 +254,7 @@ def factory( def _get_workflow_run_start_method_input_and_output_type_annotations( start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], ], ) -> tuple[ Optional[Type[InputT]], @@ -265,7 +266,7 @@ def _get_workflow_run_start_method_input_and_output_type_annotations( :py:class:`WorkflowHandle`. """ # TODO(nexus-preview) circular import - from temporalio.client import WorkflowHandle + import temporalio.client input_type, output_type = ( nexusrpc.handler.get_start_method_input_and_output_types_annotations( @@ -273,7 +274,7 @@ def _get_workflow_run_start_method_input_and_output_type_annotations( ) ) origin_type = typing.get_origin(output_type) - if not origin_type or not issubclass(origin_type, WorkflowHandle): + if not origin_type or not issubclass(origin_type, temporalio.client.WorkflowHandle): warnings.warn( f"Expected return type of {start_method.__name__} to be a subclass of WorkflowHandle, " f"but is {output_type}" From 31ce1fc52bb7f1dab916ea7ad7bf1f10b89ffa0e Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 21 Jun 2025 16:05:52 -0400 Subject: [PATCH 007/183] Revert "Adjust imports" This reverts commit 1a74d64b829a0e08336d1ada3b6afbf139076217. --- temporalio/client.py | 3 +- .../nexus/handler/_operation_handlers.py | 33 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 6469c40a2..5ab8b7c0b 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -53,11 +53,12 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.nexus.handler import temporalio.runtime import temporalio.service import temporalio.workflow from temporalio.activity import ActivityCancellationDetails -from temporalio.nexus.handler._operation_context import ( +from temporalio.nexus.handler import ( TemporalNexusOperationContext, ) from temporalio.service import ( diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index e27ab9674..56b2ccb51 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -35,13 +35,16 @@ ) if TYPE_CHECKING: - import temporalio.client + from temporalio.client import ( + Client, + WorkflowHandle, + ) async def cancel_workflow( ctx: CancelOperationContext, token: str, - client: Optional[temporalio.client.Client] = None, # noqa + client: Optional[Client] = None, # noqa **kwargs: Any, ) -> None: client = client or TemporalNexusOperationContext.current().client @@ -73,7 +76,7 @@ def __init__( service: ServiceHandlerT, start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], + Awaitable[WorkflowHandle[Any, OutputT]], ], output_type: Optional[Type] = None, ): @@ -126,18 +129,14 @@ class WorkflowRunOperationResult(nexusrpc.handler.StartOperationResultAsync): """ @classmethod - def from_workflow_handle( - cls, workflow_handle: temporalio.client.WorkflowHandle[Any, Any] - ) -> Self: + def from_workflow_handle(cls, workflow_handle: WorkflowHandle) -> Self: """ Create a :class:`WorkflowRunOperationResult` from a :py:class:`~temporalio.client.WorkflowHandle`. """ token = WorkflowOperationToken.from_workflow_handle(workflow_handle).encode() return cls(token=token) - def to_workflow_handle( - self, client: temporalio.client.Client - ) -> temporalio.client.WorkflowHandle[Any, Any]: + def to_workflow_handle(self, client: Client) -> WorkflowHandle: """ Create a :py:class:`~temporalio.client.WorkflowHandle` from a :class:`WorkflowRunOperationResult`. """ @@ -155,7 +154,7 @@ def to_workflow_handle( def workflow_run_operation_handler( start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], + Awaitable[WorkflowHandle[Any, OutputT]], ], ) -> Callable[ [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] @@ -170,7 +169,7 @@ def workflow_run_operation_handler( [ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], + Awaitable[WorkflowHandle[Any, OutputT]], ] ], Callable[ @@ -183,7 +182,7 @@ def workflow_run_operation_handler( start_method: Optional[ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], + Awaitable[WorkflowHandle[Any, OutputT]], ] ] = None, *, @@ -196,7 +195,7 @@ def workflow_run_operation_handler( [ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], + Awaitable[WorkflowHandle[Any, OutputT]], ] ], Callable[ @@ -208,7 +207,7 @@ def workflow_run_operation_handler( def decorator( start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], + Awaitable[WorkflowHandle[Any, OutputT]], ], ) -> Callable[ [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] @@ -254,7 +253,7 @@ def factory( def _get_workflow_run_start_method_input_and_output_type_annotations( start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[temporalio.client.WorkflowHandle[Any, OutputT]], + Awaitable[WorkflowHandle[Any, OutputT]], ], ) -> tuple[ Optional[Type[InputT]], @@ -266,7 +265,7 @@ def _get_workflow_run_start_method_input_and_output_type_annotations( :py:class:`WorkflowHandle`. """ # TODO(nexus-preview) circular import - import temporalio.client + from temporalio.client import WorkflowHandle input_type, output_type = ( nexusrpc.handler.get_start_method_input_and_output_types_annotations( @@ -274,7 +273,7 @@ def _get_workflow_run_start_method_input_and_output_type_annotations( ) ) origin_type = typing.get_origin(output_type) - if not origin_type or not issubclass(origin_type, temporalio.client.WorkflowHandle): + if not origin_type or not issubclass(origin_type, WorkflowHandle): warnings.warn( f"Expected return type of {start_method.__name__} to be a subclass of WorkflowHandle, " f"but is {output_type}" From 8ce108fb9e5418b73122a69d46fa21400f3526a7 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 21 Jun 2025 15:28:27 -0400 Subject: [PATCH 008/183] TODO --- temporalio/nexus/handler/_operation_handlers.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 56b2ccb51..7cfc94173 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -86,6 +86,20 @@ def __init__( async def start( self, ctx: StartOperationContext, input: InputT ) -> WorkflowRunOperationResult: + # TODO(nexus-prerelease) It must be possible to start "normal" workflows in + # here, and then finish up with a "nexusified" workflow. + # TODO(nexus-prerelease) It should not be possible to construct a Nexus + # token for a non-nexusified workflow. + # TODO(nexus-prerelease) When `start` returns, must the workflow have been + # started? The answer is yes, but that's yes regarding the + # OperationHandler.start() method that is created by the decorator: it's OK + # for the shorthand method to return a lazily evaluated start_workflow; it + # will only ever be used in its transformed form. Note that in a + # `OperationHandler.start` method, a user should be able to create a token + # for a nexusified workflow and return it as a Nexus response: + # + # token = WorkflowOperationToken.from_workflow_handle(wf_handle).encode() + # return StartOperationResultAsync(token) wf_handle = await start_method(service, ctx, input) return WorkflowRunOperationResult.from_workflow_handle(wf_handle) From ecc08762d5404c39307d4460a6df84e3c738a0bb Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 21 Jun 2025 16:12:53 -0400 Subject: [PATCH 009/183] Option 2 for workflow_run_operation_handler The only way to start a nexusified workflow is to create a WorkflowStartOperationResult, passing it an Awaitable[WorkflowHandle]. Accordingly, you must go via this type to create a token encoding a nexusified worfklow handle. The method decorated by the shorthand decorator @workflow_run_operation_handler must return WorkflowStartOperationResult. A manually-implemented OperationHandler start method may create a token this way. --- temporalio/nexus/handler/__init__.py | 4 +- .../nexus/handler/_operation_handlers.py | 111 ++++++++++-------- tests/nexus/test_handler.py | 45 ++++--- .../test_handler_interface_implementation.py | 4 +- .../test_handler_operation_definitions.py | 8 +- tests/nexus/test_workflow_caller.py | 15 ++- 6 files changed, 104 insertions(+), 83 deletions(-) diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index 9750b876a..6d1c1b8e0 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -23,10 +23,10 @@ TemporalNexusOperationContext as TemporalNexusOperationContext, ) from ._operation_handlers import ( - WorkflowRunOperationHandler as WorkflowRunOperationHandler, + NexusStartWorkflowRequest as NexusStartWorkflowRequest, ) from ._operation_handlers import ( - WorkflowRunOperationResult as WorkflowRunOperationResult, + WorkflowRunOperationHandler as WorkflowRunOperationHandler, ) from ._operation_handlers import cancel_workflow as cancel_workflow from ._operation_handlers import ( diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 7cfc94173..7f7374928 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -9,6 +9,7 @@ Any, Awaitable, Callable, + Coroutine, Generic, Optional, Type, @@ -21,13 +22,14 @@ HandlerError, HandlerErrorType, StartOperationContext, + StartOperationResultAsync, ) from nexusrpc.types import ( InputT, OutputT, ServiceHandlerT, ) -from typing_extensions import Self, overload +from typing_extensions import overload from ._operation_context import TemporalNexusOperationContext from ._token import ( @@ -67,6 +69,47 @@ async def cancel_workflow( await handle.cancel(**kwargs) +class NexusStartWorkflowRequest(Generic[OutputT]): + """ + A request to start a workflow that will handle the Nexus operation. + """ + + def __init__( + self, start_workflow: Coroutine[Any, Any, WorkflowHandle[Any, OutputT]], / + ): + if start_workflow.__qualname__ != "Client.start_workflow": + raise ValueError( + "NexusStartWorkflowRequest must be initialized with the coroutine " + "object obtained by calling Client.start_workflow." + ) + self._start_workflow = start_workflow + + async def start_workflow(self) -> WorkflowHandle[Any, OutputT]: + # TODO(nexus-prerelease) set context such that nexus metadata is injected into request + return await self._start_workflow + + # @classmethod + # def from_workflow_handle(cls, workflow_handle: WorkflowHandle) -> Self: + # """ + # Create a :class:`WorkflowRunOperationResult` from a :py:class:`~temporalio.client.WorkflowHandle`. + # """ + # token = WorkflowOperationToken.from_workflow_handle(workflow_handle).encode() + # return cls(token=token) + + # def to_workflow_handle(self, client: Client) -> WorkflowHandle: + # """ + # Create a :py:class:`~temporalio.client.WorkflowHandle` from a :class:`WorkflowRunOperationResult`. + # """ + # workflow_operation_token = WorkflowOperationToken.decode(self.token) + # if workflow_operation_token.namespace != client.namespace: + # raise ValueError( + # "Cannot create a workflow handle from a workflow operation result " + # "with a client whose namespace is not the same as the namespace of the " + # "workflow operation token." + # ) + # return WorkflowOperationToken.decode(self.token).to_workflow_handle(client) + + class WorkflowRunOperationHandler( nexusrpc.handler.OperationHandler[InputT, OutputT], Generic[InputT, OutputT, ServiceHandlerT], @@ -76,7 +119,7 @@ def __init__( service: ServiceHandlerT, start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[NexusStartWorkflowRequest[OutputT]], ], output_type: Optional[Type] = None, ): @@ -85,7 +128,7 @@ def __init__( @wraps(start_method) async def start( self, ctx: StartOperationContext, input: InputT - ) -> WorkflowRunOperationResult: + ) -> StartOperationResultAsync: # TODO(nexus-prerelease) It must be possible to start "normal" workflows in # here, and then finish up with a "nexusified" workflow. # TODO(nexus-prerelease) It should not be possible to construct a Nexus @@ -100,8 +143,10 @@ async def start( # # token = WorkflowOperationToken.from_workflow_handle(wf_handle).encode() # return StartOperationResultAsync(token) - wf_handle = await start_method(service, ctx, input) - return WorkflowRunOperationResult.from_workflow_handle(wf_handle) + start_wf_request = await start_method(service, ctx, input) + wf_handle = await start_wf_request.start_workflow() + token = WorkflowOperationToken.from_workflow_handle(wf_handle).encode() + return StartOperationResultAsync(token) self.start = types.MethodType(start, self) @@ -133,42 +178,11 @@ def fetch_result( ) -class WorkflowRunOperationResult(nexusrpc.handler.StartOperationResultAsync): - """ - A value returned by the start method of a :class:`WorkflowRunOperation`. - - It indicates that the operation is responding asynchronously, and contains a token - that the handler can use to construct a :class:`~temporalio.client.WorkflowHandle` to - interact with the workflow. - """ - - @classmethod - def from_workflow_handle(cls, workflow_handle: WorkflowHandle) -> Self: - """ - Create a :class:`WorkflowRunOperationResult` from a :py:class:`~temporalio.client.WorkflowHandle`. - """ - token = WorkflowOperationToken.from_workflow_handle(workflow_handle).encode() - return cls(token=token) - - def to_workflow_handle(self, client: Client) -> WorkflowHandle: - """ - Create a :py:class:`~temporalio.client.WorkflowHandle` from a :class:`WorkflowRunOperationResult`. - """ - workflow_operation_token = WorkflowOperationToken.decode(self.token) - if workflow_operation_token.namespace != client.namespace: - raise ValueError( - "Cannot create a workflow handle from a workflow operation result " - "with a client whose namespace is not the same as the namespace of the " - "workflow operation token." - ) - return WorkflowOperationToken.decode(self.token).to_workflow_handle(client) - - @overload def workflow_run_operation_handler( start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[NexusStartWorkflowRequest[OutputT]], ], ) -> Callable[ [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] @@ -183,7 +197,7 @@ def workflow_run_operation_handler( [ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[NexusStartWorkflowRequest[OutputT]], ] ], Callable[ @@ -196,7 +210,7 @@ def workflow_run_operation_handler( start_method: Optional[ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[NexusStartWorkflowRequest[OutputT]], ] ] = None, *, @@ -209,7 +223,7 @@ def workflow_run_operation_handler( [ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[NexusStartWorkflowRequest[OutputT]], ] ], Callable[ @@ -221,7 +235,7 @@ def workflow_run_operation_handler( def decorator( start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[NexusStartWorkflowRequest[OutputT]], ], ) -> Callable[ [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] @@ -267,7 +281,7 @@ def factory( def _get_workflow_run_start_method_input_and_output_type_annotations( start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowHandle[Any, OutputT]], + Awaitable[NexusStartWorkflowRequest[OutputT]], ], ) -> tuple[ Optional[Type[InputT]], @@ -278,29 +292,26 @@ def _get_workflow_run_start_method_input_and_output_type_annotations( `start_method` must be a type-annotated start method that returns a :py:class:`WorkflowHandle`. """ - # TODO(nexus-preview) circular import - from temporalio.client import WorkflowHandle - input_type, output_type = ( nexusrpc.handler.get_start_method_input_and_output_types_annotations( start_method ) ) origin_type = typing.get_origin(output_type) - if not origin_type or not issubclass(origin_type, WorkflowHandle): + if not origin_type or not issubclass(origin_type, NexusStartWorkflowRequest): warnings.warn( - f"Expected return type of {start_method.__name__} to be a subclass of WorkflowHandle, " + f"Expected return type of {start_method.__name__} to be a subclass of NexusStartWorkflowRequest, " f"but is {output_type}" ) output_type = None args = typing.get_args(output_type) - if len(args) != 2: + if len(args) != 1: warnings.warn( - f"Expected return type of {start_method.__name__} to have exactly two type parameters, " + f"Expected return type of {start_method.__name__} to have exactly one type parameter, " f"but has {len(args)}: {args}" ) output_type = None else: - _wf_type, output_type = args + [output_type] = args return input_type, output_type diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index b54198050..d4f7386aa 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -38,7 +38,7 @@ import temporalio.api.failure.v1 import temporalio.nexus from temporalio import workflow -from temporalio.client import Client, WorkflowHandle +from temporalio.client import Client from temporalio.common import WorkflowIDReusePolicy from temporalio.converter import FailureConverter, PayloadConverter from temporalio.exceptions import ApplicationError @@ -46,6 +46,7 @@ logger, ) from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext +from temporalio.nexus.handler._operation_handlers import NexusStartWorkflowRequest from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ServiceClient, create_nexus_endpoint @@ -210,14 +211,16 @@ async def log(self, ctx: StartOperationContext, input: Input) -> Output: @temporalio.nexus.handler.workflow_run_operation_handler async def workflow_run_operation( self, ctx: StartOperationContext, input: Input - ) -> WorkflowHandle[Any, Output]: + ) -> NexusStartWorkflowRequest[Output]: tctx = TemporalNexusOperationContext.current() - return await tctx.client.start_workflow( - MyWorkflow.run, - input, - id=test_context.workflow_id or str(uuid.uuid4()), - task_queue=tctx.task_queue, - id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, + return NexusStartWorkflowRequest( + tctx.client.start_workflow( + MyWorkflow.run, + input, + id=test_context.workflow_id or str(uuid.uuid4()), + task_queue=tctx.task_queue, + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, + ) ) @nexusrpc.handler.sync_operation_handler @@ -257,17 +260,19 @@ async def sync_operation_without_type_annotations(self, ctx, input): @temporalio.nexus.handler.workflow_run_operation_handler async def workflow_run_operation_without_type_annotations(self, ctx, input): tctx = TemporalNexusOperationContext.current() - return await tctx.client.start_workflow( - WorkflowWithoutTypeAnnotations.run, - input, - id=test_context.workflow_id or str(uuid.uuid4()), - task_queue=tctx.task_queue, + return NexusStartWorkflowRequest( + tctx.client.start_workflow( + WorkflowWithoutTypeAnnotations.run, + input, + id=test_context.workflow_id or str(uuid.uuid4()), + task_queue=tctx.task_queue, + ) ) @temporalio.nexus.handler.workflow_run_operation_handler async def workflow_run_op_link_test( self, ctx: StartOperationContext, input: Input - ) -> WorkflowHandle[Any, Output]: + ) -> NexusStartWorkflowRequest[Output]: assert any( link.url == "http://inbound-link/" for link in ctx.inbound_links ), "Inbound link not found" @@ -275,11 +280,13 @@ async def workflow_run_op_link_test( ctx.outbound_links.extend(ctx.inbound_links) tctx = TemporalNexusOperationContext.current() - return await tctx.client.start_workflow( - MyLinkTestWorkflow.run, - input, - id=test_context.workflow_id or str(uuid.uuid4()), - task_queue=tctx.task_queue, + return NexusStartWorkflowRequest( + tctx.client.start_workflow( + MyLinkTestWorkflow.run, + input, + id=test_context.workflow_id or str(uuid.uuid4()), + task_queue=tctx.task_queue, + ) ) class OperationHandlerReturningUnwrappedResult( diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index f688ca791..f33b8e9db 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -6,7 +6,7 @@ import temporalio.api.failure.v1 import temporalio.nexus -from temporalio.client import WorkflowHandle +from temporalio.nexus.handler import NexusStartWorkflowRequest HTTP_PORT = 7243 @@ -40,7 +40,7 @@ class Impl: @temporalio.nexus.handler.workflow_run_operation_handler async def op( self, ctx: nexusrpc.handler.StartOperationContext, input: str - ) -> WorkflowHandle[Any, int]: ... + ) -> NexusStartWorkflowRequest[int]: ... error_message = None diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index 51b4e66d3..7ba9dad10 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -10,7 +10,7 @@ import pytest import temporalio.nexus.handler -from temporalio.client import WorkflowHandle +from temporalio.nexus.handler import NexusStartWorkflowRequest @dataclass @@ -35,7 +35,7 @@ class Service: @temporalio.nexus.handler.workflow_run_operation_handler async def workflow_run_operation_handler( self, ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> WorkflowHandle[Any, Output]: ... + ) -> NexusStartWorkflowRequest[Output]: ... expected_operations = { "workflow_run_operation_handler": nexusrpc.Operation( @@ -53,7 +53,7 @@ class Service: @temporalio.nexus.handler.workflow_run_operation_handler() async def workflow_run_operation_handler( self, ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> WorkflowHandle[Any, Output]: ... + ) -> NexusStartWorkflowRequest[Output]: ... expected_operations = NotCalled.expected_operations @@ -64,7 +64,7 @@ class Service: @temporalio.nexus.handler.workflow_run_operation_handler(name="operation-name") async def workflow_run_operation_with_name_override( self, ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> WorkflowHandle[Any, Output]: ... + ) -> NexusStartWorkflowRequest[Output]: ... expected_operations = { "workflow_run_operation_with_name_override": nexusrpc.Operation( diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index fb00316d3..58b7b6dd0 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -34,6 +34,7 @@ from temporalio.common import WorkflowIDConflictPolicy from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext +from temporalio.nexus.handler._operation_handlers import NexusStartWorkflowRequest from temporalio.nexus.handler._token import WorkflowOperationToken from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import UnsandboxedWorkflowRunner, Worker @@ -206,7 +207,7 @@ async def sync_operation( @temporalio.nexus.handler.workflow_run_operation_handler async def async_operation( self, ctx: StartOperationContext, input: OpInput - ) -> WorkflowHandle[HandlerWorkflow, HandlerWfOutput]: + ) -> NexusStartWorkflowRequest[HandlerWfOutput]: assert isinstance(input.response_type, AsyncResponse) if input.response_type.exception_in_operation_start: raise RPCError( @@ -215,11 +216,13 @@ async def async_operation( b"", ) tctx = TemporalNexusOperationContext.current() - return await tctx.client.start_workflow( - HandlerWorkflow.run, - args=[HandlerWfInput(op_input=input)], - id=input.response_type.operation_workflow_id, - task_queue=tctx.task_queue, + return NexusStartWorkflowRequest( + tctx.client.start_workflow( + HandlerWorkflow.run, + args=[HandlerWfInput(op_input=input)], + id=input.response_type.operation_workflow_id, + task_queue=tctx.task_queue, + ) ) From d5c1184e7c82c5d88270db7d6ad487bbe21639ef Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 21 Jun 2025 21:43:00 -0400 Subject: [PATCH 010/183] Failing test: first of two workflows incorrectly delivers result --- tests/nexus/test_workflow_caller.py | 79 +++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 58b7b6dd0..84f420e98 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -906,6 +906,85 @@ async def test_service_interface_and_implementation_names(client: Client): ) +@nexusrpc.service +class ServiceWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: + my_workflow_run_operation: nexusrpc.Operation[None, None] + my_manual_async_operation: nexusrpc.Operation[None, None] + + +@workflow.defn +class EchoWorkflow: + @workflow.run + async def run(self, input: str) -> str: + return input + + +@nexusrpc.handler.service_handler +class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: + @temporalio.nexus.handler.workflow_run_operation_handler + async def my_workflow_run_operation( + self, ctx: StartOperationContext, input: None + ) -> NexusStartWorkflowRequest[str]: + tctx = TemporalNexusOperationContext.current() + result_1 = await tctx.client.execute_workflow( + EchoWorkflow.run, + "result-1", + id=str(uuid.uuid4()), + task_queue=tctx.task_queue, + ) + # In case result_1 is incorrectly being delivered to the caller as the operation + # result, give time for that incorrect behavior to occur. + await asyncio.sleep(0.5) + return NexusStartWorkflowRequest( + tctx.client.start_workflow( + EchoWorkflow.run, + f"{result_1}-result-2", + id=str(uuid.uuid4()), + task_queue=tctx.task_queue, + ) + ) + + +@workflow.defn +class WorkflowCallingNexusOperationThatExecutesWorkflowBeforeStartingBackingWorkflow: + @workflow.run + async def run(self, input: str, task_queue: str) -> str: + nexus_client = workflow.NexusClient( + service=ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow, + endpoint=make_nexus_endpoint_name(task_queue), + ) + return await nexus_client.execute_operation( + ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow.my_workflow_run_operation, + None, + ) + + +async def test_workflow_run_operation_can_execute_workflow_before_starting_backing_workflow( + client: Client, +): + task_queue = str(uuid.uuid4()) + async with Worker( + client, + workflows=[ + EchoWorkflow, + WorkflowCallingNexusOperationThatExecutesWorkflowBeforeStartingBackingWorkflow, + ], + nexus_service_handlers=[ + ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow(), + ], + task_queue=task_queue, + workflow_runner=UnsandboxedWorkflowRunner(), + ): + await create_nexus_endpoint(task_queue, client) + result = await client.execute_workflow( + WorkflowCallingNexusOperationThatExecutesWorkflowBeforeStartingBackingWorkflow.run, + args=("result-1", task_queue), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + assert result == "result-1-result-2" + + # TODO(dan): test invalid service interface implementations # TODO(dan): test caller passing output_type From 20e82fc5e0f7720b3b585d4927c82a4984cd9fcb Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 21 Jun 2025 22:24:13 -0400 Subject: [PATCH 011/183] WIP: Option 3 for workflow_run_operation_handler NexusStartWorkflowRequest holds start_workflow request params, and exposes a method to start the workflow that injects the required Nexus metadata, and populates outbound links. --- temporalio/client.py | 27 ++--- temporalio/nexus/handler/__init__.py | 3 + .../nexus/handler/_operation_handlers.py | 39 +------ temporalio/nexus/handler/_start_workflow.py | 104 ++++++++++++++++++ tests/nexus/test_handler.py | 35 +++--- tests/nexus/test_workflow_caller.py | 28 ++--- 6 files changed, 150 insertions(+), 86 deletions(-) create mode 100644 temporalio/nexus/handler/_start_workflow.py diff --git a/temporalio/client.py b/temporalio/client.py index 5ab8b7c0b..e6d0f28cc 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -471,6 +471,12 @@ async def start_workflow( priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, # The following options are deliberately not exposed in overloads + nexus_completion_callbacks: Sequence[ + temporalio.common.NexusCompletionCallback + ] = [], + workflow_event_links: Sequence[ + temporalio.api.common.v1.Link.WorkflowEvent + ] = [], stack_level: int = 2, ) -> WorkflowHandle[Any, Any]: """Start a workflow and return its handle. @@ -534,21 +540,7 @@ async def start_workflow( name, result_type_from_type_hint = ( temporalio.workflow._Definition.get_name_and_result_type(workflow) ) - nexus_start_ctx = None - if nexus_ctx := TemporalNexusOperationContext.try_current(): - # TODO(prerelease): I think this is too magical: what if a user implements a - # nexus handler by running one workflow to completion, and then starting a - # second workflow to act as the async operation itself? - # TODO(prerelease): What do we do if the Temporal Nexus context client - # (namespace) is not the same as the one being used to start this workflow? - if nexus_start_ctx := nexus_ctx.temporal_nexus_start_operation_context: - nexus_completion_callbacks = nexus_start_ctx.get_completion_callbacks() - workflow_event_links = nexus_start_ctx.get_workflow_event_links() - else: - nexus_completion_callbacks = [] - workflow_event_links = [] - - wf_handle = await self._impl.start_workflow( + return await self._impl.start_workflow( StartWorkflowInput( workflow=name, args=temporalio.common._arg_or_args(arg, args), @@ -580,11 +572,6 @@ async def start_workflow( ) ) - if nexus_start_ctx: - nexus_start_ctx.add_outbound_links(wf_handle) - - return wf_handle - # Overload for no-param workflow @overload async def execute_workflow( diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index 6d1c1b8e0..a99446da8 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -32,6 +32,9 @@ from ._operation_handlers import ( workflow_run_operation_handler as workflow_run_operation_handler, ) +from ._start_workflow import ( + start_workflow as start_workflow, +) from ._token import ( WorkflowOperationToken as WorkflowOperationToken, ) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 7f7374928..364654578 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -9,7 +9,6 @@ Any, Awaitable, Callable, - Coroutine, Generic, Optional, Type, @@ -31,6 +30,8 @@ ) from typing_extensions import overload +import temporalio.nexus.handler + from ._operation_context import TemporalNexusOperationContext from ._token import ( WorkflowOperationToken as WorkflowOperationToken, @@ -74,40 +75,12 @@ class NexusStartWorkflowRequest(Generic[OutputT]): A request to start a workflow that will handle the Nexus operation. """ - def __init__( - self, start_workflow: Coroutine[Any, Any, WorkflowHandle[Any, OutputT]], / - ): - if start_workflow.__qualname__ != "Client.start_workflow": - raise ValueError( - "NexusStartWorkflowRequest must be initialized with the coroutine " - "object obtained by calling Client.start_workflow." - ) - self._start_workflow = start_workflow + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs async def start_workflow(self) -> WorkflowHandle[Any, OutputT]: - # TODO(nexus-prerelease) set context such that nexus metadata is injected into request - return await self._start_workflow - - # @classmethod - # def from_workflow_handle(cls, workflow_handle: WorkflowHandle) -> Self: - # """ - # Create a :class:`WorkflowRunOperationResult` from a :py:class:`~temporalio.client.WorkflowHandle`. - # """ - # token = WorkflowOperationToken.from_workflow_handle(workflow_handle).encode() - # return cls(token=token) - - # def to_workflow_handle(self, client: Client) -> WorkflowHandle: - # """ - # Create a :py:class:`~temporalio.client.WorkflowHandle` from a :class:`WorkflowRunOperationResult`. - # """ - # workflow_operation_token = WorkflowOperationToken.decode(self.token) - # if workflow_operation_token.namespace != client.namespace: - # raise ValueError( - # "Cannot create a workflow handle from a workflow operation result " - # "with a client whose namespace is not the same as the namespace of the " - # "workflow operation token." - # ) - # return WorkflowOperationToken.decode(self.token).to_workflow_handle(client) + return await temporalio.nexus.handler.start_workflow(*self.args, **self.kwargs) class WorkflowRunOperationHandler( diff --git a/temporalio/nexus/handler/_start_workflow.py b/temporalio/nexus/handler/_start_workflow.py new file mode 100644 index 000000000..dda4aac35 --- /dev/null +++ b/temporalio/nexus/handler/_start_workflow.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + Optional, + Sequence, + Union, +) + +import temporalio.common +from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext +from temporalio.types import ( + MethodAsyncSingleParam, + ParamType, + ReturnType, + SelfType, +) + +if TYPE_CHECKING: + from temporalio.client import Client, WorkflowHandle + + +# Overload for single-param workflow +async def start_workflow( + client: Client, + workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, +) -> WorkflowHandle[SelfType, ReturnType]: + if nexus_ctx := TemporalNexusOperationContext.try_current(): + if nexus_start_ctx := nexus_ctx.temporal_nexus_start_operation_context: + nexus_completion_callbacks = nexus_start_ctx.get_completion_callbacks() + workflow_event_links = nexus_start_ctx.get_workflow_event_links() + else: + raise RuntimeError( + "temporalio.nexus.handler.start_workflow() must be called from within a Nexus start operation context" + ) + else: + raise RuntimeError( + "temporalio.nexus.handler.start_workflow() must be called from within a Nexus operation context" + ) + + # We must pass nexus_completion_callbacks and workflow_event_links, but these are + # deliberately not exposed in overloads, hence the type check violation. + wf_handle = await client.start_workflow( # type: ignore + workflow=workflow, + arg=arg, + id=id, + task_queue=task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + start_signal=start_signal, + start_signal_args=start_signal_args, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + request_eager_start=request_eager_start, + priority=priority, + versioning_override=versioning_override, + nexus_completion_callbacks=nexus_completion_callbacks, + workflow_event_links=workflow_event_links, + ) + + nexus_start_ctx.add_outbound_links(wf_handle) + + return wf_handle diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index d4f7386aa..ab891070a 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -214,13 +214,12 @@ async def workflow_run_operation( ) -> NexusStartWorkflowRequest[Output]: tctx = TemporalNexusOperationContext.current() return NexusStartWorkflowRequest( - tctx.client.start_workflow( - MyWorkflow.run, - input, - id=test_context.workflow_id or str(uuid.uuid4()), - task_queue=tctx.task_queue, - id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, - ) + tctx.client, + MyWorkflow.run, + input, + id=test_context.workflow_id or str(uuid.uuid4()), + task_queue=tctx.task_queue, + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) @nexusrpc.handler.sync_operation_handler @@ -261,12 +260,11 @@ async def sync_operation_without_type_annotations(self, ctx, input): async def workflow_run_operation_without_type_annotations(self, ctx, input): tctx = TemporalNexusOperationContext.current() return NexusStartWorkflowRequest( - tctx.client.start_workflow( - WorkflowWithoutTypeAnnotations.run, - input, - id=test_context.workflow_id or str(uuid.uuid4()), - task_queue=tctx.task_queue, - ) + tctx.client, + WorkflowWithoutTypeAnnotations.run, + input, + id=test_context.workflow_id or str(uuid.uuid4()), + task_queue=tctx.task_queue, ) @temporalio.nexus.handler.workflow_run_operation_handler @@ -281,12 +279,11 @@ async def workflow_run_op_link_test( tctx = TemporalNexusOperationContext.current() return NexusStartWorkflowRequest( - tctx.client.start_workflow( - MyLinkTestWorkflow.run, - input, - id=test_context.workflow_id or str(uuid.uuid4()), - task_queue=tctx.task_queue, - ) + tctx.client, + MyLinkTestWorkflow.run, + input, + id=test_context.workflow_id or str(uuid.uuid4()), + task_queue=tctx.task_queue, ) class OperationHandlerReturningUnwrappedResult( diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 84f420e98..54e57c658 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -157,12 +157,14 @@ async def start( ) elif isinstance(input.response_type, AsyncResponse): tctx = TemporalNexusOperationContext.current() - wf_handle = await tctx.client.start_workflow( + start_request = NexusStartWorkflowRequest( # type: ignore + tctx.client, HandlerWorkflow.run, - args=[HandlerWfInput(op_input=input)], + HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, task_queue=tctx.task_queue, ) + wf_handle = await start_request.start_workflow() return nexusrpc.handler.StartOperationResultAsync( WorkflowOperationToken.from_workflow_handle(wf_handle).encode() ) @@ -217,12 +219,11 @@ async def async_operation( ) tctx = TemporalNexusOperationContext.current() return NexusStartWorkflowRequest( - tctx.client.start_workflow( - HandlerWorkflow.run, - args=[HandlerWfInput(op_input=input)], - id=input.response_type.operation_workflow_id, - task_queue=tctx.task_queue, - ) + tctx.client, + HandlerWorkflow.run, + HandlerWfInput(op_input=input), + id=input.response_type.operation_workflow_id, + task_queue=tctx.task_queue, ) @@ -936,12 +937,11 @@ async def my_workflow_run_operation( # result, give time for that incorrect behavior to occur. await asyncio.sleep(0.5) return NexusStartWorkflowRequest( - tctx.client.start_workflow( - EchoWorkflow.run, - f"{result_1}-result-2", - id=str(uuid.uuid4()), - task_queue=tctx.task_queue, - ) + tctx.client, + EchoWorkflow.run, + f"{result_1}-result-2", + id=str(uuid.uuid4()), + task_queue=tctx.task_queue, ) From 26543aec1767c6b764b8c453629b3f8980c9a077 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 07:52:39 -0400 Subject: [PATCH 012/183] TemporalNexusOperationContext should not be an ABC --- temporalio/nexus/handler/_operation_context.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index 2b6dbd9cc..b012c9e60 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -4,7 +4,6 @@ import logging import re import urllib.parse -from abc import ABC from contextvars import ContextVar from dataclasses import dataclass from typing import ( @@ -37,7 +36,7 @@ @dataclass -class TemporalNexusOperationContext(ABC): +class TemporalNexusOperationContext: """ Context for a Nexus operation being handled by a Temporal Nexus Worker. """ From e512834ca29b714babc49cce8584135c69f045eb Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 08:54:06 -0400 Subject: [PATCH 013/183] Remove unused output_type --- temporalio/nexus/handler/_operation_handlers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 364654578..61ac5211b 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -94,7 +94,6 @@ def __init__( [ServiceHandlerT, StartOperationContext, InputT], Awaitable[NexusStartWorkflowRequest[OutputT]], ], - output_type: Optional[Type] = None, ): self.service = service @@ -222,9 +221,8 @@ def decorator( def factory( service: ServiceHandlerT, ) -> WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT]: - return WorkflowRunOperationHandler( - service, start_method, output_type=output_type - ) + # TODO(nexus-prerelease) I was passing output_type here; why? + return WorkflowRunOperationHandler(service, start_method) # TODO(nexus-prerelease): handle callable instances: __class__.__name__ as in sync_operation_handler method_name = getattr(start_method, "__name__", None) From 80cf8aaa285ce4f272058dbbe27492ecdba95259 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 08:54:38 -0400 Subject: [PATCH 014/183] Cleanup --- .../nexus/handler/_operation_context.py | 6 ++-- .../nexus/handler/_operation_handlers.py | 28 +++++++++++-------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index b012c9e60..28a6cae4e 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -61,11 +61,13 @@ def current() -> TemporalNexusOperationContext: return context @staticmethod - def set(context: TemporalNexusOperationContext) -> contextvars.Token: + def set( + context: TemporalNexusOperationContext, + ) -> contextvars.Token[TemporalNexusOperationContext]: return _current_context.set(context) @staticmethod - def reset(token: contextvars.Token) -> None: + def reset(token: contextvars.Token[TemporalNexusOperationContext]) -> None: _current_context.reset(token) @property diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 61ac5211b..f61a1a391 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -99,7 +99,7 @@ def __init__( @wraps(start_method) async def start( - self, ctx: StartOperationContext, input: InputT + _, ctx: StartOperationContext, input: InputT ) -> StartOperationResultAsync: # TODO(nexus-prerelease) It must be possible to start "normal" workflows in # here, and then finish up with a "nexusified" workflow. @@ -212,12 +212,6 @@ def decorator( ) -> Callable[ [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] ]: - input_type, output_type = ( - _get_workflow_run_start_method_input_and_output_type_annotations( - start_method - ) - ) - def factory( service: ServiceHandlerT, ) -> WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT]: @@ -234,11 +228,21 @@ def factory( f"expected {start_method} to be a function or callable instance." ) - factory.__nexus_operation__ = nexusrpc.Operation( - name=name or method_name, - method_name=method_name, - input_type=input_type, - output_type=output_type, + input_type, output_type = ( + _get_workflow_run_start_method_input_and_output_type_annotations( + start_method + ) + ) + + setattr( + factory, + "__nexus_operation__", + nexusrpc.Operation( + name=name or method_name, + method_name=method_name, + input_type=input_type, + output_type=output_type, + ), ) return factory From 32b1604763f4db0dfdcd572c847089980720d4e8 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 08:59:09 -0400 Subject: [PATCH 015/183] Make WorkflowOperationToken generic, parameterized by output type --- temporalio/nexus/handler/_operation_handlers.py | 6 ++++-- temporalio/nexus/handler/_token.py | 17 +++++++++++------ tests/nexus/test_workflow_caller.py | 6 ++++-- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index f61a1a391..8602fc253 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -52,7 +52,7 @@ async def cancel_workflow( ) -> None: client = client or TemporalNexusOperationContext.current().client try: - decoded = WorkflowOperationToken.decode(token) + decoded = WorkflowOperationToken[Any].decode(token) except Exception as err: raise HandlerError( "Failed to decode workflow operation token", @@ -117,7 +117,9 @@ async def start( # return StartOperationResultAsync(token) start_wf_request = await start_method(service, ctx, input) wf_handle = await start_wf_request.start_workflow() - token = WorkflowOperationToken.from_workflow_handle(wf_handle).encode() + token = ( + WorkflowOperationToken[OutputT].from_workflow_handle(wf_handle).encode() + ) return StartOperationResultAsync(token) self.start = types.MethodType(start, self) diff --git a/temporalio/nexus/handler/_token.py b/temporalio/nexus/handler/_token.py index bf08198e4..a41be459f 100644 --- a/temporalio/nexus/handler/_token.py +++ b/temporalio/nexus/handler/_token.py @@ -3,7 +3,9 @@ import base64 import json from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Generic, Literal, Optional + +from nexusrpc.types import OutputT if TYPE_CHECKING: from temporalio.client import Client, WorkflowHandle @@ -13,7 +15,7 @@ @dataclass(frozen=True) -class WorkflowOperationToken: +class WorkflowOperationToken(Generic[OutputT]): """Represents the structured data of a Nexus workflow operation token.""" namespace: str @@ -23,17 +25,20 @@ class WorkflowOperationToken: # serialized token; it's only used to reject newer token versions on load. version: Optional[int] = None + # TODO(nexus-preview): Is it helpful to parameterize WorkflowOperationToken by + # OutputT? The return type here should be dictated by the input workflow handle + # type. @classmethod def from_workflow_handle( - cls, workflow_handle: WorkflowHandle[Any, Any] - ) -> WorkflowOperationToken: + cls, workflow_handle: WorkflowHandle[Any, OutputT] + ) -> WorkflowOperationToken[OutputT]: """Creates a token from a workflow handle.""" return cls( namespace=workflow_handle._client.namespace, workflow_id=workflow_handle.id, ) - def to_workflow_handle(self, client: Client) -> WorkflowHandle[Any, Any]: + def to_workflow_handle(self, client: Client) -> WorkflowHandle[Any, OutputT]: """Creates a workflow handle from this token.""" if client.namespace != self.namespace: raise ValueError( @@ -54,7 +59,7 @@ def encode(self) -> str: ) @classmethod - def decode(cls, token: str) -> WorkflowOperationToken: + def decode(cls, token: str) -> WorkflowOperationToken[OutputT]: """Decodes and validates a token from its base64url-encoded string representation.""" if not token: raise TypeError("invalid workflow token: token is empty") diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 54e57c658..22902e627 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -157,7 +157,7 @@ async def start( ) elif isinstance(input.response_type, AsyncResponse): tctx = TemporalNexusOperationContext.current() - start_request = NexusStartWorkflowRequest( # type: ignore + start_request = NexusStartWorkflowRequest[HandlerWfOutput]( # type: ignore tctx.client, HandlerWorkflow.run, HandlerWfInput(op_input=input), @@ -166,7 +166,9 @@ async def start( ) wf_handle = await start_request.start_workflow() return nexusrpc.handler.StartOperationResultAsync( - WorkflowOperationToken.from_workflow_handle(wf_handle).encode() + WorkflowOperationToken[HandlerWfOutput] + .from_workflow_handle(wf_handle) + .encode() ) else: raise TypeError From e1b996d41b83a3e8c4582b635fa0364c87155b6d Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 09:14:41 -0400 Subject: [PATCH 016/183] Option 4 for WorkflowRunOperationHandler tctx.start_workflow starts the workflow, injects Nexus metadata, but returns a WorkflowOperationToken instead of a WorkflowHandle. This means that users accidentally trying to start a Nexusified workflow via the standard client.start_workflow will get both a type-check-time error, and a run-time error. --- temporalio/nexus/handler/__init__.py | 6 - .../nexus/handler/_operation_context.py | 100 ++++++++++++++++- .../nexus/handler/_operation_handlers.py | 42 ++----- temporalio/nexus/handler/_start_workflow.py | 104 ------------------ temporalio/nexus/handler/_token.py | 25 +++-- tests/nexus/test_handler.py | 18 +-- .../test_handler_interface_implementation.py | 4 +- .../test_handler_operation_definitions.py | 8 +- tests/nexus/test_workflow_caller.py | 24 ++-- 9 files changed, 146 insertions(+), 185 deletions(-) delete mode 100644 temporalio/nexus/handler/_start_workflow.py diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index a99446da8..fbf144e9b 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -22,9 +22,6 @@ from ._operation_context import ( TemporalNexusOperationContext as TemporalNexusOperationContext, ) -from ._operation_handlers import ( - NexusStartWorkflowRequest as NexusStartWorkflowRequest, -) from ._operation_handlers import ( WorkflowRunOperationHandler as WorkflowRunOperationHandler, ) @@ -32,9 +29,6 @@ from ._operation_handlers import ( workflow_run_operation_handler as workflow_run_operation_handler, ) -from ._start_workflow import ( - start_workflow as start_workflow, -) from ._token import ( WorkflowOperationToken as WorkflowOperationToken, ) diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index 28a6cae4e..5695cc585 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -6,10 +6,13 @@ import urllib.parse from contextvars import ContextVar from dataclasses import dataclass +from datetime import timedelta from typing import ( TYPE_CHECKING, Any, + Mapping, Optional, + Sequence, Union, ) @@ -19,12 +22,16 @@ import temporalio.api.common.v1 import temporalio.api.enums.v1 import temporalio.common +from temporalio.nexus.handler._token import WorkflowOperationToken +from temporalio.types import ( + MethodAsyncSingleParam, + ParamType, + ReturnType, + SelfType, +) if TYPE_CHECKING: - from temporalio.client import ( - Client, - WorkflowHandle, - ) + from temporalio.client import Client, WorkflowHandle logger = logging.getLogger(__name__) @@ -88,6 +95,91 @@ def temporal_nexus_cancel_operation_context( return None return _TemporalNexusCancelOperationContext(ctx) + # Overload for single-param workflow + # TODO(nexus-preview): support other overloads? + async def start_workflow( + self, + workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + # TODO(nexus-preview): Allow client and task queue to be omitted, defaulting to worker's? + task_queue: str, + client: Client, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, + ) -> WorkflowOperationToken[ReturnType]: + if nexus_ctx := TemporalNexusOperationContext.try_current(): + if nexus_start_ctx := nexus_ctx.temporal_nexus_start_operation_context: + nexus_completion_callbacks = nexus_start_ctx.get_completion_callbacks() + workflow_event_links = nexus_start_ctx.get_workflow_event_links() + else: + raise RuntimeError( + "temporalio.nexus.handler.start_workflow() must be called from within a Nexus start operation context" + ) + else: + raise RuntimeError( + "temporalio.nexus.handler.start_workflow() must be called from within a Nexus operation context" + ) + + # We must pass nexus_completion_callbacks and workflow_event_links, but these are + # deliberately not exposed in overloads, hence the type-check violation. + wf_handle = await client.start_workflow( # type: ignore + workflow=workflow, + arg=arg, + id=id, + task_queue=task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + start_signal=start_signal, + start_signal_args=start_signal_args, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + request_eager_start=request_eager_start, + priority=priority, + versioning_override=versioning_override, + nexus_completion_callbacks=nexus_completion_callbacks, + workflow_event_links=workflow_event_links, + ) + + nexus_start_ctx.add_outbound_links(wf_handle) + + return WorkflowOperationToken[ReturnType]._unsafe_from_workflow_handle( + wf_handle + ) + @dataclass class _TemporalNexusStartOperationContext: diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 8602fc253..10aadd3d8 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -30,8 +30,6 @@ ) from typing_extensions import overload -import temporalio.nexus.handler - from ._operation_context import TemporalNexusOperationContext from ._token import ( WorkflowOperationToken as WorkflowOperationToken, @@ -40,7 +38,6 @@ if TYPE_CHECKING: from temporalio.client import ( Client, - WorkflowHandle, ) @@ -70,19 +67,6 @@ async def cancel_workflow( await handle.cancel(**kwargs) -class NexusStartWorkflowRequest(Generic[OutputT]): - """ - A request to start a workflow that will handle the Nexus operation. - """ - - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - - async def start_workflow(self) -> WorkflowHandle[Any, OutputT]: - return await temporalio.nexus.handler.start_workflow(*self.args, **self.kwargs) - - class WorkflowRunOperationHandler( nexusrpc.handler.OperationHandler[InputT, OutputT], Generic[InputT, OutputT, ServiceHandlerT], @@ -92,7 +76,7 @@ def __init__( service: ServiceHandlerT, start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[NexusStartWorkflowRequest[OutputT]], + Awaitable[WorkflowOperationToken[OutputT]], ], ): self.service = service @@ -115,12 +99,8 @@ async def start( # # token = WorkflowOperationToken.from_workflow_handle(wf_handle).encode() # return StartOperationResultAsync(token) - start_wf_request = await start_method(service, ctx, input) - wf_handle = await start_wf_request.start_workflow() - token = ( - WorkflowOperationToken[OutputT].from_workflow_handle(wf_handle).encode() - ) - return StartOperationResultAsync(token) + token = await start_method(service, ctx, input) + return StartOperationResultAsync(token.encode()) self.start = types.MethodType(start, self) @@ -156,7 +136,7 @@ def fetch_result( def workflow_run_operation_handler( start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[NexusStartWorkflowRequest[OutputT]], + Awaitable[WorkflowOperationToken[OutputT]], ], ) -> Callable[ [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] @@ -171,7 +151,7 @@ def workflow_run_operation_handler( [ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[NexusStartWorkflowRequest[OutputT]], + Awaitable[WorkflowOperationToken[OutputT]], ] ], Callable[ @@ -184,7 +164,7 @@ def workflow_run_operation_handler( start_method: Optional[ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[NexusStartWorkflowRequest[OutputT]], + Awaitable[WorkflowOperationToken[OutputT]], ] ] = None, *, @@ -197,7 +177,7 @@ def workflow_run_operation_handler( [ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[NexusStartWorkflowRequest[OutputT]], + Awaitable[WorkflowOperationToken[OutputT]], ] ], Callable[ @@ -209,7 +189,7 @@ def workflow_run_operation_handler( def decorator( start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[NexusStartWorkflowRequest[OutputT]], + Awaitable[WorkflowOperationToken[OutputT]], ], ) -> Callable[ [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] @@ -258,7 +238,7 @@ def factory( def _get_workflow_run_start_method_input_and_output_type_annotations( start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[NexusStartWorkflowRequest[OutputT]], + Awaitable[WorkflowOperationToken[OutputT]], ], ) -> tuple[ Optional[Type[InputT]], @@ -275,9 +255,9 @@ def _get_workflow_run_start_method_input_and_output_type_annotations( ) ) origin_type = typing.get_origin(output_type) - if not origin_type or not issubclass(origin_type, NexusStartWorkflowRequest): + if not origin_type or not issubclass(origin_type, WorkflowOperationToken): warnings.warn( - f"Expected return type of {start_method.__name__} to be a subclass of NexusStartWorkflowRequest, " + f"Expected return type of {start_method.__name__} to be a subclass of WorkflowOperationToken, " f"but is {output_type}" ) output_type = None diff --git a/temporalio/nexus/handler/_start_workflow.py b/temporalio/nexus/handler/_start_workflow.py deleted file mode 100644 index dda4aac35..000000000 --- a/temporalio/nexus/handler/_start_workflow.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import annotations - -from datetime import timedelta -from typing import ( - TYPE_CHECKING, - Any, - Mapping, - Optional, - Sequence, - Union, -) - -import temporalio.common -from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext -from temporalio.types import ( - MethodAsyncSingleParam, - ParamType, - ReturnType, - SelfType, -) - -if TYPE_CHECKING: - from temporalio.client import Client, WorkflowHandle - - -# Overload for single-param workflow -async def start_workflow( - client: Client, - workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], - arg: ParamType, - *, - id: str, - task_queue: str, - execution_timeout: Optional[timedelta] = None, - run_timeout: Optional[timedelta] = None, - task_timeout: Optional[timedelta] = None, - id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, - id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, - retry_policy: Optional[temporalio.common.RetryPolicy] = None, - cron_schedule: str = "", - memo: Optional[Mapping[str, Any]] = None, - search_attributes: Optional[ - Union[ - temporalio.common.TypedSearchAttributes, - temporalio.common.SearchAttributes, - ] - ] = None, - static_summary: Optional[str] = None, - static_details: Optional[str] = None, - start_delay: Optional[timedelta] = None, - start_signal: Optional[str] = None, - start_signal_args: Sequence[Any] = [], - rpc_metadata: Mapping[str, str] = {}, - rpc_timeout: Optional[timedelta] = None, - request_eager_start: bool = False, - priority: temporalio.common.Priority = temporalio.common.Priority.default, - versioning_override: Optional[temporalio.common.VersioningOverride] = None, -) -> WorkflowHandle[SelfType, ReturnType]: - if nexus_ctx := TemporalNexusOperationContext.try_current(): - if nexus_start_ctx := nexus_ctx.temporal_nexus_start_operation_context: - nexus_completion_callbacks = nexus_start_ctx.get_completion_callbacks() - workflow_event_links = nexus_start_ctx.get_workflow_event_links() - else: - raise RuntimeError( - "temporalio.nexus.handler.start_workflow() must be called from within a Nexus start operation context" - ) - else: - raise RuntimeError( - "temporalio.nexus.handler.start_workflow() must be called from within a Nexus operation context" - ) - - # We must pass nexus_completion_callbacks and workflow_event_links, but these are - # deliberately not exposed in overloads, hence the type check violation. - wf_handle = await client.start_workflow( # type: ignore - workflow=workflow, - arg=arg, - id=id, - task_queue=task_queue, - execution_timeout=execution_timeout, - run_timeout=run_timeout, - task_timeout=task_timeout, - id_reuse_policy=id_reuse_policy, - id_conflict_policy=id_conflict_policy, - retry_policy=retry_policy, - cron_schedule=cron_schedule, - memo=memo, - search_attributes=search_attributes, - static_summary=static_summary, - static_details=static_details, - start_delay=start_delay, - start_signal=start_signal, - start_signal_args=start_signal_args, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - request_eager_start=request_eager_start, - priority=priority, - versioning_override=versioning_override, - nexus_completion_callbacks=nexus_completion_callbacks, - workflow_event_links=workflow_event_links, - ) - - nexus_start_ctx.add_outbound_links(wf_handle) - - return wf_handle diff --git a/temporalio/nexus/handler/_token.py b/temporalio/nexus/handler/_token.py index a41be459f..47a696e79 100644 --- a/temporalio/nexus/handler/_token.py +++ b/temporalio/nexus/handler/_token.py @@ -25,27 +25,32 @@ class WorkflowOperationToken(Generic[OutputT]): # serialized token; it's only used to reject newer token versions on load. version: Optional[int] = None + def to_workflow_handle(self, client: Client) -> WorkflowHandle[Any, OutputT]: + """Create a :py:class:`temporalio.client.WorkflowHandle` from the token.""" + if client.namespace != self.namespace: + raise ValueError( + f"Client namespace {client.namespace} does not match token namespace {self.namespace}" + ) + return client.get_workflow_handle(self.workflow_id) + # TODO(nexus-preview): Is it helpful to parameterize WorkflowOperationToken by # OutputT? The return type here should be dictated by the input workflow handle # type. @classmethod - def from_workflow_handle( + def _unsafe_from_workflow_handle( cls, workflow_handle: WorkflowHandle[Any, OutputT] ) -> WorkflowOperationToken[OutputT]: - """Creates a token from a workflow handle.""" + """Create a :py:class:`WorkflowOperationToken` from a workflow handle. + + This is a private method not intended to be used by users. It does not check + that the supplied WorkflowHandle references a workflow that has been + instrumented to supply the result of a Nexus operation. + """ return cls( namespace=workflow_handle._client.namespace, workflow_id=workflow_handle.id, ) - def to_workflow_handle(self, client: Client) -> WorkflowHandle[Any, OutputT]: - """Creates a workflow handle from this token.""" - if client.namespace != self.namespace: - raise ValueError( - f"Client namespace {client.namespace} does not match token namespace {self.namespace}" - ) - return client.get_workflow_handle(self.workflow_id) - def encode(self) -> str: return _base64url_encode_no_padding( json.dumps( diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index ab891070a..a77c577e0 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -46,7 +46,7 @@ logger, ) from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext -from temporalio.nexus.handler._operation_handlers import NexusStartWorkflowRequest +from temporalio.nexus.handler._token import WorkflowOperationToken from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ServiceClient, create_nexus_endpoint @@ -211,13 +211,13 @@ async def log(self, ctx: StartOperationContext, input: Input) -> Output: @temporalio.nexus.handler.workflow_run_operation_handler async def workflow_run_operation( self, ctx: StartOperationContext, input: Input - ) -> NexusStartWorkflowRequest[Output]: + ) -> WorkflowOperationToken[Output]: tctx = TemporalNexusOperationContext.current() - return NexusStartWorkflowRequest( - tctx.client, + return await tctx.start_workflow( MyWorkflow.run, input, id=test_context.workflow_id or str(uuid.uuid4()), + client=tctx.client, task_queue=tctx.task_queue, id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) @@ -259,18 +259,18 @@ async def sync_operation_without_type_annotations(self, ctx, input): @temporalio.nexus.handler.workflow_run_operation_handler async def workflow_run_operation_without_type_annotations(self, ctx, input): tctx = TemporalNexusOperationContext.current() - return NexusStartWorkflowRequest( - tctx.client, + return await tctx.start_workflow( WorkflowWithoutTypeAnnotations.run, input, id=test_context.workflow_id or str(uuid.uuid4()), + client=tctx.client, task_queue=tctx.task_queue, ) @temporalio.nexus.handler.workflow_run_operation_handler async def workflow_run_op_link_test( self, ctx: StartOperationContext, input: Input - ) -> NexusStartWorkflowRequest[Output]: + ) -> WorkflowOperationToken[Output]: assert any( link.url == "http://inbound-link/" for link in ctx.inbound_links ), "Inbound link not found" @@ -278,11 +278,11 @@ async def workflow_run_op_link_test( ctx.outbound_links.extend(ctx.inbound_links) tctx = TemporalNexusOperationContext.current() - return NexusStartWorkflowRequest( - tctx.client, + return await tctx.start_workflow( MyLinkTestWorkflow.run, input, id=test_context.workflow_id or str(uuid.uuid4()), + client=tctx.client, task_queue=tctx.task_queue, ) diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index f33b8e9db..d62e0e581 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -6,7 +6,7 @@ import temporalio.api.failure.v1 import temporalio.nexus -from temporalio.nexus.handler import NexusStartWorkflowRequest +from temporalio.nexus.handler._token import WorkflowOperationToken HTTP_PORT = 7243 @@ -40,7 +40,7 @@ class Impl: @temporalio.nexus.handler.workflow_run_operation_handler async def op( self, ctx: nexusrpc.handler.StartOperationContext, input: str - ) -> NexusStartWorkflowRequest[int]: ... + ) -> WorkflowOperationToken[int]: ... error_message = None diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index 7ba9dad10..85c10a68c 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -10,7 +10,7 @@ import pytest import temporalio.nexus.handler -from temporalio.nexus.handler import NexusStartWorkflowRequest +from temporalio.nexus.handler._token import WorkflowOperationToken @dataclass @@ -35,7 +35,7 @@ class Service: @temporalio.nexus.handler.workflow_run_operation_handler async def workflow_run_operation_handler( self, ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> NexusStartWorkflowRequest[Output]: ... + ) -> WorkflowOperationToken[Output]: ... expected_operations = { "workflow_run_operation_handler": nexusrpc.Operation( @@ -53,7 +53,7 @@ class Service: @temporalio.nexus.handler.workflow_run_operation_handler() async def workflow_run_operation_handler( self, ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> NexusStartWorkflowRequest[Output]: ... + ) -> WorkflowOperationToken[Output]: ... expected_operations = NotCalled.expected_operations @@ -64,7 +64,7 @@ class Service: @temporalio.nexus.handler.workflow_run_operation_handler(name="operation-name") async def workflow_run_operation_with_name_override( self, ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> NexusStartWorkflowRequest[Output]: ... + ) -> WorkflowOperationToken[Output]: ... expected_operations = { "workflow_run_operation_with_name_override": nexusrpc.Operation( diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 22902e627..52a5a2cb8 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -34,7 +34,6 @@ from temporalio.common import WorkflowIDConflictPolicy from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext -from temporalio.nexus.handler._operation_handlers import NexusStartWorkflowRequest from temporalio.nexus.handler._token import WorkflowOperationToken from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import UnsandboxedWorkflowRunner, Worker @@ -157,19 +156,14 @@ async def start( ) elif isinstance(input.response_type, AsyncResponse): tctx = TemporalNexusOperationContext.current() - start_request = NexusStartWorkflowRequest[HandlerWfOutput]( # type: ignore - tctx.client, + token = await tctx.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, task_queue=tctx.task_queue, + client=tctx.client, ) - wf_handle = await start_request.start_workflow() - return nexusrpc.handler.StartOperationResultAsync( - WorkflowOperationToken[HandlerWfOutput] - .from_workflow_handle(wf_handle) - .encode() - ) + return nexusrpc.handler.StartOperationResultAsync(token.encode()) else: raise TypeError @@ -211,7 +205,7 @@ async def sync_operation( @temporalio.nexus.handler.workflow_run_operation_handler async def async_operation( self, ctx: StartOperationContext, input: OpInput - ) -> NexusStartWorkflowRequest[HandlerWfOutput]: + ) -> WorkflowOperationToken[HandlerWfOutput]: assert isinstance(input.response_type, AsyncResponse) if input.response_type.exception_in_operation_start: raise RPCError( @@ -220,11 +214,11 @@ async def async_operation( b"", ) tctx = TemporalNexusOperationContext.current() - return NexusStartWorkflowRequest( - tctx.client, + return await tctx.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, + client=tctx.client, task_queue=tctx.task_queue, ) @@ -927,7 +921,7 @@ class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: @temporalio.nexus.handler.workflow_run_operation_handler async def my_workflow_run_operation( self, ctx: StartOperationContext, input: None - ) -> NexusStartWorkflowRequest[str]: + ) -> WorkflowOperationToken[str]: tctx = TemporalNexusOperationContext.current() result_1 = await tctx.client.execute_workflow( EchoWorkflow.run, @@ -938,11 +932,11 @@ async def my_workflow_run_operation( # In case result_1 is incorrectly being delivered to the caller as the operation # result, give time for that incorrect behavior to occur. await asyncio.sleep(0.5) - return NexusStartWorkflowRequest( - tctx.client, + return await tctx.start_workflow( EchoWorkflow.run, f"{result_1}-result-2", id=str(uuid.uuid4()), + client=tctx.client, task_queue=tctx.task_queue, ) From 3f3ab83f91375bddcf167fb33703727e43427dfc Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 10:45:10 -0400 Subject: [PATCH 017/183] Cleanup --- temporalio/nexus/handler/__init__.py | 3 -- .../nexus/handler/_operation_context.py | 24 ++++++--------- .../nexus/handler/_operation_handlers.py | 1 + temporalio/worker/_nexus.py | 8 ++--- temporalio/worker/_worker.py | 29 ++++++------------- tests/helpers/nexus.py | 1 + 6 files changed, 23 insertions(+), 43 deletions(-) diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index fbf144e9b..17b03413c 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -62,9 +62,6 @@ def process( """Logger that emits additional data describing the current Nexus operation.""" -# TODO(nexus-preview): demonstrate obtaining Temporal client in sync operation. - - # TODO(nexus-prerelease): support request_id # See e.g. TS # packages/nexus/src/context.ts attachRequestId diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index 5695cc585..ada48963f 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -96,14 +96,14 @@ def temporal_nexus_cancel_operation_context( return _TemporalNexusCancelOperationContext(ctx) # Overload for single-param workflow - # TODO(nexus-preview): support other overloads? + # TODO(nexus-prerelease): support other overloads? async def start_workflow( self, workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], arg: ParamType, *, id: str, - # TODO(nexus-preview): Allow client and task queue to be omitted, defaulting to worker's? + # TODO(nexus-prerelease): Allow client and task queue to be omitted, defaulting to worker's? task_queue: str, client: Client, execution_timeout: Optional[timedelta] = None, @@ -131,17 +131,10 @@ async def start_workflow( priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, ) -> WorkflowOperationToken[ReturnType]: - if nexus_ctx := TemporalNexusOperationContext.try_current(): - if nexus_start_ctx := nexus_ctx.temporal_nexus_start_operation_context: - nexus_completion_callbacks = nexus_start_ctx.get_completion_callbacks() - workflow_event_links = nexus_start_ctx.get_workflow_event_links() - else: - raise RuntimeError( - "temporalio.nexus.handler.start_workflow() must be called from within a Nexus start operation context" - ) - else: + start_operation_context = self.temporal_nexus_start_operation_context + if not start_operation_context: raise RuntimeError( - "temporalio.nexus.handler.start_workflow() must be called from within a Nexus operation context" + "temporalio.nexus.handler.start_workflow() must be called from within a Nexus start operation context" ) # We must pass nexus_completion_callbacks and workflow_event_links, but these are @@ -170,11 +163,11 @@ async def start_workflow( request_eager_start=request_eager_start, priority=priority, versioning_override=versioning_override, - nexus_completion_callbacks=nexus_completion_callbacks, - workflow_event_links=workflow_event_links, + nexus_completion_callbacks=start_operation_context.get_completion_callbacks(), + workflow_event_links=start_operation_context.get_workflow_event_links(), ) - nexus_start_ctx.add_outbound_links(wf_handle) + start_operation_context.add_outbound_links(wf_handle) return WorkflowOperationToken[ReturnType]._unsafe_from_workflow_handle( wf_handle @@ -260,6 +253,7 @@ def _workflow_handle_to_workflow_execution_started_event_link( event_ref=temporalio.api.common.v1.Link.WorkflowEvent.EventReference( event_type=temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED ), + # TODO(nexus-prerelease): RequestIdReference? ) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 10aadd3d8..104cb8d2c 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -41,6 +41,7 @@ ) +# TODO(nexus-prerelease): revise cancel implementation async def cancel_workflow( ctx: CancelOperationContext, token: str, diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index fdb41c762..8a40b07a3 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -74,9 +74,9 @@ def __init__( ) self._handler = Handler(service_handlers, executor) self._data_converter = data_converter - # TODO(nexus-prerelease): interceptors + # TODO(nexus-preview): interceptors self._interceptors = interceptors - # TODO(nexus-prerelease): metric_meter + # TODO(nexus-preview): metric_meter self._metric_meter = metric_meter self._running_tasks: dict[bytes, asyncio.Task[Any]] = {} self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue() @@ -162,7 +162,7 @@ async def drain_poll_queue(self) -> None: async def wait_all_completed(self) -> None: await asyncio.gather(*self._running_tasks.values(), return_exceptions=True) - # TODO(nexus-prerelease): stack trace pruning. See sdk-typescript NexusHandler.execute + # TODO(nexus-preview): stack trace pruning. See sdk-typescript NexusHandler.execute # "Any call up to this function and including this one will be trimmed out of stack traces."" async def _handle_cancel_operation_task( @@ -359,7 +359,6 @@ async def _handler_error_to_proto( return temporalio.api.nexus.v1.HandlerError( error_type=err.type.value, failure=await self._exception_to_failure_proto(err), - # TODO(nexus-prerelease): is there a reason to support retryable=None? retry_behavior=( temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE if err.retryable @@ -410,7 +409,6 @@ def _exception_to_handler_error(err: BaseException) -> nexusrpc.handler.HandlerE err.message, type=nexusrpc.handler.HandlerErrorType.INTERNAL, cause=err, - # TODO(nexus-prerelease): is there a reason to support retryable=None? retryable=not err.non_retryable, ) elif isinstance(err, RPCError): diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 038ce38aa..91063388b 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -107,14 +107,11 @@ def __init__( *, task_queue: str, activities: Sequence[Callable] = [], - # TODO(nexus-prerelease): for naming consistency this should be named - # nexus_service_handlers. That will prevent users from mistakenly trying to add - # their service definitions here. nexus_service_handlers: Sequence[Any] = [], workflows: Sequence[Type] = [], activity_executor: Optional[concurrent.futures.Executor] = None, workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None, - nexus_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None, + nexus_task_executor: Optional[concurrent.futures.Executor] = None, workflow_runner: WorkflowRunner = SandboxedWorkflowRunner(), unsandboxed_workflow_runner: WorkflowRunner = UnsandboxedWorkflowRunner(), interceptors: Sequence[Interceptor] = [], @@ -183,6 +180,10 @@ def __init__( otherwise. The default one will be properly shutdown, but if one is provided, the caller is responsible for shutting it down after the worker is shut down. + nexus_operation_executor: Executor to use for non-async + Nexus operations. This is required if any operation start methods + are non-`async def`. :py:class:`concurrent.futures.ThreadPoolExecutor` + is recommended. workflow_runner: Runner for workflows. unsandboxed_workflow_runner: Runner for workflows that opt-out of sandboxing. @@ -206,8 +207,6 @@ def __init__( will ever be given to the activity worker concurrently. Mutually exclusive with ``tuner``. max_concurrent_local_activities: Maximum number of local activity tasks that will ever be given to the activity worker concurrently. Mutually exclusive with ``tuner``. - max_concurrent_workflow_tasks: Maximum allowed number of - tasks that will ever be given to the workflow worker at one time. Mutually exclusive with ``tuner``. tuner: Provide a custom :py:class:`WorkerTuner`. Mutually exclusive with the ``max_concurrent_workflow_tasks``, ``max_concurrent_activities``, and ``max_concurrent_local_activities`` arguments. @@ -307,18 +306,8 @@ def __init__( activity_task_poller_behavior: Specify the behavior of activity task polling. Defaults to a 5-poller maximum. """ - # TODO(nexus-prerelease): non-async (executor-based) Nexus worker; honor - # max_concurrent_nexus_operations and nexus_operation_executor. - # nexus_operation_executor: Concurrent executor to use for non-async - # Nexus operations. This is required if any operation start methods - # are non-async. :py:class:`concurrent.futures.ThreadPoolExecutor` - # is recommended. If this is a - # :py:class:`concurrent.futures.ProcessPoolExecutor`, all non-async - # start methods must be picklable. ``max_workers`` on the executor - # should at least be ``max_concurrent_nexus_operations`` or a warning - # is issued. - # max_concurrent_nexus_operations: Maximum number of Nexus operations that - # will ever be given to the Nexus worker concurrently. Mutually exclusive with ``tuner``. + # TODO(nexus-prerelease): Support `nexus_task_poller_behavior` in bridge worker, + # with max_concurrent_nexus_tasks and max_concurrent_nexus_tasks if not (activities or nexus_service_handlers or workflows): raise ValueError( "At least one activity, Nexus service, or workflow must be specified" @@ -470,7 +459,7 @@ def check_activity(activity): ) if tuner is not None: - # TODO(nexus-prerelease): Nexus tuner support + # TODO(nexus-preview): Nexus tuner support if ( max_concurrent_workflow_tasks or max_concurrent_activities @@ -731,7 +720,7 @@ async def raise_on_shutdown(): if self._nexus_worker: await self._nexus_worker.wait_all_completed() - # TODO(nexus-prerelease): check that we do all appropriate things for nexus worker that we do for activity worker + # TODO(nexus-preview): check that we do all appropriate things for nexus worker that we do for activity worker # Do final shutdown try: diff --git a/tests/helpers/nexus.py b/tests/helpers/nexus.py index c1225136c..57fbe35d2 100644 --- a/tests/helpers/nexus.py +++ b/tests/helpers/nexus.py @@ -50,6 +50,7 @@ async def start_operation( """ Start a Nexus operation. """ + # TODO(nexus-preview): Support callback URL as query param async with httpx.AsyncClient() as http_client: return await http_client.post( f"{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}", From 2e3181e45efc9a8bfea325bca9543f8955d2e14d Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 11:19:23 -0400 Subject: [PATCH 018/183] Refactor test --- tests/nexus/test_handler.py | 51 +++++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index a77c577e0..e6f3225e7 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -69,14 +69,6 @@ class NonSerializableOutput: callable: Callable[[], Any] = lambda: None -@dataclass -class TestContext: - workflow_id: Optional[str] = None - - -test_context = TestContext() - - # TODO: type check nexus implementation under mypy # TODO(nexus-prerelease): test dynamic creation of a service from unsugared definition @@ -216,7 +208,7 @@ async def workflow_run_operation( return await tctx.start_workflow( MyWorkflow.run, input, - id=test_context.workflow_id or str(uuid.uuid4()), + id=str(uuid.uuid4()), client=tctx.client, task_queue=tctx.task_queue, id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, @@ -262,7 +254,7 @@ async def workflow_run_operation_without_type_annotations(self, ctx, input): return await tctx.start_workflow( WorkflowWithoutTypeAnnotations.run, input, - id=test_context.workflow_id or str(uuid.uuid4()), + id=str(uuid.uuid4()), client=tctx.client, task_queue=tctx.task_queue, ) @@ -281,7 +273,7 @@ async def workflow_run_op_link_test( return await tctx.start_workflow( MyLinkTestWorkflow.run, input, - id=test_context.workflow_id or str(uuid.uuid4()), + id=str(uuid.uuid4()), client=tctx.client, task_queue=tctx.task_queue, ) @@ -1034,6 +1026,30 @@ async def test_request_id_is_received_by_sync_operation_handler( assert resp.json() == {"value": f"request_id: {request_id}"} +@workflow.defn +class EchoWorkflow: + @workflow.run + async def run(self, input: Input) -> Output: + return Output(value=input.value) + + +@nexusrpc.handler.service_handler +class ServiceHandlerForRequestIdTest: + @temporalio.nexus.handler.workflow_run_operation_handler + async def operation_backed_by_a_workflow( + self, ctx: StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: + tctx = TemporalNexusOperationContext.current() + return await tctx.start_workflow( + EchoWorkflow.run, + input, + id=input.value, + client=tctx.client, + task_queue=tctx.task_queue, + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, + ) + + async def test_request_id_becomes_start_workflow_request_id(env: WorkflowEnvironment): # We send two Nexus requests that would start a workflow with the same workflow ID, # using reuse_policy=REJECT_DUPLICATE. This would fail if they used different @@ -1045,20 +1061,17 @@ async def test_request_id_becomes_start_workflow_request_id(env: WorkflowEnviron service_client = ServiceClient( server_address=server_address(env), endpoint=endpoint, - service=MyService.__name__, + service=ServiceHandlerForRequestIdTest.__name__, ) - decorator = nexusrpc.handler.service_handler(service=MyService) - service_handler = decorator(MyServiceHandler)() - async def start_two_workflows_with_conflicting_workflow_ids( request_ids: tuple[tuple[str, int], tuple[str, int]], ): - test_context.workflow_id = str(uuid.uuid4()) + workflow_id = str(uuid.uuid4()) for request_id, status_code in request_ids: resp = await service_client.start_operation( - "workflow_run_operation", - dataclass_as_dict(Input("")), + "operation_backed_by_a_workflow", + dataclass_as_dict(Input(workflow_id)), {"Nexus-Request-Id": request_id}, ) assert resp.status_code == status_code, ( @@ -1074,7 +1087,7 @@ async def start_two_workflows_with_conflicting_workflow_ids( async with Worker( env.client, task_queue=task_queue, - nexus_service_handlers=[service_handler], + nexus_service_handlers=[ServiceHandlerForRequestIdTest()], nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), ): request_id_1, request_id_2 = str(uuid.uuid4()), str(uuid.uuid4()) From 8aa937f58b6ce5dd44f65ae71867dd2c24f524c0 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 11:40:40 -0400 Subject: [PATCH 019/183] Failing test: request ID is not used for non-backing workflow --- tests/nexus/test_handler.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index e6f3225e7..b0abe0fb1 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -1049,6 +1049,28 @@ async def operation_backed_by_a_workflow( id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) + @temporalio.nexus.handler.workflow_run_operation_handler + async def operation_that_executes_a_workflow_before_starting_the_backing_workflow( + self, ctx: StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: + tctx = TemporalNexusOperationContext.current() + await tctx.client.start_workflow( + EchoWorkflow.run, + input, + id=input.value, + task_queue=tctx.task_queue, + ) + # This should fail. It will not fail if the Nexus request ID was incorrectly + # propagated to both StartWorkflow requests. + return await tctx.start_workflow( + EchoWorkflow.run, + input, + id=input.value, + client=tctx.client, + task_queue=tctx.task_queue, + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, + ) + async def test_request_id_becomes_start_workflow_request_id(env: WorkflowEnvironment): # We send two Nexus requests that would start a workflow with the same workflow ID, @@ -1084,6 +1106,16 @@ async def start_two_workflows_with_conflicting_workflow_ids( assert op_info["token"] assert op_info["state"] == nexusrpc.OperationState.RUNNING.value + async def start_two_workflows_in_a_single_operation( + request_id: str, status_code: int + ): + resp = await service_client.start_operation( + "operation_that_executes_a_workflow_before_starting_the_backing_workflow", + dataclass_as_dict(Input("test-workflow-id")), + {"Nexus-Request-Id": request_id}, + ) + assert resp.status_code == status_code + async with Worker( env.client, task_queue=task_queue, @@ -1101,6 +1133,8 @@ async def start_two_workflows_with_conflicting_workflow_ids( await start_two_workflows_with_conflicting_workflow_ids( ((request_id_1, 201), (request_id_2, 500)) ) + # Two workflows started in the same operation should fail + await start_two_workflows_in_a_single_operation(request_id_1, 500) def server_address(env: WorkflowEnvironment) -> str: From 93d047b88d0fb8fea41f3fc8c4749fd59d6d3643 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 11:47:52 -0400 Subject: [PATCH 020/183] Bug fix: wire request_id through as top-level start_workflow param --- temporalio/client.py | 19 ++++++------------- .../nexus/handler/_operation_context.py | 1 + 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index e6d0f28cc..887d9386e 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -58,9 +58,6 @@ import temporalio.service import temporalio.workflow from temporalio.activity import ActivityCancellationDetails -from temporalio.nexus.handler import ( - TemporalNexusOperationContext, -) from temporalio.service import ( HttpConnectProxyConfig, KeepAliveConfig, @@ -477,6 +474,7 @@ async def start_workflow( workflow_event_links: Sequence[ temporalio.api.common.v1.Link.WorkflowEvent ] = [], + request_id: Optional[str] = None, stack_level: int = 2, ) -> WorkflowHandle[Any, Any]: """Start a workflow and return its handle. @@ -569,6 +567,7 @@ async def start_workflow( priority=priority, nexus_completion_callbacks=nexus_completion_callbacks, workflow_event_links=workflow_event_links, + request_id=request_id, ) ) @@ -5207,6 +5206,7 @@ class StartWorkflowInput: priority: temporalio.common.Priority nexus_completion_callbacks: Sequence[temporalio.common.NexusCompletionCallback] workflow_event_links: Sequence[temporalio.api.common.v1.Link.WorkflowEvent] + request_id: Optional[str] versioning_override: Optional[temporalio.common.VersioningOverride] = None @@ -5822,6 +5822,9 @@ async def _build_start_workflow_execution_request( ) -> temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest: req = temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest() req.request_eager_execution = input.request_eager_start + if input.request_id: + req.request_id = input.request_id + await self._populate_start_workflow_execution_request(req, input) for callback in input.nexus_completion_callbacks: c = temporalio.api.common.v1.Callback() @@ -5879,16 +5882,6 @@ async def _populate_start_workflow_execution_request( if input.task_timeout is not None: req.workflow_task_timeout.FromTimedelta(input.task_timeout) req.identity = self._client.identity - # Use Nexus request ID if we're handling a Nexus Start operation - # TODO(prerelease): confirm that we should do this for every workflow started - # TODO(prerelease): add test coverage for multiple workflows started by a Nexus operation - if nexus_ctx := TemporalNexusOperationContext.try_current(): - if nexus_start_ctx := nexus_ctx.temporal_nexus_start_operation_context: - if ( - nexus_request_id - := nexus_start_ctx.nexus_operation_context.request_id - ): - req.request_id = nexus_request_id if not req.request_id: req.request_id = str(uuid.uuid4()) req.workflow_id_reuse_policy = cast( diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index ada48963f..4997ae6c0 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -165,6 +165,7 @@ async def start_workflow( versioning_override=versioning_override, nexus_completion_callbacks=start_operation_context.get_completion_callbacks(), workflow_event_links=start_operation_context.get_workflow_event_links(), + request_id=start_operation_context.nexus_operation_context.request_id, ) start_operation_context.add_outbound_links(wf_handle) From 2b77eb132300e821b43181e3ba345b0b8dc321eb Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 14:34:40 -0400 Subject: [PATCH 021/183] Rename: TemporalOperationContext --- temporalio/nexus/handler/__init__.py | 6 ++-- .../nexus/handler/_operation_context.py | 36 +++++++++---------- .../nexus/handler/_operation_handlers.py | 4 +-- temporalio/worker/_nexus.py | 12 +++---- tests/nexus/test_handler.py | 12 +++---- tests/nexus/test_workflow_caller.py | 9 +++-- 6 files changed, 37 insertions(+), 42 deletions(-) diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index 17b03413c..151081d98 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -19,9 +19,7 @@ HandlerErrorType as HandlerErrorType, ) -from ._operation_context import ( - TemporalNexusOperationContext as TemporalNexusOperationContext, -) +from ._operation_context import TemporalOperationContext as TemporalOperationContext from ._operation_handlers import ( WorkflowRunOperationHandler as WorkflowRunOperationHandler, ) @@ -50,7 +48,7 @@ def process( self, msg: Any, kwargs: MutableMapping[str, Any] ) -> tuple[Any, MutableMapping[str, Any]]: extra = dict(self.extra or {}) - if tctx := TemporalNexusOperationContext.current(): + if tctx := TemporalOperationContext.current(): extra["service"] = tctx.nexus_operation_context.service extra["operation"] = tctx.nexus_operation_context.operation extra["task_queue"] = tctx.task_queue diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index 4997ae6c0..fd79ca172 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -37,13 +37,13 @@ logger = logging.getLogger(__name__) -_current_context: ContextVar[TemporalNexusOperationContext] = ContextVar( - "temporal-nexus-operation-context" +_current_context: ContextVar[TemporalOperationContext] = ContextVar( + "temporal-operation-context" ) @dataclass -class TemporalNexusOperationContext: +class TemporalOperationContext: """ Context for a Nexus operation being handled by a Temporal Nexus Worker. """ @@ -57,43 +57,43 @@ class TemporalNexusOperationContext: """The task queue of the worker handling this Nexus operation.""" @staticmethod - def try_current() -> Optional[TemporalNexusOperationContext]: + def try_current() -> Optional[TemporalOperationContext]: return _current_context.get(None) @staticmethod - def current() -> TemporalNexusOperationContext: - context = TemporalNexusOperationContext.try_current() + def current() -> TemporalOperationContext: + context = TemporalOperationContext.try_current() if not context: raise RuntimeError("Not in Nexus operation context") return context @staticmethod def set( - context: TemporalNexusOperationContext, - ) -> contextvars.Token[TemporalNexusOperationContext]: + context: TemporalOperationContext, + ) -> contextvars.Token[TemporalOperationContext]: return _current_context.set(context) @staticmethod - def reset(token: contextvars.Token[TemporalNexusOperationContext]) -> None: + def reset(token: contextvars.Token[TemporalOperationContext]) -> None: _current_context.reset(token) @property - def temporal_nexus_start_operation_context( + def temporal_start_operation_context( self, - ) -> Optional[_TemporalNexusStartOperationContext]: + ) -> Optional[_TemporalStartOperationContext]: ctx = self.nexus_operation_context if not isinstance(ctx, StartOperationContext): return None - return _TemporalNexusStartOperationContext(ctx) + return _TemporalStartOperationContext(ctx) @property - def temporal_nexus_cancel_operation_context( + def temporal_cancel_operation_context( self, - ) -> Optional[_TemporalNexusCancelOperationContext]: + ) -> Optional[_TemporalCancelOperationContext]: ctx = self.nexus_operation_context if not isinstance(ctx, CancelOperationContext): return None - return _TemporalNexusCancelOperationContext(ctx) + return _TemporalCancelOperationContext(ctx) # Overload for single-param workflow # TODO(nexus-prerelease): support other overloads? @@ -131,7 +131,7 @@ async def start_workflow( priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, ) -> WorkflowOperationToken[ReturnType]: - start_operation_context = self.temporal_nexus_start_operation_context + start_operation_context = self.temporal_start_operation_context if not start_operation_context: raise RuntimeError( "temporalio.nexus.handler.start_workflow() must be called from within a Nexus start operation context" @@ -176,7 +176,7 @@ async def start_workflow( @dataclass -class _TemporalNexusStartOperationContext: +class _TemporalStartOperationContext: nexus_operation_context: StartOperationContext def get_completion_callbacks( @@ -233,7 +233,7 @@ def add_outbound_links(self, workflow_handle: WorkflowHandle[Any, Any]): @dataclass -class _TemporalNexusCancelOperationContext: +class _TemporalCancelOperationContext: nexus_operation_context: CancelOperationContext diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 104cb8d2c..c14ebeacb 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -30,7 +30,7 @@ ) from typing_extensions import overload -from ._operation_context import TemporalNexusOperationContext +from ._operation_context import TemporalOperationContext from ._token import ( WorkflowOperationToken as WorkflowOperationToken, ) @@ -48,7 +48,7 @@ async def cancel_workflow( client: Optional[Client] = None, # noqa **kwargs: Any, ) -> None: - client = client or TemporalNexusOperationContext.current().client + client = client or TemporalOperationContext.current().client try: decoded = WorkflowOperationToken[Any].decode(token) except Exception as err: diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 8a40b07a3..bf597b433 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -37,9 +37,7 @@ import temporalio.nexus import temporalio.nexus.handler from temporalio.exceptions import ApplicationError -from temporalio.nexus.handler import ( - TemporalNexusOperationContext, -) +from temporalio.nexus.handler import TemporalOperationContext from temporalio.service import RPCError, RPCStatusCode from ._interceptor import Interceptor @@ -178,8 +176,8 @@ async def _handle_cancel_operation_task( service=request.service, operation=request.operation, ) - TemporalNexusOperationContext.set( - TemporalNexusOperationContext( + TemporalOperationContext.set( + TemporalOperationContext( nexus_operation_context=ctx, client=self._client, task_queue=self._task_queue, @@ -281,8 +279,8 @@ async def _start_operation( ], callback_headers=dict(start_request.callback_header), ) - TemporalNexusOperationContext.set( - TemporalNexusOperationContext( + TemporalOperationContext.set( + TemporalOperationContext( nexus_operation_context=ctx, client=self._client, task_queue=self._task_queue, diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index b0abe0fb1..3535622e4 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -45,7 +45,7 @@ from temporalio.nexus.handler import ( logger, ) -from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext +from temporalio.nexus.handler._operation_context import TemporalOperationContext from temporalio.nexus.handler._token import WorkflowOperationToken from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker @@ -204,7 +204,7 @@ async def log(self, ctx: StartOperationContext, input: Input) -> Output: async def workflow_run_operation( self, ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = TemporalNexusOperationContext.current() + tctx = TemporalOperationContext.current() return await tctx.start_workflow( MyWorkflow.run, input, @@ -250,7 +250,7 @@ async def sync_operation_without_type_annotations(self, ctx, input): @temporalio.nexus.handler.workflow_run_operation_handler async def workflow_run_operation_without_type_annotations(self, ctx, input): - tctx = TemporalNexusOperationContext.current() + tctx = TemporalOperationContext.current() return await tctx.start_workflow( WorkflowWithoutTypeAnnotations.run, input, @@ -269,7 +269,7 @@ async def workflow_run_op_link_test( assert ctx.request_id == "test-request-id-123", "Request ID mismatch" ctx.outbound_links.extend(ctx.inbound_links) - tctx = TemporalNexusOperationContext.current() + tctx = TemporalOperationContext.current() return await tctx.start_workflow( MyLinkTestWorkflow.run, input, @@ -1039,7 +1039,7 @@ class ServiceHandlerForRequestIdTest: async def operation_backed_by_a_workflow( self, ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = TemporalNexusOperationContext.current() + tctx = TemporalOperationContext.current() return await tctx.start_workflow( EchoWorkflow.run, input, @@ -1053,7 +1053,7 @@ async def operation_backed_by_a_workflow( async def operation_that_executes_a_workflow_before_starting_the_backing_workflow( self, ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = TemporalNexusOperationContext.current() + tctx = TemporalOperationContext.current() await tctx.client.start_workflow( EchoWorkflow.run, input, diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 52a5a2cb8..58aaad48c 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -33,8 +33,7 @@ ) from temporalio.common import WorkflowIDConflictPolicy from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError -from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext -from temporalio.nexus.handler._token import WorkflowOperationToken +from temporalio.nexus.handler import TemporalOperationContext, WorkflowOperationToken from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -155,7 +154,7 @@ async def start( value=OpOutput(value="sync response") ) elif isinstance(input.response_type, AsyncResponse): - tctx = TemporalNexusOperationContext.current() + tctx = TemporalOperationContext.current() token = await tctx.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), @@ -213,7 +212,7 @@ async def async_operation( RPCStatusCode.INVALID_ARGUMENT, b"", ) - tctx = TemporalNexusOperationContext.current() + tctx = TemporalOperationContext.current() return await tctx.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), @@ -922,7 +921,7 @@ class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: async def my_workflow_run_operation( self, ctx: StartOperationContext, input: None ) -> WorkflowOperationToken[str]: - tctx = TemporalNexusOperationContext.current() + tctx = TemporalOperationContext.current() result_1 = await tctx.client.execute_workflow( EchoWorkflow.run, "result-1", From a7fc5430f841d9c3551745ecdb2191f421317779 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 15:13:51 -0400 Subject: [PATCH 022/183] Rename: cancel_operation --- temporalio/nexus/handler/__init__.py | 2 +- .../nexus/handler/_operation_handlers.py | 26 ++++++++++++------- temporalio/nexus/handler/_token.py | 8 +++--- tests/nexus/test_handler.py | 2 +- tests/nexus/test_workflow_caller.py | 3 ++- 5 files changed, 24 insertions(+), 17 deletions(-) diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index 151081d98..d0e1df8b6 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -23,7 +23,7 @@ from ._operation_handlers import ( WorkflowRunOperationHandler as WorkflowRunOperationHandler, ) -from ._operation_handlers import cancel_workflow as cancel_workflow +from ._operation_handlers import cancel_operation as cancel_operation from ._operation_handlers import ( workflow_run_operation_handler as workflow_run_operation_handler, ) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index c14ebeacb..cce7484da 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -30,7 +30,8 @@ ) from typing_extensions import overload -from ._operation_context import TemporalOperationContext +from temporalio.nexus.handler._operation_context import TemporalOperationContext + from ._token import ( WorkflowOperationToken as WorkflowOperationToken, ) @@ -41,24 +42,28 @@ ) -# TODO(nexus-prerelease): revise cancel implementation -async def cancel_workflow( - ctx: CancelOperationContext, +async def cancel_operation( token: str, - client: Optional[Client] = None, # noqa + client: Client, **kwargs: Any, ) -> None: - client = client or TemporalOperationContext.current().client + """Cancel a Nexus operation. + + Args: + token: The token of the operation to cancel. + client: The client to use to cancel the operation. + """ try: - decoded = WorkflowOperationToken[Any].decode(token) + workflow_token = WorkflowOperationToken[Any].decode(token) except Exception as err: raise HandlerError( - "Failed to decode workflow operation token", + "Failed to decode operation token as workflow operation token. " + "Canceling non-workflow operations is not supported.", type=HandlerErrorType.NOT_FOUND, cause=err, ) try: - handle = decoded.to_workflow_handle(client) + handle = workflow_token.to_workflow_handle(client) except Exception as err: raise HandlerError( "Failed to construct workflow handle from workflow operation token", @@ -114,7 +119,8 @@ async def start( ) async def cancel(self, ctx: CancelOperationContext, token: str) -> None: - await cancel_workflow(ctx, token) + tctx = TemporalOperationContext.current() + await cancel_operation(token, tctx.client) def fetch_info( self, ctx: nexusrpc.handler.FetchOperationInfoContext, token: str diff --git a/temporalio/nexus/handler/_token.py b/temporalio/nexus/handler/_token.py index 47a696e79..ecb5d06cf 100644 --- a/temporalio/nexus/handler/_token.py +++ b/temporalio/nexus/handler/_token.py @@ -29,13 +29,13 @@ def to_workflow_handle(self, client: Client) -> WorkflowHandle[Any, OutputT]: """Create a :py:class:`temporalio.client.WorkflowHandle` from the token.""" if client.namespace != self.namespace: raise ValueError( - f"Client namespace {client.namespace} does not match token namespace {self.namespace}" + f"Client namespace {client.namespace} does not match " + f"operation token namespace {self.namespace}" ) return client.get_workflow_handle(self.workflow_id) - # TODO(nexus-preview): Is it helpful to parameterize WorkflowOperationToken by - # OutputT? The return type here should be dictated by the input workflow handle - # type. + # TODO(nexus-preview): The return type here should be dictated by the input workflow + # handle type. @classmethod def _unsafe_from_workflow_handle( cls, workflow_handle: WorkflowHandle[Any, OutputT] diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 3535622e4..32c428513 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -995,7 +995,7 @@ async def test_cancel_operation_with_invalid_token(env: WorkflowEnvironment): ) assert cancel_response.status_code == 404 failure = Failure(**cancel_response.json()) - assert "failed to decode workflow operation token" in failure.message.lower() + assert "failed to decode operation token" in failure.message.lower() async def test_request_id_is_received_by_sync_operation_handler( diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 58aaad48c..5f93ffceb 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -167,7 +167,8 @@ async def start( raise TypeError async def cancel(self, ctx: CancelOperationContext, token: str) -> None: - return await temporalio.nexus.handler.cancel_workflow(ctx, token) + tctx = TemporalOperationContext.current() + return await temporalio.nexus.handler.cancel_operation(token, tctx.client) async def fetch_info( self, ctx: FetchOperationInfoContext, token: str From dd7ecff559b4bc03ce8f43206065b6135f604a27 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 15:59:16 -0400 Subject: [PATCH 023/183] Cleanup --- temporalio/nexus/handler/__init__.py | 6 ---- .../nexus/handler/_operation_context.py | 34 ++++++++++++++++++- temporalio/worker/_worker.py | 2 +- tests/helpers/nexus.py | 2 +- tests/nexus/test_handler.py | 3 +- 5 files changed, 37 insertions(+), 10 deletions(-) diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index d0e1df8b6..cc400da49 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -58,9 +58,3 @@ def process( logger = LoggerAdapter(logging.getLogger(__name__), None) """Logger that emits additional data describing the current Nexus operation.""" - - -# TODO(nexus-prerelease): support request_id -# See e.g. TS -# packages/nexus/src/context.ts attachRequestId -# packages/test/src/test-nexus-handler.ts ctx.requestId diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index fd79ca172..d5128e93c 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -131,12 +131,44 @@ async def start_workflow( priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, ) -> WorkflowOperationToken[ReturnType]: + """Start a workflow that will deliver the result of the Nexus operation. + + The workflow will be started as usual, with the following modifications: + + - On workflow completion, Temporal server will deliver the workflow result to + the Nexus operation caller, using the callback from the Nexus operation start + request. + + - The request ID from the Nexus operation start request will be used as the + request ID for the start workflow request. + + - Inbound links to the caller that were submitted in the Nexus start operation + request will be attached to the started workflow and, outbound links to the + started workflow will be added to the Nexus start operation response. If the + Nexus caller is itself a workflow, this means that the workflow in the caller + namespace web UI will contain links to the started workflow, and vice versa. + + Args: + client: The client to use to start the workflow. + + See :py:meth:`temporalio.client.Client.start_workflow` for all other arguments. + """ start_operation_context = self.temporal_start_operation_context if not start_operation_context: raise RuntimeError( - "temporalio.nexus.handler.start_workflow() must be called from within a Nexus start operation context" + "temporalio.nexus.handler.start_workflow() must be called from " + "within a Nexus start operation context" ) + # TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this: + # if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') { + # internalOptions.onConflictOptions = { + # attachLinks: true, + # attachCompletionCallbacks: true, + # attachRequestId: true, + # }; + # } + # We must pass nexus_completion_callbacks and workflow_event_links, but these are # deliberately not exposed in overloads, hence the type-check violation. wf_handle = await client.start_workflow( # type: ignore diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 91063388b..90c1790b3 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -802,7 +802,7 @@ class WorkerConfig(TypedDict, total=False): workflows: Sequence[Type] activity_executor: Optional[concurrent.futures.Executor] workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] - nexus_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] + nexus_task_executor: Optional[concurrent.futures.Executor] workflow_runner: WorkflowRunner unsandboxed_workflow_runner: WorkflowRunner interceptors: Sequence[Interceptor] diff --git a/tests/helpers/nexus.py b/tests/helpers/nexus.py index 57fbe35d2..88aadc7c5 100644 --- a/tests/helpers/nexus.py +++ b/tests/helpers/nexus.py @@ -14,7 +14,7 @@ def make_nexus_endpoint_name(task_queue: str) -> str: return f"nexus-endpoint-{task_queue}" -# TODO(nexus-prerelease): How do we recommend that users create endpoints in their own tests? +# TODO(nexus-preview): How do we recommend that users create endpoints in their own tests? # See https://github.com/temporalio/sdk-typescript/pull/1708/files?show-viewed-files=true&file-filters%5B%5D=&w=0#r2082549085 async def create_nexus_endpoint( task_queue: str, client: Client diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 32c428513..5d6effe67 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -1133,7 +1133,8 @@ async def start_two_workflows_in_a_single_operation( await start_two_workflows_with_conflicting_workflow_ids( ((request_id_1, 201), (request_id_2, 500)) ) - # Two workflows started in the same operation should fail + # Two workflows started in the same operation should fail, since the Nexus + # request ID should be propagated to the backing workflow only. await start_two_workflows_in_a_single_operation(request_id_1, 500) From 95f8b4a1a840dcdc497f9bc334b3f92f7e5d0d41 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 16:26:05 -0400 Subject: [PATCH 024/183] Do not allow Nexus operation to set client used for starting workflow --- temporalio/nexus/handler/_operation_context.py | 5 ++--- tests/nexus/test_handler.py | 5 ----- tests/nexus/test_workflow_caller.py | 3 --- 3 files changed, 2 insertions(+), 11 deletions(-) diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index d5128e93c..0164d188b 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -103,9 +103,8 @@ async def start_workflow( arg: ParamType, *, id: str, - # TODO(nexus-prerelease): Allow client and task queue to be omitted, defaulting to worker's? + # TODO(nexus-prerelease): Allow task queue to be omitted, defaulting to worker's? task_queue: str, - client: Client, execution_timeout: Optional[timedelta] = None, run_timeout: Optional[timedelta] = None, task_timeout: Optional[timedelta] = None, @@ -171,7 +170,7 @@ async def start_workflow( # We must pass nexus_completion_callbacks and workflow_event_links, but these are # deliberately not exposed in overloads, hence the type-check violation. - wf_handle = await client.start_workflow( # type: ignore + wf_handle = await self.client.start_workflow( # type: ignore workflow=workflow, arg=arg, id=id, diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 5d6effe67..432d24209 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -209,7 +209,6 @@ async def workflow_run_operation( MyWorkflow.run, input, id=str(uuid.uuid4()), - client=tctx.client, task_queue=tctx.task_queue, id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) @@ -255,7 +254,6 @@ async def workflow_run_operation_without_type_annotations(self, ctx, input): WorkflowWithoutTypeAnnotations.run, input, id=str(uuid.uuid4()), - client=tctx.client, task_queue=tctx.task_queue, ) @@ -274,7 +272,6 @@ async def workflow_run_op_link_test( MyLinkTestWorkflow.run, input, id=str(uuid.uuid4()), - client=tctx.client, task_queue=tctx.task_queue, ) @@ -1044,7 +1041,6 @@ async def operation_backed_by_a_workflow( EchoWorkflow.run, input, id=input.value, - client=tctx.client, task_queue=tctx.task_queue, id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) @@ -1066,7 +1062,6 @@ async def operation_that_executes_a_workflow_before_starting_the_backing_workflo EchoWorkflow.run, input, id=input.value, - client=tctx.client, task_queue=tctx.task_queue, id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 5f93ffceb..3a71f4dbd 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -160,7 +160,6 @@ async def start( HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, task_queue=tctx.task_queue, - client=tctx.client, ) return nexusrpc.handler.StartOperationResultAsync(token.encode()) else: @@ -218,7 +217,6 @@ async def async_operation( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, - client=tctx.client, task_queue=tctx.task_queue, ) @@ -936,7 +934,6 @@ async def my_workflow_run_operation( EchoWorkflow.run, f"{result_1}-result-2", id=str(uuid.uuid4()), - client=tctx.client, task_queue=tctx.task_queue, ) From 6d6887a80769ce75bdd68d635b21a45ecd08f20e Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 16:44:24 -0400 Subject: [PATCH 025/183] Make task queue optional when starting workflows --- temporalio/nexus/handler/_operation_context.py | 13 ++++++------- tests/nexus/test_handler.py | 5 ----- tests/nexus/test_workflow_caller.py | 3 --- 3 files changed, 6 insertions(+), 15 deletions(-) diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index 0164d188b..eb29ca900 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -103,8 +103,7 @@ async def start_workflow( arg: ParamType, *, id: str, - # TODO(nexus-prerelease): Allow task queue to be omitted, defaulting to worker's? - task_queue: str, + task_queue: Optional[str] = None, execution_timeout: Optional[timedelta] = None, run_timeout: Optional[timedelta] = None, task_timeout: Optional[timedelta] = None, @@ -132,6 +131,9 @@ async def start_workflow( ) -> WorkflowOperationToken[ReturnType]: """Start a workflow that will deliver the result of the Nexus operation. + The workflow will be started in the same namespace as the Nexus worker, using + the same client as the worker. If task queue is not specified, the worker's task queue will be used. + The workflow will be started as usual, with the following modifications: - On workflow completion, Temporal server will deliver the workflow result to @@ -147,10 +149,7 @@ async def start_workflow( Nexus caller is itself a workflow, this means that the workflow in the caller namespace web UI will contain links to the started workflow, and vice versa. - Args: - client: The client to use to start the workflow. - - See :py:meth:`temporalio.client.Client.start_workflow` for all other arguments. + See :py:meth:`temporalio.client.Client.start_workflow` for all arguments. """ start_operation_context = self.temporal_start_operation_context if not start_operation_context: @@ -174,7 +173,7 @@ async def start_workflow( workflow=workflow, arg=arg, id=id, - task_queue=task_queue, + task_queue=task_queue or self.task_queue, execution_timeout=execution_timeout, run_timeout=run_timeout, task_timeout=task_timeout, diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 432d24209..3df09f103 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -209,7 +209,6 @@ async def workflow_run_operation( MyWorkflow.run, input, id=str(uuid.uuid4()), - task_queue=tctx.task_queue, id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) @@ -254,7 +253,6 @@ async def workflow_run_operation_without_type_annotations(self, ctx, input): WorkflowWithoutTypeAnnotations.run, input, id=str(uuid.uuid4()), - task_queue=tctx.task_queue, ) @temporalio.nexus.handler.workflow_run_operation_handler @@ -272,7 +270,6 @@ async def workflow_run_op_link_test( MyLinkTestWorkflow.run, input, id=str(uuid.uuid4()), - task_queue=tctx.task_queue, ) class OperationHandlerReturningUnwrappedResult( @@ -1041,7 +1038,6 @@ async def operation_backed_by_a_workflow( EchoWorkflow.run, input, id=input.value, - task_queue=tctx.task_queue, id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) @@ -1062,7 +1058,6 @@ async def operation_that_executes_a_workflow_before_starting_the_backing_workflo EchoWorkflow.run, input, id=input.value, - task_queue=tctx.task_queue, id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 3a71f4dbd..8188c87d2 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -159,7 +159,6 @@ async def start( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, - task_queue=tctx.task_queue, ) return nexusrpc.handler.StartOperationResultAsync(token.encode()) else: @@ -217,7 +216,6 @@ async def async_operation( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, - task_queue=tctx.task_queue, ) @@ -934,7 +932,6 @@ async def my_workflow_run_operation( EchoWorkflow.run, f"{result_1}-result-2", id=str(uuid.uuid4()), - task_queue=tctx.task_queue, ) From b0d7a600c67d174b533eac519e2085514ba67844 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 17:19:47 -0400 Subject: [PATCH 026/183] Add nexus_task_poller_behavior --- temporalio/bridge/src/worker.rs | 2 ++ temporalio/bridge/worker.py | 1 + temporalio/worker/_replayer.py | 3 +++ temporalio/worker/_worker.py | 12 +++++++----- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/temporalio/bridge/src/worker.rs b/temporalio/bridge/src/worker.rs index 4fb3085ed..930acedd3 100644 --- a/temporalio/bridge/src/worker.rs +++ b/temporalio/bridge/src/worker.rs @@ -60,6 +60,7 @@ pub struct WorkerConfig { graceful_shutdown_period_millis: u64, nondeterminism_as_workflow_fail: bool, nondeterminism_as_workflow_fail_for_types: HashSet, + nexus_task_poller_behavior: PollerBehavior, } #[derive(FromPyObject)] @@ -722,6 +723,7 @@ fn convert_worker_config( }) .collect::>>(), ) + .nexus_task_poller_behavior(conf.nexus_task_poller_behavior) .build() .map_err(|err| PyValueError::new_err(format!("Invalid worker config: {err}"))) } diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index e98a54470..e97563bf1 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -61,6 +61,7 @@ class WorkerConfig: graceful_shutdown_period_millis: int nondeterminism_as_workflow_fail: bool nondeterminism_as_workflow_fail_for_types: Set[str] + nexus_task_poller_behavior: PollerBehavior @dataclass diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index 7600118d6..c016495c7 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -261,6 +261,9 @@ def on_eviction_hook( activity_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( 1 ), + nexus_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( + 1 + ), ), ) # Start worker diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 90c1790b3..7d30b3511 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -146,6 +146,9 @@ def __init__( activity_task_poller_behavior: PollerBehavior = PollerBehaviorSimpleMaximum( maximum=5 ), + nexus_task_poller_behavior: PollerBehavior = PollerBehaviorSimpleMaximum( + maximum=5 + ), ) -> None: """Create a worker to process workflows and/or activities. @@ -305,9 +308,10 @@ def __init__( Defaults to a 5-poller maximum. activity_task_poller_behavior: Specify the behavior of activity task polling. Defaults to a 5-poller maximum. + nexus_task_poller_behavior: Specify the behavior of Nexus task polling. + Defaults to a 5-poller maximum. """ - # TODO(nexus-prerelease): Support `nexus_task_poller_behavior` in bridge worker, - # with max_concurrent_nexus_tasks and max_concurrent_nexus_tasks + # TODO(nexus-prerelease): max_concurrent_nexus_tasks / tuner support if not (activities or nexus_service_handlers or workflows): raise ValueError( "At least one activity, Nexus service, or workflow must be specified" @@ -408,8 +412,6 @@ def __init__( ) self._nexus_worker: Optional[_NexusWorker] = None if nexus_service_handlers: - # TODO(nexus-prerelease): consider not allowing / warning on max_workers < - # max_concurrent_nexus_operations? See warning above for activity worker. self._nexus_worker = _NexusWorker( bridge_worker=lambda: self._bridge_worker, client=client, @@ -459,7 +461,6 @@ def check_activity(activity): ) if tuner is not None: - # TODO(nexus-preview): Nexus tuner support if ( max_concurrent_workflow_tasks or max_concurrent_activities @@ -552,6 +553,7 @@ def check_activity(activity): versioning_strategy=versioning_strategy, workflow_task_poller_behavior=workflow_task_poller_behavior._to_bridge(), activity_task_poller_behavior=activity_task_poller_behavior._to_bridge(), + nexus_task_poller_behavior=nexus_task_poller_behavior._to_bridge(), ), ) From 0322767c3c1314154aca275f27f4e1bfadfb8d44 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 17:29:54 -0400 Subject: [PATCH 027/183] Handle PollShutdownError --- temporalio/worker/_nexus.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index bf597b433..ce81e8f7c 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -135,8 +135,9 @@ async def raise_from_exception_queue() -> NoReturn: else: raise NotImplementedError(f"Invalid Nexus task: {task}") - # TODO(nexus-prerelease): handle poller shutdown - # except temporalio.bridge.worker.PollShutdownError + except temporalio.bridge.worker.PollShutdownError: + exception_task.cancel() + return except Exception as err: raise RuntimeError("Nexus worker failed") from err From 8c2de6e3de9f95997cf7fb00581fdbdc247459ad Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 22:12:15 -0400 Subject: [PATCH 028/183] Respond to upstream: default to async --- temporalio/worker/_nexus.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index ce81e8f7c..cd02d6520 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -18,12 +18,8 @@ import google.protobuf.json_format import nexusrpc.handler -from nexusrpc import LazyValueAsync as LazyValue -from nexusrpc.handler import ( - CancelOperationContext, - StartOperationContext, -) -from nexusrpc.handler import HandlerAsync as Handler +from nexusrpc import LazyValue +from nexusrpc.handler import CancelOperationContext, Handler, StartOperationContext import temporalio.api.common.v1 import temporalio.api.enums.v1 From 7bc369237f8dfceb9c0056a9a28caab7677b6eb4 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 22 Jun 2025 19:19:53 -0400 Subject: [PATCH 029/183] Cleanup; changes from review comments - Move NexusCompletionCallback into client.py --- temporalio/client.py | 18 ++-- temporalio/common.py | 31 ------- .../nexus/handler/_operation_context.py | 35 ++++---- .../nexus/handler/_operation_handlers.py | 83 +++++++------------ temporalio/types.py | 1 + temporalio/worker/_workflow_instance.py | 6 -- temporalio/workflow.py | 15 ++-- tests/nexus/test_workflow_caller.py | 17 ---- 8 files changed, 75 insertions(+), 131 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 887d9386e..2a65f88e3 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -53,7 +53,6 @@ import temporalio.common import temporalio.converter import temporalio.exceptions -import temporalio.nexus.handler import temporalio.runtime import temporalio.service import temporalio.workflow @@ -468,9 +467,7 @@ async def start_workflow( priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, # The following options are deliberately not exposed in overloads - nexus_completion_callbacks: Sequence[ - temporalio.common.NexusCompletionCallback - ] = [], + nexus_completion_callbacks: Sequence[NexusCompletionCallback] = [], workflow_event_links: Sequence[ temporalio.api.common.v1.Link.WorkflowEvent ] = [], @@ -5204,7 +5201,7 @@ class StartWorkflowInput: rpc_timeout: Optional[timedelta] request_eager_start: bool priority: temporalio.common.Priority - nexus_completion_callbacks: Sequence[temporalio.common.NexusCompletionCallback] + nexus_completion_callbacks: Sequence[NexusCompletionCallback] workflow_event_links: Sequence[temporalio.api.common.v1.Link.WorkflowEvent] request_id: Optional[str] versioning_override: Optional[temporalio.common.VersioningOverride] = None @@ -7259,6 +7256,17 @@ def api_key(self, value: Optional[str]) -> None: self.service_client.update_api_key(value) +@dataclass(frozen=True) +class NexusCompletionCallback: + """Nexus callback to attach to events such as workflow completion.""" + + url: str + """Callback URL.""" + + header: Mapping[str, str] + """Header to attach to callback request.""" + + async def _encode_user_metadata( converter: temporalio.converter.DataConverter, summary: Optional[Union[str, temporalio.api.common.v1.Payload]], diff --git a/temporalio/common.py b/temporalio/common.py index dbc04a3b1..b9b088e86 100644 --- a/temporalio/common.py +++ b/temporalio/common.py @@ -197,37 +197,6 @@ def __setstate__(self, state: object) -> None: ) -@dataclass(frozen=True) -class NexusCompletionCallback: - """Nexus callback to attach to events such as workflow completion.""" - - url: str - """Callback URL.""" - - header: Mapping[str, str] - """Header to attach to callback request.""" - - -@dataclass(frozen=True) -class WorkflowEventLink: - """A link to a history event that can be attached to a different history event.""" - - namespace: str - """Namespace of the workflow to link to.""" - - workflow_id: str - """ID of the workflow to link to.""" - - run_id: str - """Run ID of the workflow to link to.""" - - event_type: temporalio.api.enums.v1.EventType - """Type of the event to link to.""" - - event_id: int - """ID of the event to link to.""" - - # We choose to make this a list instead of an sequence so we can catch if people # are not sending lists each time but maybe accidentally sending a string (which # is a sequence) diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index eb29ca900..54f8a7edd 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -8,7 +8,6 @@ from dataclasses import dataclass from datetime import timedelta from typing import ( - TYPE_CHECKING, Any, Mapping, Optional, @@ -22,6 +21,7 @@ import temporalio.api.common.v1 import temporalio.api.enums.v1 import temporalio.common +from temporalio.client import Client, NexusCompletionCallback, WorkflowHandle from temporalio.nexus.handler._token import WorkflowOperationToken from temporalio.types import ( MethodAsyncSingleParam, @@ -30,10 +30,6 @@ SelfType, ) -if TYPE_CHECKING: - from temporalio.client import Client, WorkflowHandle - - logger = logging.getLogger(__name__) @@ -56,6 +52,8 @@ class TemporalOperationContext: task_queue: str """The task queue of the worker handling this Nexus operation.""" + # TODO(nexus-prerelease): I don't think I like these names. Perhaps .get(), or + # expose the contextvar directly in the public API. @staticmethod def try_current() -> Optional[TemporalOperationContext]: return _current_context.get(None) @@ -132,7 +130,15 @@ async def start_workflow( """Start a workflow that will deliver the result of the Nexus operation. The workflow will be started in the same namespace as the Nexus worker, using - the same client as the worker. If task queue is not specified, the worker's task queue will be used. + the same client as the worker. If task queue is not specified, the worker's task + queue will be used. + + See :py:meth:`temporalio.client.Client.start_workflow` for all arguments. + + The return value is :py:class:`temporalio.nexus.handler.WorkflowOperationToken`. + Use :py:meth:`temporalio.nexus.handler.WorkflowOperationToken.to_workflow_handle` + to get a :py:class:`temporalio.client.WorkflowHandle` for interacting with the + workflow. The workflow will be started as usual, with the following modifications: @@ -148,8 +154,6 @@ async def start_workflow( started workflow will be added to the Nexus start operation response. If the Nexus caller is itself a workflow, this means that the workflow in the caller namespace web UI will contain links to the started workflow, and vice versa. - - See :py:meth:`temporalio.client.Client.start_workflow` for all arguments. """ start_operation_context = self.temporal_start_operation_context if not start_operation_context: @@ -167,8 +171,9 @@ async def start_workflow( # }; # } - # We must pass nexus_completion_callbacks and workflow_event_links, but these are - # deliberately not exposed in overloads, hence the type-check violation. + # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, + # but these are deliberately not exposed in overloads, hence the type-check + # violation. wf_handle = await self.client.start_workflow( # type: ignore workflow=workflow, arg=arg, @@ -211,7 +216,7 @@ class _TemporalStartOperationContext: def get_completion_callbacks( self, - ) -> list[temporalio.common.NexusCompletionCallback]: + ) -> list[NexusCompletionCallback]: ctx = self.nexus_operation_context return ( [ @@ -220,7 +225,7 @@ def get_completion_callbacks( # StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links # (for backwards compatibility). PR reference in Go SDK: # https://github.com/temporalio/sdk-go/pull/1945 - temporalio.common.NexusCompletionCallback( + NexusCompletionCallback( url=ctx.callback_url, header=ctx.callback_headers, ) @@ -267,8 +272,6 @@ class _TemporalCancelOperationContext: nexus_operation_context: CancelOperationContext -# TODO(nexus-prerelease): confirm that it is correct not to use event_id in the following functions. -# Should the proto say explicitly that it's optional or how it behaves when it's missing? def _workflow_handle_to_workflow_execution_started_event_link( handle: WorkflowHandle[Any, Any], ) -> temporalio.api.common.v1.Link.WorkflowEvent: @@ -282,6 +285,8 @@ def _workflow_handle_to_workflow_execution_started_event_link( workflow_id=handle.id, run_id=handle.first_execution_run_id, event_ref=temporalio.api.common.v1.Link.WorkflowEvent.EventReference( + # TODO(nexus-prerelease): confirm that it is correct not to use event_id. + # Should the proto say explicitly that it's optional or how it behaves when it's missing? event_type=temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED ), # TODO(nexus-prerelease): RequestIdReference? @@ -334,6 +339,8 @@ def _nexus_link_to_workflow_event( ) [event_type_name] = query_params.get("eventType", []) event_ref = temporalio.api.common.v1.Link.WorkflowEvent.EventReference( + # TODO(nexus-prerelease): confirm that it is correct not to use event_id. + # Should the proto say explicitly that it's optional or how it behaves when it's missing? event_type=temporalio.api.enums.v1.EventType.Value(event_type_name) ) except ValueError as err: diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index cce7484da..0a6c4ddeb 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -5,7 +5,6 @@ import warnings from functools import wraps from typing import ( - TYPE_CHECKING, Any, Awaitable, Callable, @@ -30,48 +29,13 @@ ) from typing_extensions import overload +from temporalio.client import Client from temporalio.nexus.handler._operation_context import TemporalOperationContext from ._token import ( WorkflowOperationToken as WorkflowOperationToken, ) -if TYPE_CHECKING: - from temporalio.client import ( - Client, - ) - - -async def cancel_operation( - token: str, - client: Client, - **kwargs: Any, -) -> None: - """Cancel a Nexus operation. - - Args: - token: The token of the operation to cancel. - client: The client to use to cancel the operation. - """ - try: - workflow_token = WorkflowOperationToken[Any].decode(token) - except Exception as err: - raise HandlerError( - "Failed to decode operation token as workflow operation token. " - "Canceling non-workflow operations is not supported.", - type=HandlerErrorType.NOT_FOUND, - cause=err, - ) - try: - handle = workflow_token.to_workflow_handle(client) - except Exception as err: - raise HandlerError( - "Failed to construct workflow handle from workflow operation token", - type=HandlerErrorType.NOT_FOUND, - cause=err, - ) - await handle.cancel(**kwargs) - class WorkflowRunOperationHandler( nexusrpc.handler.OperationHandler[InputT, OutputT], @@ -91,20 +55,6 @@ def __init__( async def start( _, ctx: StartOperationContext, input: InputT ) -> StartOperationResultAsync: - # TODO(nexus-prerelease) It must be possible to start "normal" workflows in - # here, and then finish up with a "nexusified" workflow. - # TODO(nexus-prerelease) It should not be possible to construct a Nexus - # token for a non-nexusified workflow. - # TODO(nexus-prerelease) When `start` returns, must the workflow have been - # started? The answer is yes, but that's yes regarding the - # OperationHandler.start() method that is created by the decorator: it's OK - # for the shorthand method to return a lazily evaluated start_workflow; it - # will only ever be used in its transformed form. Note that in a - # `OperationHandler.start` method, a user should be able to create a token - # for a nexusified workflow and return it as a Nexus response: - # - # token = WorkflowOperationToken.from_workflow_handle(wf_handle).encode() - # return StartOperationResultAsync(token) token = await start_method(service, ctx, input) return StartOperationResultAsync(token.encode()) @@ -279,3 +229,34 @@ def _get_workflow_run_start_method_input_and_output_type_annotations( else: [output_type] = args return input_type, output_type + + +async def cancel_operation( + token: str, + client: Client, + **kwargs: Any, +) -> None: + """Cancel a Nexus operation. + + Args: + token: The token of the operation to cancel. + client: The client to use to cancel the operation. + """ + try: + workflow_token = WorkflowOperationToken[Any].decode(token) + except Exception as err: + raise HandlerError( + "Failed to decode operation token as workflow operation token. " + "Canceling non-workflow operations is not supported.", + type=HandlerErrorType.NOT_FOUND, + cause=err, + ) + try: + handle = workflow_token.to_workflow_handle(client) + except Exception as err: + raise HandlerError( + "Failed to construct workflow handle from workflow operation token", + type=HandlerErrorType.NOT_FOUND, + cause=err, + ) + await handle.cancel(**kwargs) diff --git a/temporalio/types.py b/temporalio/types.py index a756d328c..f29d42e1e 100644 --- a/temporalio/types.py +++ b/temporalio/types.py @@ -81,6 +81,7 @@ class MethodAsyncSingleParam( ): """Generic callable type.""" + # TODO(nexus-prerelease): review changes to signatures in this file def __call__( self, __self: ProtocolSelfType, __arg: ProtocolParamType ) -> Awaitable[ProtocolReturnType]: diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index cc398cb14..a709ea069 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -3043,19 +3043,14 @@ def cancel(self) -> bool: return self._task.cancel() def _resolve_start_success(self, operation_token: Optional[str]) -> None: - print(f"🟢 _resolve_start_success: operation_id: {operation_token}") # We intentionally let this error if already done self._start_fut.set_result(operation_token) def _resolve_success(self, result: Any) -> None: - print( - f"🟢 _resolve_success: operation_id: {self.operation_token} result: {result}" - ) # We intentionally let this error if already done self._result_fut.set_result(result) def _resolve_failure(self, err: BaseException) -> None: - print(f"🔴 _resolve_failure: operation_id: {self.operation_token} err: {err}") if self._start_fut.done(): # We intentionally let this error if already done self._result_fut.set_exception(err) @@ -3080,7 +3075,6 @@ def _apply_schedule_command(self) -> None: ) if self._input.headers: for key, val in self._input.headers.items(): - print(f"🌈 adding nexus header: {key} = {val}") v.nexus_header[key] = val def _apply_cancel_command( diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 0e7af635b..7fcd7f376 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -4386,15 +4386,16 @@ async def execute_child_workflow( return await handle +# TODO(nexus-prerelease): use types from nexusrpc I = TypeVar("I") O = TypeVar("O") S = TypeVar("S") -# TODO(dan): ABC? +# TODO(nexus-prerelease): ABC / inherit from asyncio.Task? class NexusOperationHandle(Generic[O]): def cancel(self) -> bool: - # TODO(dan): docstring + # TODO(nexus-prerelease): docstring """ Call task.cancel() on the asyncio task that is backing this handle. @@ -4409,7 +4410,7 @@ def cancel(self) -> bool: def __await__(self) -> Generator[Any, Any, O]: raise NotImplementedError - # TODO(dan): check SDK-wide philosophy on @property vs nullary accessor methods. + # TODO(nexus-prerelease): check SDK-wide consistency for @property vs nullary accessor methods. @property def operation_token(self) -> Optional[str]: raise NotImplementedError @@ -5172,7 +5173,7 @@ class NexusClient(Generic[S]): def __init__( self, service: Union[ - # TODO(dan): Type[S] is modeling the interface case as well the impl case, but + # TODO(nexus-prerelease): Type[S] is modeling the interface case as well the impl case, but # the typevar S is used below only in the impl case. I think this is OK, but # think about it again before deleting this TODO. Type[S], @@ -5197,8 +5198,8 @@ def __init__( ) self._endpoint = endpoint - # TODO(dan): overloads: no-input, operation name, ret type - # TODO(dan): should it be an error to use a reference to a method on a class other than that supplied? + # TODO(nexus-prerelease): overloads: no-input, ret type + # TODO(nexus-prerelease): should it be an error to use a reference to a method on a class other than that supplied? async def start_operation( self, operation: Union[ @@ -5222,7 +5223,7 @@ async def start_operation( headers=headers, ) - # TODO(dan): overloads: no-input, operation name, ret type + # TODO(nexus-prerelease): overloads: no-input, ret type async def execute_operation( self, operation: Union[ diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 8188c87d2..9fff66a3d 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -478,7 +478,6 @@ async def test_sync_response( assert isinstance(e.__cause__, NexusOperationError) assert isinstance(e.__cause__.__cause__, NexusHandlerError) # ID of first command - await print_history(caller_wf_handle) assert e.__cause__.scheduled_event_id == 5 assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue) assert e.__cause__.service == "ServiceInterface" @@ -508,7 +507,6 @@ async def test_async_response( op_definition_type: OpDefinitionType, caller_reference: CallerReference, ): - print(f"🌈 {'test_async_response':<24}: {request_cancel=} {op_definition_type=}") task_queue = str(uuid.uuid4()) async with Worker( client, @@ -1036,21 +1034,6 @@ async def assert_handler_workflow_has_link_to_caller_workflow( ) -async def print_history(handle: WorkflowHandle): - print("\n\n") - history = await handle.fetch_history() - for event in history.events: - try: - event_type_name = temporalio.api.enums.v1.EventType.Name( - event.event_type - ).replace("EVENT_TYPE_", "") - except ValueError: - # Handle unknown event types - event_type_name = f"Unknown({event.event_type})" - print(f"{event.event_id}. {event_type_name}") - print("\n\n") - - # When request_cancel is True, the NexusOperationHandle in the workflow evolves # through the following states: # start_fut result_fut handle_task w/ fut_waiter (task._must_cancel) From a24be8de5e2c5a92d9dc097ff211061e0aa0adda Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 23 Jun 2025 09:10:45 -0400 Subject: [PATCH 030/183] Respond to upstream: handler factory instead of sync_operation_handler --- .../nexus/handler/_operation_handlers.py | 4 +- ...ynamic_creation_of_user_handler_classes.py | 18 +- tests/nexus/test_handler.py | 308 +++++++++++------- .../test_handler_interface_implementation.py | 16 +- tests/nexus/test_workflow_caller.py | 87 +++-- 5 files changed, 270 insertions(+), 163 deletions(-) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 0a6c4ddeb..2ee8702c9 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -14,6 +14,8 @@ Union, ) +from typing_extensions import overload + import nexusrpc.handler from nexusrpc.handler import ( CancelOperationContext, @@ -27,8 +29,6 @@ OutputT, ServiceHandlerT, ) -from typing_extensions import overload - from temporalio.client import Client from temporalio.nexus.handler._operation_context import TemporalOperationContext diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index c9c24a8f9..dce89c534 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -4,6 +4,7 @@ import httpx import nexusrpc.handler import pytest +from nexusrpc.handler import SyncOperationHandler from temporalio.client import Client from temporalio.worker import Worker @@ -32,19 +33,20 @@ def make_incrementer_user_service_definition_and_service_handler_classes( # # service handler # - async def _increment_op( - self: Any, - ctx: nexusrpc.handler.StartOperationContext, - input: int, - ) -> int: - return input + 1 + def factory(self: Any) -> nexusrpc.handler.OperationHandler[int, int]: + async def _increment_op( + ctx: nexusrpc.handler.StartOperationContext, + input: int, + ) -> int: + return input + 1 + + return SyncOperationHandler(_increment_op) op_handler_factories = { # TODO(nexus-prerelease): check that name=name should be required here. Should the op factory # name not default to the name of the method attribute (i.e. key), as opposed to # the name of the method object (i.e. value.__name__)? - # TODO(nexus-prerelease): type error - name: nexusrpc.handler.sync_operation_handler(_increment_op, name=name) + name: nexusrpc.handler.operation_handler(name=name)(factory) for name in op_names } diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 3df09f103..c7e795910 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -27,13 +27,18 @@ import httpx import nexusrpc -import nexusrpc.handler +import nexusrpc.handler.syncio import pytest from google.protobuf import json_format from nexusrpc.handler import ( CancelOperationContext, StartOperationContext, ) +from nexusrpc.handler._common import ( + FetchOperationInfoContext, + FetchOperationResultContext, + OperationInfo, +) import temporalio.api.failure.v1 import temporalio.nexus @@ -90,7 +95,8 @@ class MyService: workflow_run_operation_without_type_annotations: nexusrpc.Operation[Input, Output] sync_operation_without_type_annotations: nexusrpc.Operation[Input, Output] sync_operation_with_non_async_def: nexusrpc.Operation[Input, Output] - sync_operation_with_non_async_callable_instance: nexusrpc.Operation[Input, Output] + # TODO(nexus-prerelease): fix tests of callable instances + # sync_operation_with_non_async_callable_instance: nexusrpc.Operation[Input, Output] operation_returning_unwrapped_result_at_runtime_error: nexusrpc.Operation[ Input, Output ] @@ -130,75 +136,99 @@ async def run(self, input: Input) -> Output: # The service_handler decorator is applied by the test class MyServiceHandler: - @nexusrpc.handler.sync_operation_handler - async def echo(self, ctx: StartOperationContext, input: Input) -> Output: - assert ctx.headers["test-header-key"] == "test-header-value" - ctx.outbound_links.extend(ctx.inbound_links) - return Output( - value=f"from start method on {self.__class__.__name__}: {input.value}" - ) + @nexusrpc.handler.operation_handler + def echo(self) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx: StartOperationContext, input: Input) -> Output: + assert ctx.headers["test-header-key"] == "test-header-value" + ctx.outbound_links.extend(ctx.inbound_links) + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) - @nexusrpc.handler.sync_operation_handler - async def hang(self, ctx: StartOperationContext, input: Input) -> Output: - await asyncio.Future() - return Output(value="won't reach here") + return nexusrpc.handler.SyncOperationHandler(start) - @nexusrpc.handler.sync_operation_handler - async def non_retryable_application_error( - self, ctx: StartOperationContext, input: Input - ) -> Output: - raise ApplicationError( - "non-retryable application error", - "details arg", - # TODO(nexus-prerelease): what values of `type` should be tested? - type="TestFailureType", - non_retryable=True, - ) + @nexusrpc.handler.operation_handler + def hang(self) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx: StartOperationContext, input: Input) -> Output: + await asyncio.Future() + return Output(value="won't reach here") - @nexusrpc.handler.sync_operation_handler - async def retryable_application_error( - self, ctx: StartOperationContext, input: Input - ) -> Output: - raise ApplicationError( - "retryable application error", - "details arg", - type="TestFailureType", - non_retryable=False, - ) + return nexusrpc.handler.SyncOperationHandler(start) - @nexusrpc.handler.sync_operation_handler - async def handler_error_internal( - self, ctx: StartOperationContext, input: Input - ) -> Output: - raise nexusrpc.handler.HandlerError( - message="deliberate internal handler error", - type=nexusrpc.handler.HandlerErrorType.INTERNAL, - retryable=False, - cause=RuntimeError("cause message"), - ) + @nexusrpc.handler.operation_handler + def non_retryable_application_error( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx: StartOperationContext, input: Input) -> Output: + raise ApplicationError( + "non-retryable application error", + "details arg", + # TODO(nexus-prerelease): what values of `type` should be tested? + type="TestFailureType", + non_retryable=True, + ) - @nexusrpc.handler.sync_operation_handler - async def operation_error_failed( - self, ctx: StartOperationContext, input: Input - ) -> Output: - raise nexusrpc.handler.OperationError( - message="deliberate operation error", - state=nexusrpc.handler.OperationErrorState.FAILED, - ) + return nexusrpc.handler.SyncOperationHandler(start) - @nexusrpc.handler.sync_operation_handler - async def check_operation_timeout_header( - self, ctx: StartOperationContext, input: Input - ) -> Output: - assert "operation-timeout" in ctx.headers - return Output( - value=f"from start method on {self.__class__.__name__}: {input.value}" - ) + @nexusrpc.handler.operation_handler + def retryable_application_error( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx: StartOperationContext, input: Input) -> Output: + raise ApplicationError( + "retryable application error", + "details arg", + type="TestFailureType", + non_retryable=False, + ) + + return nexusrpc.handler.SyncOperationHandler(start) + + @nexusrpc.handler.operation_handler + def handler_error_internal( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx: StartOperationContext, input: Input) -> Output: + raise nexusrpc.handler.HandlerError( + message="deliberate internal handler error", + type=nexusrpc.handler.HandlerErrorType.INTERNAL, + retryable=False, + cause=RuntimeError("cause message"), + ) + + return nexusrpc.handler.SyncOperationHandler(start) + + @nexusrpc.handler.operation_handler + def operation_error_failed( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx: StartOperationContext, input: Input) -> Output: + raise nexusrpc.handler.OperationError( + message="deliberate operation error", + state=nexusrpc.handler.OperationErrorState.FAILED, + ) + + return nexusrpc.handler.SyncOperationHandler(start) + + @nexusrpc.handler.operation_handler + def check_operation_timeout_header( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx: StartOperationContext, input: Input) -> Output: + assert "operation-timeout" in ctx.headers + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + return nexusrpc.handler.SyncOperationHandler(start) + + @nexusrpc.handler.operation_handler + def log(self) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx: StartOperationContext, input: Input) -> Output: + logger.info("Logging from start method", extra={"input_value": input.value}) + return Output(value=f"logged: {input.value}") - @nexusrpc.handler.sync_operation_handler - async def log(self, ctx: StartOperationContext, input: Input) -> Output: - logger.info("Logging from start method", extra={"input_value": input.value}) - return Output(value=f"logged: {input.value}") + return nexusrpc.handler.SyncOperationHandler(start) @temporalio.nexus.handler.workflow_run_operation_handler async def workflow_run_operation( @@ -212,40 +242,52 @@ async def workflow_run_operation( id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) - @nexusrpc.handler.sync_operation_handler + @nexusrpc.handler.operation_handler def sync_operation_with_non_async_def( - self, ctx: StartOperationContext, input: Input - ) -> Output: - return Output( - value=f"from start method on {self.__class__.__name__}: {input.value}" - ) - - class sync_operation_with_non_async_callable_instance: - def __call__( - self, - _handler: "MyServiceHandler", - ctx: StartOperationContext, - input: Input, - ) -> Output: + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx: StartOperationContext, input: Input) -> Output: return Output( - value=f"from start method on {_handler.__class__.__name__}: {input.value}" + value=f"from start method on {self.__class__.__name__}: {input.value}" ) - _sync_operation_with_non_async_callable_instance = ( - nexusrpc.handler.sync_operation_handler( - name="sync_operation_with_non_async_callable_instance", - )( - sync_operation_with_non_async_callable_instance(), - ) - ) + return nexusrpc.handler.SyncOperationHandler(start) - @nexusrpc.handler.sync_operation_handler - async def sync_operation_without_type_annotations(self, ctx, input): - # The input type from the op definition in the service definition is used to deserialize the input. - return Output( - value=f"from start method on {self.__class__.__name__} without type annotations: {input}" + if False: + # TODO(nexus-prerelease): fix tests of callable instances + def sync_operation_with_non_async_callable_instance( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + class start: + def __call__( + self, + ctx: StartOperationContext, + input: Input, + ) -> Output: + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + return nexusrpc.handler.syncio.SyncOperationHandler(start()) + + _sync_operation_with_non_async_callable_instance = ( + nexusrpc.handler.operation_handler( + name="sync_operation_with_non_async_callable_instance", + )( + sync_operation_with_non_async_callable_instance, + ) ) + @nexusrpc.handler.operation_handler + def sync_operation_without_type_annotations(self): + async def start(ctx, input): + # The input type from the op definition in the service definition is used to deserialize the input. + return Output( + value=f"from start method on {self.__class__.__name__} without type annotations: {input}" + ) + + return nexusrpc.handler.SyncOperationHandler(start) + @temporalio.nexus.handler.workflow_run_operation_handler async def workflow_run_operation_without_type_annotations(self, ctx, input): tctx = TemporalOperationContext.current() @@ -273,7 +315,7 @@ async def workflow_run_op_link_test( ) class OperationHandlerReturningUnwrappedResult( - nexusrpc.handler.SyncOperationHandler[Input, Output] + nexusrpc.handler.OperationHandler[Input, Output] ): async def start( self, @@ -286,23 +328,44 @@ async def start( # or StartOperationResultAsync return Output(value="unwrapped result error") # type: ignore + async def fetch_info( + self, ctx: FetchOperationInfoContext, token: str + ) -> OperationInfo: + raise NotImplementedError + + async def fetch_result( + self, ctx: FetchOperationResultContext, token: str + ) -> Output: + raise NotImplementedError + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + raise NotImplementedError + @nexusrpc.handler.operation_handler def operation_returning_unwrapped_result_at_runtime_error( self, ) -> nexusrpc.handler.OperationHandler[Input, Output]: return MyServiceHandler.OperationHandlerReturningUnwrappedResult() - @nexusrpc.handler.sync_operation_handler - async def idempotency_check( - self, ctx: nexusrpc.handler.StartOperationContext, input: None - ) -> Output: - return Output(value=f"request_id: {ctx.request_id}") + @nexusrpc.handler.operation_handler + def idempotency_check( + self, + ) -> nexusrpc.handler.OperationHandler[None, Output]: + async def start(ctx: StartOperationContext, input: None) -> Output: + return Output(value=f"request_id: {ctx.request_id}") - @nexusrpc.handler.sync_operation_handler - async def non_serializable_output( - self, ctx: StartOperationContext, input: Input - ) -> NonSerializableOutput: - return NonSerializableOutput() + return nexusrpc.handler.SyncOperationHandler(start) + + @nexusrpc.handler.operation_handler + def non_serializable_output( + self, + ) -> nexusrpc.handler.OperationHandler[Input, NonSerializableOutput]: + async def start( + ctx: StartOperationContext, input: Input + ) -> NonSerializableOutput: + return NonSerializableOutput() + + return nexusrpc.handler.SyncOperationHandler(start) @dataclass @@ -496,6 +559,7 @@ class SyncHandlerHappyPathWithNonAsyncCallableInstance(_TestCase): status_code=200, body_json={"value": "from start method on MyServiceHandler: hello"}, ) + skip = "TODO(nexus-prerelease): fix tests of callable instances" class SyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): @@ -876,27 +940,33 @@ class EchoService: @nexusrpc.handler.service_handler(service=EchoService) class SyncStartHandler: - @nexusrpc.handler.sync_operation_handler - def echo(self, ctx: StartOperationContext, input: Input) -> Output: - assert ctx.headers["test-header-key"] == "test-header-value" - ctx.outbound_links.extend(ctx.inbound_links) - return Output( - value=f"from start method on {self.__class__.__name__}: {input.value}" - ) + @nexusrpc.handler.operation_handler + def echo(self) -> nexusrpc.handler.OperationHandler[Input, Output]: + def start(ctx: StartOperationContext, input: Input) -> Output: + assert ctx.headers["test-header-key"] == "test-header-value" + ctx.outbound_links.extend(ctx.inbound_links) + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + return nexusrpc.handler.SyncOperationHandler(start) @nexusrpc.handler.service_handler(service=EchoService) class DefaultCancelHandler: - @nexusrpc.handler.sync_operation_handler - async def echo(self, ctx: StartOperationContext, input: Input) -> Output: - return Output( - value=f"from start method on {self.__class__.__name__}: {input.value}" - ) + @nexusrpc.handler.operation_handler + def echo(self) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx: StartOperationContext, input: Input) -> Output: + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + return nexusrpc.handler.SyncOperationHandler(start) @nexusrpc.handler.service_handler(service=EchoService) class SyncCancelHandler: - class SyncCancel(nexusrpc.handler.SyncOperationHandler[Input, Output]): + class SyncCancel(nexusrpc.handler.OperationHandler[Input, Output]): async def start( self, ctx: StartOperationContext, @@ -911,6 +981,12 @@ async def start( def cancel(self, ctx: CancelOperationContext, token: str) -> Output: return Output(value="Hello") # type: ignore + def fetch_info(self, ctx: FetchOperationInfoContext) -> OperationInfo: + raise NotImplementedError + + def fetch_result(self, ctx: FetchOperationResultContext) -> Output: + raise NotImplementedError + @nexusrpc.handler.operation_handler def echo(self) -> nexusrpc.handler.OperationHandler[Input, Output]: return SyncCancelHandler.SyncCancel() @@ -920,7 +996,7 @@ class SyncHandlerNoExecutor(_InstantiationCase): handler = SyncStartHandler executor = False exception = RuntimeError - match = "start must be an `async def`" + match = "is not an `async def` method" class DefaultCancel(_InstantiationCase): diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index d62e0e581..ad2af6177 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -3,10 +3,13 @@ import nexusrpc import nexusrpc.handler import pytest +from nexusrpc.handler import OperationHandler, SyncOperationHandler import temporalio.api.failure.v1 import temporalio.nexus -from temporalio.nexus.handler._token import WorkflowOperationToken +from temporalio.nexus.handler import ( + WorkflowOperationToken, +) HTTP_PORT = 7243 @@ -23,10 +26,13 @@ class Interface: op: nexusrpc.Operation[None, None] class Impl: - @nexusrpc.handler.sync_operation_handler - async def op( - self, ctx: nexusrpc.handler.StartOperationContext, input: None - ) -> None: ... + @nexusrpc.handler.operation_handler + def op(self) -> OperationHandler[None, None]: + async def start( + ctx: nexusrpc.handler.StartOperationContext, input: None + ) -> None: ... + + return SyncOperationHandler(start) error_message = None diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 9fff66a3d..9a6c30eb8 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -187,18 +187,21 @@ def sync_or_async_operation( ) -> nexusrpc.handler.OperationHandler[OpInput, OpOutput]: return SyncOrAsyncOperation() - @nexusrpc.handler.sync_operation_handler - async def sync_operation( - self, ctx: StartOperationContext, input: OpInput - ) -> OpOutput: - assert isinstance(input.response_type, SyncResponse) - if input.response_type.exception_in_operation_start: - raise RPCError( - "RPCError INVALID_ARGUMENT in Nexus operation", - RPCStatusCode.INVALID_ARGUMENT, - b"", - ) - return OpOutput(value="sync response") + @nexusrpc.handler.operation_handler + def sync_operation( + self, + ) -> nexusrpc.handler.OperationHandler[OpInput, OpOutput]: + async def start(ctx: StartOperationContext, input: OpInput) -> OpOutput: + assert isinstance(input.response_type, SyncResponse) + if input.response_type.exception_in_operation_start: + raise RPCError( + "RPCError INVALID_ARGUMENT in Nexus operation", + RPCStatusCode.INVALID_ARGUMENT, + b"", + ) + return OpOutput(value="sync response") + + return nexusrpc.handler.SyncOperationHandler(start) @temporalio.nexus.handler.workflow_run_operation_handler async def async_operation( @@ -744,38 +747,58 @@ class ServiceInterfaceWithNameOverride: @nexusrpc.handler.service_handler class ServiceImplInterfaceWithNeitherInterfaceNorNameOverride: - @nexusrpc.handler.sync_operation_handler - async def op( - self, ctx: nexusrpc.handler.StartOperationContext, input: None - ) -> ServiceClassNameOutput: - return ServiceClassNameOutput(self.__class__.__name__) + @nexusrpc.handler.operation_handler + def op( + self, + ) -> nexusrpc.handler.OperationHandler[None, ServiceClassNameOutput]: + async def start( + ctx: StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) + + return nexusrpc.handler.SyncOperationHandler(start) @nexusrpc.handler.service_handler(service=ServiceInterfaceWithoutNameOverride) class ServiceImplInterfaceWithoutNameOverride: - @nexusrpc.handler.sync_operation_handler - async def op( - self, ctx: nexusrpc.handler.StartOperationContext, input: None - ) -> ServiceClassNameOutput: - return ServiceClassNameOutput(self.__class__.__name__) + @nexusrpc.handler.operation_handler + def op( + self, + ) -> nexusrpc.handler.OperationHandler[None, ServiceClassNameOutput]: + async def start( + ctx: StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) + + return nexusrpc.handler.SyncOperationHandler(start) @nexusrpc.handler.service_handler(service=ServiceInterfaceWithNameOverride) class ServiceImplInterfaceWithNameOverride: - @nexusrpc.handler.sync_operation_handler - async def op( - self, ctx: nexusrpc.handler.StartOperationContext, input: None - ) -> ServiceClassNameOutput: - return ServiceClassNameOutput(self.__class__.__name__) + @nexusrpc.handler.operation_handler + def op( + self, + ) -> nexusrpc.handler.OperationHandler[None, ServiceClassNameOutput]: + async def start( + ctx: StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) + + return nexusrpc.handler.SyncOperationHandler(start) @nexusrpc.handler.service_handler(name="service-impl-🌈") class ServiceImplWithNameOverride: - @nexusrpc.handler.sync_operation_handler - async def op( - self, ctx: nexusrpc.handler.StartOperationContext, input: None - ) -> ServiceClassNameOutput: - return ServiceClassNameOutput(self.__class__.__name__) + @nexusrpc.handler.operation_handler + def op( + self, + ) -> nexusrpc.handler.OperationHandler[None, ServiceClassNameOutput]: + async def start( + ctx: StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) + + return nexusrpc.handler.SyncOperationHandler(start) class NameOverride(IntEnum): From 354962935b238c896f103187c548a30a1f7cec84 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 23 Jun 2025 10:34:55 -0400 Subject: [PATCH 031/183] Switch workflow_run_operation_handler to standard factory --- temporalio/nexus/handler/__init__.py | 7 +- .../nexus/handler/_operation_handlers.py | 139 ++-------------- temporalio/nexus/handler/_util.py | 32 ++++ tests/nexus/test_handler.py | 149 ++++++++++-------- .../test_handler_interface_implementation.py | 11 +- .../test_handler_operation_definitions.py | 47 ++++-- tests/nexus/test_workflow_caller.py | 80 ++++++---- 7 files changed, 215 insertions(+), 250 deletions(-) create mode 100644 temporalio/nexus/handler/_util.py diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index cc400da49..4afdc2c15 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -24,12 +24,7 @@ WorkflowRunOperationHandler as WorkflowRunOperationHandler, ) from ._operation_handlers import cancel_operation as cancel_operation -from ._operation_handlers import ( - workflow_run_operation_handler as workflow_run_operation_handler, -) -from ._token import ( - WorkflowOperationToken as WorkflowOperationToken, -) +from ._token import WorkflowOperationToken as WorkflowOperationToken if TYPE_CHECKING: from temporalio.client import ( diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 2ee8702c9..778a8b9f5 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -1,9 +1,7 @@ from __future__ import annotations -import types import typing import warnings -from functools import wraps from typing import ( Any, Awaitable, @@ -14,8 +12,6 @@ Union, ) -from typing_extensions import overload - import nexusrpc.handler from nexusrpc.handler import ( CancelOperationContext, @@ -29,12 +25,14 @@ OutputT, ServiceHandlerT, ) + from temporalio.client import Client from temporalio.nexus.handler._operation_context import TemporalOperationContext from ._token import ( WorkflowOperationToken as WorkflowOperationToken, ) +from ._util import is_async_callable class WorkflowRunOperationHandler( @@ -43,30 +41,26 @@ class WorkflowRunOperationHandler( ): def __init__( self, - service: ServiceHandlerT, - start_method: Callable[ - [ServiceHandlerT, StartOperationContext, InputT], + start: Callable[ + [StartOperationContext, InputT], Awaitable[WorkflowOperationToken[OutputT]], ], ): - self.service = service - - @wraps(start_method) - async def start( - _, ctx: StartOperationContext, input: InputT - ) -> StartOperationResultAsync: - token = await start_method(service, ctx, input) - return StartOperationResultAsync(token.encode()) - - self.start = types.MethodType(start, self) + if not is_async_callable(start): + raise RuntimeError( + f"{start} is not an `async def` method. " + "WorkflowRunOperationHandler must be initialized with an " + "`async def` start method." + ) + self._start = start + if start.__doc__: + self.start.__func__.__doc__ = start.__doc__ async def start( self, ctx: StartOperationContext, input: InputT ) -> nexusrpc.handler.StartOperationResultAsync: - raise NotImplementedError( - "The start method of a WorkflowRunOperation should be set " - "dynamically in the __init__ method. (Did you forget to call super()?)" - ) + token = await self._start(ctx, input) + return StartOperationResultAsync(token.encode()) async def cancel(self, ctx: CancelOperationContext, token: str) -> None: tctx = TemporalOperationContext.current() @@ -89,109 +83,6 @@ def fetch_result( ) -@overload -def workflow_run_operation_handler( - start_method: Callable[ - [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], - ], -) -> Callable[ - [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] -]: ... - - -@overload -def workflow_run_operation_handler( - *, - name: Optional[str] = None, -) -> Callable[ - [ - Callable[ - [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], - ] - ], - Callable[ - [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] - ], -]: ... - - -def workflow_run_operation_handler( - start_method: Optional[ - Callable[ - [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], - ] - ] = None, - *, - name: Optional[str] = None, -) -> Union[ - Callable[ - [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] - ], - Callable[ - [ - Callable[ - [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], - ] - ], - Callable[ - [ServiceHandlerT], - WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT], - ], - ], -]: - def decorator( - start_method: Callable[ - [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], - ], - ) -> Callable[ - [ServiceHandlerT], WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT] - ]: - def factory( - service: ServiceHandlerT, - ) -> WorkflowRunOperationHandler[InputT, OutputT, ServiceHandlerT]: - # TODO(nexus-prerelease) I was passing output_type here; why? - return WorkflowRunOperationHandler(service, start_method) - - # TODO(nexus-prerelease): handle callable instances: __class__.__name__ as in sync_operation_handler - method_name = getattr(start_method, "__name__", None) - if not method_name and callable(start_method): - method_name = start_method.__class__.__name__ - if not method_name: - raise TypeError( - f"Could not determine operation method name: " - f"expected {start_method} to be a function or callable instance." - ) - - input_type, output_type = ( - _get_workflow_run_start_method_input_and_output_type_annotations( - start_method - ) - ) - - setattr( - factory, - "__nexus_operation__", - nexusrpc.Operation( - name=name or method_name, - method_name=method_name, - input_type=input_type, - output_type=output_type, - ), - ) - - return factory - - if start_method is None: - return decorator - - return decorator(start_method) - - def _get_workflow_run_start_method_input_and_output_type_annotations( start_method: Callable[ [ServiceHandlerT, StartOperationContext, InputT], diff --git a/temporalio/nexus/handler/_util.py b/temporalio/nexus/handler/_util.py new file mode 100644 index 000000000..93c8613aa --- /dev/null +++ b/temporalio/nexus/handler/_util.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import functools +import inspect +from typing import ( + Any, + Awaitable, + Callable, +) + +from typing_extensions import TypeGuard + + +# Copied from https://github.com/modelcontextprotocol/python-sdk +# +# Copyright (c) 2024 Anthropic, PBC. +# +# Modified to use TypeGuard. +# +# This file is licensed under the MIT License. +def is_async_callable(obj: Any) -> TypeGuard[Callable[..., Awaitable[Any]]]: + """ + Return True if `obj` is an async callable. + + Supports partials of async callable class instances. + """ + while isinstance(obj, functools.partial): + obj = obj.func + + return inspect.iscoroutinefunction(obj) or ( + callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) + ) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index c7e795910..0df1d4811 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -230,17 +230,22 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: return nexusrpc.handler.SyncOperationHandler(start) - @temporalio.nexus.handler.workflow_run_operation_handler - async def workflow_run_operation( - self, ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: - tctx = TemporalOperationContext.current() - return await tctx.start_workflow( - MyWorkflow.run, - input, - id=str(uuid.uuid4()), - id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, - ) + @nexusrpc.handler.operation_handler + def workflow_run_operation( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start( + ctx: StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: + tctx = TemporalOperationContext.current() + return await tctx.start_workflow( + MyWorkflow.run, + input, + id=str(uuid.uuid4()), + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, + ) + + return temporalio.nexus.handler.WorkflowRunOperationHandler(start) @nexusrpc.handler.operation_handler def sync_operation_with_non_async_def( @@ -288,31 +293,39 @@ async def start(ctx, input): return nexusrpc.handler.SyncOperationHandler(start) - @temporalio.nexus.handler.workflow_run_operation_handler - async def workflow_run_operation_without_type_annotations(self, ctx, input): - tctx = TemporalOperationContext.current() - return await tctx.start_workflow( - WorkflowWithoutTypeAnnotations.run, - input, - id=str(uuid.uuid4()), - ) + @nexusrpc.handler.operation_handler + def workflow_run_operation_without_type_annotations( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx, input): + tctx = TemporalOperationContext.current() + return await tctx.start_workflow( + WorkflowWithoutTypeAnnotations.run, + input, + id=str(uuid.uuid4()), + ) - @temporalio.nexus.handler.workflow_run_operation_handler - async def workflow_run_op_link_test( - self, ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: - assert any( - link.url == "http://inbound-link/" for link in ctx.inbound_links - ), "Inbound link not found" - assert ctx.request_id == "test-request-id-123", "Request ID mismatch" - ctx.outbound_links.extend(ctx.inbound_links) - - tctx = TemporalOperationContext.current() - return await tctx.start_workflow( - MyLinkTestWorkflow.run, - input, - id=str(uuid.uuid4()), - ) + return temporalio.nexus.handler.WorkflowRunOperationHandler(start) + + @nexusrpc.handler.operation_handler + def workflow_run_op_link_test( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx, input): + assert any( + link.url == "http://inbound-link/" for link in ctx.inbound_links + ), "Inbound link not found" + assert ctx.request_id == "test-request-id-123", "Request ID mismatch" + ctx.outbound_links.extend(ctx.inbound_links) + + tctx = TemporalOperationContext.current() + return await tctx.start_workflow( + MyLinkTestWorkflow.run, + input, + id=str(uuid.uuid4()), + ) + + return temporalio.nexus.handler.WorkflowRunOperationHandler(start) class OperationHandlerReturningUnwrappedResult( nexusrpc.handler.OperationHandler[Input, Output] @@ -1105,37 +1118,43 @@ async def run(self, input: Input) -> Output: @nexusrpc.handler.service_handler class ServiceHandlerForRequestIdTest: - @temporalio.nexus.handler.workflow_run_operation_handler - async def operation_backed_by_a_workflow( - self, ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: - tctx = TemporalOperationContext.current() - return await tctx.start_workflow( - EchoWorkflow.run, - input, - id=input.value, - id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, - ) + @nexusrpc.handler.operation_handler + def operation_backed_by_a_workflow( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx, input) -> WorkflowOperationToken[Output]: + tctx = TemporalOperationContext.current() + return await tctx.start_workflow( + EchoWorkflow.run, + input, + id=input.value, + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, + ) - @temporalio.nexus.handler.workflow_run_operation_handler - async def operation_that_executes_a_workflow_before_starting_the_backing_workflow( - self, ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: - tctx = TemporalOperationContext.current() - await tctx.client.start_workflow( - EchoWorkflow.run, - input, - id=input.value, - task_queue=tctx.task_queue, - ) - # This should fail. It will not fail if the Nexus request ID was incorrectly - # propagated to both StartWorkflow requests. - return await tctx.start_workflow( - EchoWorkflow.run, - input, - id=input.value, - id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, - ) + return temporalio.nexus.handler.WorkflowRunOperationHandler(start) + + @nexusrpc.handler.operation_handler + def operation_that_executes_a_workflow_before_starting_the_backing_workflow( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start(ctx, input) -> WorkflowOperationToken[Output]: + tctx = TemporalOperationContext.current() + await tctx.client.start_workflow( + EchoWorkflow.run, + input, + id=input.value, + task_queue=tctx.task_queue, + ) + # This should fail. It will not fail if the Nexus request ID was incorrectly + # propagated to both StartWorkflow requests. + return await tctx.start_workflow( + EchoWorkflow.run, + input, + id=input.value, + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, + ) + + return temporalio.nexus.handler.WorkflowRunOperationHandler(start) async def test_request_id_becomes_start_workflow_request_id(env: WorkflowEnvironment): diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index ad2af6177..2e2872d45 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -43,10 +43,13 @@ class Interface: op: nexusrpc.Operation[str, int] class Impl: - @temporalio.nexus.handler.workflow_run_operation_handler - async def op( - self, ctx: nexusrpc.handler.StartOperationContext, input: str - ) -> WorkflowOperationToken[int]: ... + @nexusrpc.handler.operation_handler + def op(self) -> nexusrpc.handler.OperationHandler[str, int]: + async def start( + ctx: nexusrpc.handler.StartOperationContext, input: str + ) -> WorkflowOperationToken[int]: ... + + return temporalio.nexus.handler.WorkflowRunOperationHandler(start) error_message = None diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index 85c10a68c..fa65ccb49 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -1,5 +1,5 @@ """ -Test that workflow_run_operation_handler decorator results in operation definitions with the correct name +Test that operation_handler decorator results in operation definitions with the correct name and input/output types. """ @@ -32,15 +32,20 @@ class _TestCase: class NotCalled(_TestCase): @nexusrpc.handler.service_handler class Service: - @temporalio.nexus.handler.workflow_run_operation_handler - async def workflow_run_operation_handler( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: ... + @nexusrpc.handler.operation_handler + def my_workflow_run_operation_handler( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start( + ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: ... + + return temporalio.nexus.handler.WorkflowRunOperationHandler(start) expected_operations = { - "workflow_run_operation_handler": nexusrpc.Operation( - name="workflow_run_operation_handler", - method_name="workflow_run_operation_handler", + "my_workflow_run_operation_handler": nexusrpc.Operation( + name="my_workflow_run_operation_handler", + method_name="my_workflow_run_operation_handler", input_type=Input, output_type=Output, ), @@ -50,10 +55,15 @@ async def workflow_run_operation_handler( class CalledWithoutArgs(_TestCase): @nexusrpc.handler.service_handler class Service: - @temporalio.nexus.handler.workflow_run_operation_handler() - async def workflow_run_operation_handler( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: ... + @nexusrpc.handler.operation_handler() + def my_workflow_run_operation_handler( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start( + ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: ... + + return temporalio.nexus.handler.WorkflowRunOperationHandler(start) expected_operations = NotCalled.expected_operations @@ -61,10 +71,15 @@ async def workflow_run_operation_handler( class CalledWithNameOverride(_TestCase): @nexusrpc.handler.service_handler class Service: - @temporalio.nexus.handler.workflow_run_operation_handler(name="operation-name") - async def workflow_run_operation_with_name_override( - self, ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: ... + @nexusrpc.handler.operation_handler(name="operation-name") + def workflow_run_operation_with_name_override( + self, + ) -> nexusrpc.handler.OperationHandler[Input, Output]: + async def start( + ctx: nexusrpc.handler.StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: ... + + return temporalio.nexus.handler.WorkflowRunOperationHandler(start) expected_operations = { "workflow_run_operation_with_name_override": nexusrpc.Operation( diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 9a6c30eb8..3f999afce 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -203,23 +203,28 @@ async def start(ctx: StartOperationContext, input: OpInput) -> OpOutput: return nexusrpc.handler.SyncOperationHandler(start) - @temporalio.nexus.handler.workflow_run_operation_handler - async def async_operation( - self, ctx: StartOperationContext, input: OpInput - ) -> WorkflowOperationToken[HandlerWfOutput]: - assert isinstance(input.response_type, AsyncResponse) - if input.response_type.exception_in_operation_start: - raise RPCError( - "RPCError INVALID_ARGUMENT in Nexus operation", - RPCStatusCode.INVALID_ARGUMENT, - b"", + @nexusrpc.handler.operation_handler + def async_operation( + self, + ) -> nexusrpc.handler.OperationHandler[OpInput, HandlerWfOutput]: + async def start( + ctx: StartOperationContext, input: OpInput + ) -> WorkflowOperationToken[HandlerWfOutput]: + assert isinstance(input.response_type, AsyncResponse) + if input.response_type.exception_in_operation_start: + raise RPCError( + "RPCError INVALID_ARGUMENT in Nexus operation", + RPCStatusCode.INVALID_ARGUMENT, + b"", + ) + tctx = TemporalOperationContext.current() + return await tctx.start_workflow( + HandlerWorkflow.run, + HandlerWfInput(op_input=input), + id=input.response_type.operation_workflow_id, ) - tctx = TemporalOperationContext.current() - return await tctx.start_workflow( - HandlerWorkflow.run, - HandlerWfInput(op_input=input), - id=input.response_type.operation_workflow_id, - ) + + return temporalio.nexus.handler.WorkflowRunOperationHandler(start) # ----------------------------------------------------------------------------- @@ -935,25 +940,30 @@ async def run(self, input: str) -> str: @nexusrpc.handler.service_handler class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: - @temporalio.nexus.handler.workflow_run_operation_handler - async def my_workflow_run_operation( - self, ctx: StartOperationContext, input: None - ) -> WorkflowOperationToken[str]: - tctx = TemporalOperationContext.current() - result_1 = await tctx.client.execute_workflow( - EchoWorkflow.run, - "result-1", - id=str(uuid.uuid4()), - task_queue=tctx.task_queue, - ) - # In case result_1 is incorrectly being delivered to the caller as the operation - # result, give time for that incorrect behavior to occur. - await asyncio.sleep(0.5) - return await tctx.start_workflow( - EchoWorkflow.run, - f"{result_1}-result-2", - id=str(uuid.uuid4()), - ) + @nexusrpc.handler.operation_handler + def my_workflow_run_operation( + self, + ) -> nexusrpc.handler.OperationHandler[None, str]: + async def start( + ctx: StartOperationContext, input: None + ) -> WorkflowOperationToken[str]: + tctx = TemporalOperationContext.current() + result_1 = await tctx.client.execute_workflow( + EchoWorkflow.run, + "result-1", + id=str(uuid.uuid4()), + task_queue=tctx.task_queue, + ) + # In case result_1 is incorrectly being delivered to the caller as the operation + # result, give time for that incorrect behavior to occur. + await asyncio.sleep(0.5) + return await tctx.start_workflow( + EchoWorkflow.run, + f"{result_1}-result-2", + id=str(uuid.uuid4()), + ) + + return temporalio.nexus.handler.WorkflowRunOperationHandler(start) @workflow.defn From fdb3e3701d5c8b5b1c125a83a8063be8c4483f4a Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 23 Jun 2025 12:23:14 -0400 Subject: [PATCH 032/183] Do not support passing client to cancel_operation --- temporalio/nexus/handler/_operation_handlers.py | 9 ++++----- tests/nexus/test_workflow_caller.py | 3 +-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 778a8b9f5..264bbacee 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -26,7 +26,6 @@ ServiceHandlerT, ) -from temporalio.client import Client from temporalio.nexus.handler._operation_context import TemporalOperationContext from ._token import ( @@ -63,8 +62,7 @@ async def start( return StartOperationResultAsync(token.encode()) async def cancel(self, ctx: CancelOperationContext, token: str) -> None: - tctx = TemporalOperationContext.current() - await cancel_operation(token, tctx.client) + await cancel_operation(token) def fetch_info( self, ctx: nexusrpc.handler.FetchOperationInfoContext, token: str @@ -124,7 +122,6 @@ def _get_workflow_run_start_method_input_and_output_type_annotations( async def cancel_operation( token: str, - client: Client, **kwargs: Any, ) -> None: """Cancel a Nexus operation. @@ -142,8 +139,10 @@ async def cancel_operation( type=HandlerErrorType.NOT_FOUND, cause=err, ) + + tctx = TemporalOperationContext.current() try: - handle = workflow_token.to_workflow_handle(client) + handle = workflow_token.to_workflow_handle(tctx.client) except Exception as err: raise HandlerError( "Failed to construct workflow handle from workflow operation token", diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 3f999afce..a588ef38d 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -165,8 +165,7 @@ async def start( raise TypeError async def cancel(self, ctx: CancelOperationContext, token: str) -> None: - tctx = TemporalOperationContext.current() - return await temporalio.nexus.handler.cancel_operation(token, tctx.client) + return await temporalio.nexus.handler.cancel_operation(token) async def fetch_info( self, ctx: FetchOperationInfoContext, token: str From c79bde5bed1238988c8c3a519a617c74c0cc7b4a Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 23 Jun 2025 14:25:06 -0400 Subject: [PATCH 033/183] RTU: bridge Rust --- temporalio/bridge/src/worker.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/temporalio/bridge/src/worker.rs b/temporalio/bridge/src/worker.rs index 930acedd3..130389259 100644 --- a/temporalio/bridge/src/worker.rs +++ b/temporalio/bridge/src/worker.rs @@ -566,7 +566,7 @@ impl WorkerRef { }) } - fn poll_nexus_task<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + fn poll_nexus_task<'p>(&self, py: Python<'p>) -> PyResult> { let worker = self.worker.as_ref().unwrap().clone(); self.runtime.future_into_py(py, async move { let bytes = match worker.poll_nexus_task().await { @@ -574,8 +574,7 @@ impl WorkerRef { Err(PollError::ShutDown) => return Err(PollShutdownError::new_err(())), Err(err) => return Err(PyRuntimeError::new_err(format!("Poll failure: {}", err))), }; - let bytes: &[u8] = &bytes; - Ok(Python::with_gil(|py| bytes.into_py(py))) + Ok(bytes) }) } @@ -613,7 +612,10 @@ impl WorkerRef { }) } - fn complete_nexus_task<'p>(&self, py: Python<'p>, proto: &PyBytes) -> PyResult<&'p PyAny> { + fn complete_nexus_task<'p>(&self, + py: Python<'p>, + proto: &Bound<'_, PyBytes>, +) -> PyResult> { let worker = self.worker.as_ref().unwrap().clone(); let completion = NexusTaskCompletion::decode(proto.as_bytes()) .map_err(|err| PyValueError::new_err(format!("Invalid proto: {}", err)))?; From b91207ea0c0298b8b73de0f41e0ff1433bac0ac9 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 23 Jun 2025 20:22:45 -0400 Subject: [PATCH 034/183] Fix: make all methods `async def` on WorkflowRunOperationHandler --- temporalio/nexus/handler/_operation_handlers.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 264bbacee..9ca275bd8 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -9,7 +9,6 @@ Generic, Optional, Type, - Union, ) import nexusrpc.handler @@ -64,18 +63,16 @@ async def start( async def cancel(self, ctx: CancelOperationContext, token: str) -> None: await cancel_operation(token) - def fetch_info( + async def fetch_info( self, ctx: nexusrpc.handler.FetchOperationInfoContext, token: str - ) -> Union[ - nexusrpc.handler.OperationInfo, Awaitable[nexusrpc.handler.OperationInfo] - ]: + ) -> nexusrpc.handler.OperationInfo: raise NotImplementedError( "Temporal Nexus operation handlers do not support fetching operation info." ) - def fetch_result( + async def fetch_result( self, ctx: nexusrpc.handler.FetchOperationResultContext, token: str - ) -> Union[OutputT, Awaitable[OutputT]]: + ) -> OutputT: raise NotImplementedError( "Temporal Nexus operation handlers do not support fetching operation results." ) From 60fcee3101b38a7a80093868dceb92e1a7cc4228 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 23 Jun 2025 20:29:42 -0400 Subject: [PATCH 035/183] Get rid of TypeGuard It was destroying type information in the one place it was used --- temporalio/nexus/handler/_util.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/temporalio/nexus/handler/_util.py b/temporalio/nexus/handler/_util.py index 93c8613aa..1bd7f09fc 100644 --- a/temporalio/nexus/handler/_util.py +++ b/temporalio/nexus/handler/_util.py @@ -4,21 +4,15 @@ import inspect from typing import ( Any, - Awaitable, - Callable, ) -from typing_extensions import TypeGuard - # Copied from https://github.com/modelcontextprotocol/python-sdk # # Copyright (c) 2024 Anthropic, PBC. # -# Modified to use TypeGuard. -# # This file is licensed under the MIT License. -def is_async_callable(obj: Any) -> TypeGuard[Callable[..., Awaitable[Any]]]: +def is_async_callable(obj: Any) -> bool: """ Return True if `obj` is an async callable. From 3715460454e3fd9f4355792335d378d953f25e14 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 23 Jun 2025 21:18:56 -0400 Subject: [PATCH 036/183] Support passing result_type when getting workflow handle from token --- temporalio/nexus/handler/_token.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/temporalio/nexus/handler/_token.py b/temporalio/nexus/handler/_token.py index ecb5d06cf..08e9074a0 100644 --- a/temporalio/nexus/handler/_token.py +++ b/temporalio/nexus/handler/_token.py @@ -3,7 +3,7 @@ import base64 import json from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Literal, Optional +from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Type from nexusrpc.types import OutputT @@ -25,14 +25,16 @@ class WorkflowOperationToken(Generic[OutputT]): # serialized token; it's only used to reject newer token versions on load. version: Optional[int] = None - def to_workflow_handle(self, client: Client) -> WorkflowHandle[Any, OutputT]: + def to_workflow_handle( + self, client: Client, result_type: Optional[Type[OutputT]] = None + ) -> WorkflowHandle[Any, OutputT]: """Create a :py:class:`temporalio.client.WorkflowHandle` from the token.""" if client.namespace != self.namespace: raise ValueError( f"Client namespace {client.namespace} does not match " f"operation token namespace {self.namespace}" ) - return client.get_workflow_handle(self.workflow_id) + return client.get_workflow_handle(self.workflow_id, result_type=result_type) # TODO(nexus-preview): The return type here should be dictated by the input workflow # handle type. From 2b5debc82b9f00c0984390d89c7c6b97b12dcae1 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 23 Jun 2025 21:23:24 -0400 Subject: [PATCH 037/183] Implement fetch_result handler --- .../nexus/handler/_operation_handlers.py | 31 +++++++++++++++++-- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 9ca275bd8..0ecfc36ba 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -53,6 +53,9 @@ def __init__( self._start = start if start.__doc__: self.start.__func__.__doc__ = start.__doc__ + self._input_type, self._output_type = ( + _get_workflow_run_start_method_input_and_output_type_annotations(start) + ) async def start( self, ctx: StartOperationContext, input: InputT @@ -74,13 +77,35 @@ async def fetch_result( self, ctx: nexusrpc.handler.FetchOperationResultContext, token: str ) -> OutputT: raise NotImplementedError( - "Temporal Nexus operation handlers do not support fetching operation results." + "Temporal Nexus operation handlers do not support fetching operation result." ) + # An implementation is provided for future reference: + try: + workflow_token = WorkflowOperationToken[OutputT].decode(token) + except Exception as err: + raise HandlerError( + "Failed to decode operation token as workflow operation token. " + "Fetching result for non-workflow operations is not supported.", + type=HandlerErrorType.NOT_FOUND, + cause=err, + ) + tctx = TemporalOperationContext.current() + try: + handle = workflow_token.to_workflow_handle( + tctx.client, result_type=self._output_type + ) + except Exception as err: + raise HandlerError( + "Failed to construct workflow handle from workflow operation token", + type=HandlerErrorType.NOT_FOUND, + cause=err, + ) + return await handle.result() def _get_workflow_run_start_method_input_and_output_type_annotations( start_method: Callable[ - [ServiceHandlerT, StartOperationContext, InputT], + [StartOperationContext, InputT], Awaitable[WorkflowOperationToken[OutputT]], ], ) -> tuple[ @@ -93,7 +118,7 @@ def _get_workflow_run_start_method_input_and_output_type_annotations( :py:class:`WorkflowHandle`. """ input_type, output_type = ( - nexusrpc.handler.get_start_method_input_and_output_types_annotations( + nexusrpc.handler.get_start_method_input_and_output_type_annotations( start_method ) ) From b3ddaf9204c0cc662ca2363819bba731fda936d1 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 23 Jun 2025 21:39:27 -0400 Subject: [PATCH 038/183] Cleanup --- .../nexus/handler/_operation_handlers.py | 54 +------ temporalio/nexus/handler/_util.py | 99 ++++++++++++ .../nexus/test_get_input_and_output_types.py | 153 ++++++++++++++++++ 3 files changed, 258 insertions(+), 48 deletions(-) create mode 100644 tests/nexus/test_get_input_and_output_types.py diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 0ecfc36ba..e4d4a9d8f 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -1,14 +1,10 @@ from __future__ import annotations -import typing -import warnings from typing import ( Any, Awaitable, Callable, Generic, - Optional, - Type, ) import nexusrpc.handler @@ -26,11 +22,12 @@ ) from temporalio.nexus.handler._operation_context import TemporalOperationContext +from temporalio.nexus.handler._token import WorkflowOperationToken -from ._token import ( - WorkflowOperationToken as WorkflowOperationToken, +from ._util import ( + get_workflow_run_start_method_input_and_output_type_annotations, + is_async_callable, ) -from ._util import is_async_callable class WorkflowRunOperationHandler( @@ -54,7 +51,7 @@ def __init__( if start.__doc__: self.start.__func__.__doc__ = start.__doc__ self._input_type, self._output_type = ( - _get_workflow_run_start_method_input_and_output_type_annotations(start) + get_workflow_run_start_method_input_and_output_type_annotations(start) ) async def start( @@ -77,7 +74,7 @@ async def fetch_result( self, ctx: nexusrpc.handler.FetchOperationResultContext, token: str ) -> OutputT: raise NotImplementedError( - "Temporal Nexus operation handlers do not support fetching operation result." + "Temporal Nexus operation handlers do not support fetching the operation result." ) # An implementation is provided for future reference: try: @@ -103,45 +100,6 @@ async def fetch_result( return await handle.result() -def _get_workflow_run_start_method_input_and_output_type_annotations( - start_method: Callable[ - [StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], - ], -) -> tuple[ - Optional[Type[InputT]], - Optional[Type[OutputT]], -]: - """Return operation input and output types. - - `start_method` must be a type-annotated start method that returns a - :py:class:`WorkflowHandle`. - """ - input_type, output_type = ( - nexusrpc.handler.get_start_method_input_and_output_type_annotations( - start_method - ) - ) - origin_type = typing.get_origin(output_type) - if not origin_type or not issubclass(origin_type, WorkflowOperationToken): - warnings.warn( - f"Expected return type of {start_method.__name__} to be a subclass of WorkflowOperationToken, " - f"but is {output_type}" - ) - output_type = None - - args = typing.get_args(output_type) - if len(args) != 1: - warnings.warn( - f"Expected return type of {start_method.__name__} to have exactly one type parameter, " - f"but has {len(args)}: {args}" - ) - output_type = None - else: - [output_type] = args - return input_type, output_type - - async def cancel_operation( token: str, **kwargs: Any, diff --git a/temporalio/nexus/handler/_util.py b/temporalio/nexus/handler/_util.py index 1bd7f09fc..ddee73ffa 100644 --- a/temporalio/nexus/handler/_util.py +++ b/temporalio/nexus/handler/_util.py @@ -2,10 +2,109 @@ import functools import inspect +import typing +import warnings from typing import ( Any, + Awaitable, + Callable, + Optional, + Type, + Union, ) +from nexusrpc.handler import ( + StartOperationContext, +) +from nexusrpc.types import ( + InputT, + OutputT, +) + +from ._token import ( + WorkflowOperationToken as WorkflowOperationToken, +) + + +def get_workflow_run_start_method_input_and_output_type_annotations( + start: Callable[ + [StartOperationContext, InputT], + Awaitable[WorkflowOperationToken[OutputT]], + ], +) -> tuple[ + Optional[Type[InputT]], + Optional[Type[OutputT]], +]: + """Return operation input and output types. + + `start` must be a type-annotated start method that returns a + :py:class:`WorkflowHandle`. + """ + input_type, output_type = _get_start_method_input_and_output_type_annotations(start) + origin_type = typing.get_origin(output_type) + if not origin_type or not issubclass(origin_type, WorkflowOperationToken): + warnings.warn( + f"Expected return type of {start.__name__} to be a subclass of WorkflowOperationToken, " + f"but is {output_type}" + ) + output_type = None + + args = typing.get_args(output_type) + if len(args) != 1: + warnings.warn( + f"Expected return type of {start.__name__} to have exactly one type parameter, " + f"but has {len(args)}: {args}" + ) + output_type = None + else: + [output_type] = args + return input_type, output_type + + +def _get_start_method_input_and_output_type_annotations( + start: Callable[ + [StartOperationContext, InputT], + Union[OutputT, Awaitable[OutputT]], + ], +) -> tuple[ + Optional[Type[InputT]], + Optional[Type[OutputT]], +]: + """Return operation input and output types. + + `start` must be a type-annotated start method that returns a synchronous result. + """ + try: + type_annotations = typing.get_type_hints(start) + except TypeError: + # TODO(nexus-preview): stacklevel + warnings.warn( + f"Expected decorated start method {start} to have type annotations" + ) + return None, None + output_type = type_annotations.pop("return", None) + + if len(type_annotations) != 2: + # TODO(nexus-preview): stacklevel + suffix = f": {type_annotations}" if type_annotations else "" + warnings.warn( + f"Expected decorated start method {start} to have exactly 2 " + f"type-annotated parameters (ctx and input), but it has {len(type_annotations)}" + f"{suffix}." + ) + input_type = None + else: + ctx_type, input_type = type_annotations.values() + if not issubclass(ctx_type, StartOperationContext): + # TODO(nexus-preview): stacklevel + warnings.warn( + f"Expected first parameter of {start} to be an instance of " + f"StartOperationContext, but is {ctx_type}." + ) + input_type = None + + return input_type, output_type + # Copied from https://github.com/modelcontextprotocol/python-sdk # diff --git a/tests/nexus/test_get_input_and_output_types.py b/tests/nexus/test_get_input_and_output_types.py new file mode 100644 index 000000000..fcfa0fa8b --- /dev/null +++ b/tests/nexus/test_get_input_and_output_types.py @@ -0,0 +1,153 @@ +import warnings +from typing import ( + Any, + Awaitable, + Type, + Union, + get_args, + get_origin, +) + +import pytest +from nexusrpc.handler import ( + StartOperationContext, +) + +from temporalio.nexus.handler._util import ( + _get_start_method_input_and_output_type_annotations, +) + + +class Input: + pass + + +class Output: + pass + + +class _TestCase: + @staticmethod + def start(ctx: StartOperationContext, i: Input) -> Output: ... + + expected_types: tuple[Any, Any] + + +class SyncMethod(_TestCase): + @staticmethod + def start(ctx: StartOperationContext, i: Input) -> Output: ... + + expected_types = (Input, Output) + + +class AsyncMethod(_TestCase): + @staticmethod + async def start(ctx: StartOperationContext, i: Input) -> Output: ... + + expected_types = (Input, Output) + + +class UnionMethod(_TestCase): + @staticmethod + def start( + ctx: StartOperationContext, i: Input + ) -> Union[Output, Awaitable[Output]]: ... + + expected_types = (Input, Union[Output, Awaitable[Output]]) + + +class MissingInputAnnotationInUnionMethod(_TestCase): + @staticmethod + def start(ctx: StartOperationContext, i) -> Union[Output, Awaitable[Output]]: ... + + expected_types = (None, Union[Output, Awaitable[Output]]) + + +class TooFewParams(_TestCase): + @staticmethod + def start(i: Input) -> Output: ... + + expected_types = (None, Output) + + +class TooManyParams(_TestCase): + @staticmethod + def start(ctx: StartOperationContext, i: Input, extra: int) -> Output: ... + + expected_types = (None, Output) + + +class WrongOptionsType(_TestCase): + @staticmethod + def start(ctx: int, i: Input) -> Output: ... + + expected_types = (None, Output) + + +class NoReturnHint(_TestCase): + @staticmethod + def start(ctx: StartOperationContext, i: Input): ... + + expected_types = (Input, None) + + +class NoInputAnnotation(_TestCase): + @staticmethod + def start(ctx: StartOperationContext, i) -> Output: ... + + expected_types = (None, Output) + + +class NoOptionsAnnotation(_TestCase): + @staticmethod + def start(ctx, i: Input) -> Output: ... + + expected_types = (None, Output) + + +class AllAnnotationsMissing(_TestCase): + @staticmethod + def start(ctx: StartOperationContext, i): ... + + expected_types = (None, None) + + +class ExplicitNoneTypes(_TestCase): + @staticmethod + def start(ctx: StartOperationContext, i: None) -> None: ... + + expected_types = (type(None), type(None)) + + +@pytest.mark.parametrize( + "test_case", + [ + SyncMethod, + AsyncMethod, + UnionMethod, + TooFewParams, + TooManyParams, + WrongOptionsType, + NoReturnHint, + NoInputAnnotation, + NoOptionsAnnotation, + MissingInputAnnotationInUnionMethod, + AllAnnotationsMissing, + ExplicitNoneTypes, + ], +) +def test_get_input_and_output_types(test_case: Type[_TestCase]): + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + input_type, output_type = _get_start_method_input_and_output_type_annotations( + test_case.start + ) + expected_input_type, expected_output_type = test_case.expected_types + assert input_type is expected_input_type + + expected_origin = get_origin(expected_output_type) + if expected_origin: # Awaitable and Union cases + assert get_origin(output_type) is expected_origin + assert get_args(output_type) == get_args(expected_output_type) + else: + assert output_type is expected_output_type From e04218f93be98901d7daab379c5ce618865d0b70 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 08:11:47 -0400 Subject: [PATCH 039/183] Tests: clean up type annotation warnings --- tests/nexus/test_handler.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 0df1d4811..744727a1c 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -311,7 +311,9 @@ async def start(ctx, input): def workflow_run_op_link_test( self, ) -> nexusrpc.handler.OperationHandler[Input, Output]: - async def start(ctx, input): + async def start( + ctx: StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: assert any( link.url == "http://inbound-link/" for link in ctx.inbound_links ), "Inbound link not found" @@ -1122,7 +1124,9 @@ class ServiceHandlerForRequestIdTest: def operation_backed_by_a_workflow( self, ) -> nexusrpc.handler.OperationHandler[Input, Output]: - async def start(ctx, input) -> WorkflowOperationToken[Output]: + async def start( + ctx: StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: tctx = TemporalOperationContext.current() return await tctx.start_workflow( EchoWorkflow.run, @@ -1137,7 +1141,9 @@ async def start(ctx, input) -> WorkflowOperationToken[Output]: def operation_that_executes_a_workflow_before_starting_the_backing_workflow( self, ) -> nexusrpc.handler.OperationHandler[Input, Output]: - async def start(ctx, input) -> WorkflowOperationToken[Output]: + async def start( + ctx: StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: tctx = TemporalOperationContext.current() await tctx.client.start_workflow( EchoWorkflow.run, From 9e99ca78faee6de6912fd00d4a7328f0b380b52b Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 08:12:23 -0400 Subject: [PATCH 040/183] Improve type annotation warnings --- temporalio/nexus/handler/_util.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/temporalio/nexus/handler/_util.py b/temporalio/nexus/handler/_util.py index ddee73ffa..03057e147 100644 --- a/temporalio/nexus/handler/_util.py +++ b/temporalio/nexus/handler/_util.py @@ -42,22 +42,26 @@ def get_workflow_run_start_method_input_and_output_type_annotations( """ input_type, output_type = _get_start_method_input_and_output_type_annotations(start) origin_type = typing.get_origin(output_type) - if not origin_type or not issubclass(origin_type, WorkflowOperationToken): + if not origin_type: + output_type = None + elif not issubclass(origin_type, WorkflowOperationToken): warnings.warn( f"Expected return type of {start.__name__} to be a subclass of WorkflowOperationToken, " f"but is {output_type}" ) output_type = None - args = typing.get_args(output_type) - if len(args) != 1: - warnings.warn( - f"Expected return type of {start.__name__} to have exactly one type parameter, " - f"but has {len(args)}: {args}" - ) - output_type = None - else: - [output_type] = args + if output_type: + args = typing.get_args(output_type) + if len(args) != 1: + suffix = f": {args}" if args else "" + warnings.warn( + f"Expected return type {output_type} of {start.__name__} to have exactly one type parameter, " + f"but has {len(args)}{suffix}." + ) + output_type = None + else: + [output_type] = args return input_type, output_type From 8046b97c2ea4ec54fce8c916eaf08c0604027a5c Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 08:20:45 -0400 Subject: [PATCH 041/183] Cleanup --- temporalio/client.py | 22 +++++++++++++--------- temporalio/nexus/__init__.py | 1 - 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 2a65f88e3..6a67e328e 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -5818,17 +5818,22 @@ async def _build_start_workflow_execution_request( self, input: StartWorkflowInput ) -> temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest: req = temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest() + await self._populate_start_workflow_execution_request(req, input) + # _populate_start_workflow_execution_request is used for both StartWorkflowInput + # and UpdateWithStartStartWorkflowInput. UpdateWithStartStartWorkflowInput does + # not have the following two fields so they are handled here. req.request_eager_execution = input.request_eager_start if input.request_id: req.request_id = input.request_id - await self._populate_start_workflow_execution_request(req, input) - for callback in input.nexus_completion_callbacks: - c = temporalio.api.common.v1.Callback() - c.nexus.url = callback.url - c.nexus.header.update(callback.header) - req.completion_callbacks.append(c) - + req.completion_callbacks.extend( + temporalio.api.common.v1.Callback( + nexus=temporalio.api.common.v1.Callback.Nexus( + url=callback.url, header=callback.header + ) + ) + for callback in input.nexus_completion_callbacks + ) req.links.extend( temporalio.api.common.v1.Link(workflow_event=link) for link in input.workflow_event_links @@ -5879,8 +5884,7 @@ async def _populate_start_workflow_execution_request( if input.task_timeout is not None: req.workflow_task_timeout.FromTimedelta(input.task_timeout) req.identity = self._client.identity - if not req.request_id: - req.request_id = str(uuid.uuid4()) + req.request_id = str(uuid.uuid4()) req.workflow_id_reuse_policy = cast( "temporalio.api.enums.v1.WorkflowIdReusePolicy.ValueType", int(input.id_reuse_policy), diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index 571965eb9..e69de29bb 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -1 +0,0 @@ -from . import handler as handler From 4ceba6d0829fa31afc37107f622a3e91fbd5626b Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 08:43:06 -0400 Subject: [PATCH 042/183] Import nexus.handler.logger in worker --- temporalio/worker/_nexus.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index cd02d6520..99b273803 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -5,7 +5,6 @@ import asyncio import concurrent.futures import json -import logging from dataclasses import dataclass from typing import ( Any, @@ -31,15 +30,12 @@ import temporalio.common import temporalio.converter import temporalio.nexus -import temporalio.nexus.handler from temporalio.exceptions import ApplicationError -from temporalio.nexus.handler import TemporalOperationContext +from temporalio.nexus.handler import TemporalOperationContext, logger from temporalio.service import RPCError, RPCStatusCode from ._interceptor import Interceptor -logger = logging.getLogger(__name__) - class _NexusWorker: def __init__( @@ -124,7 +120,7 @@ async def raise_from_exception_queue() -> NoReturn: # TODO(nexus-prerelease): when do we remove the entry from _running_operations? _task.cancel() else: - temporalio.nexus.handler.logger.warning( + logger.warning( f"Received cancel_task but no running operation exists for " f"task token: {task.task_token}" ) @@ -184,9 +180,7 @@ async def _handle_cancel_operation_task( try: await self._handler.cancel_operation(ctx, request.operation_token) except Exception as err: - temporalio.nexus.handler.logger.exception( - "Failed to execute Nexus cancel operation method" - ) + logger.exception("Failed to execute Nexus cancel operation method") completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( task_token=task_token, error=await self._handler_error_to_proto( @@ -204,9 +198,7 @@ async def _handle_cancel_operation_task( try: await self._bridge_worker().complete_nexus_task(completion) except Exception: - temporalio.nexus.handler.logger.exception( - "Failed to send Nexus task completion" - ) + logger.exception("Failed to send Nexus task completion") async def _handle_start_operation_task( self, @@ -241,16 +233,12 @@ async def _handle_start_operation_task( try: await self._bridge_worker().complete_nexus_task(completion) except Exception: - temporalio.nexus.handler.logger.exception( - "Failed to send Nexus task completion" - ) + logger.exception("Failed to send Nexus task completion") finally: try: del self._running_tasks[task_token] except KeyError: - temporalio.nexus.handler.logger.exception( - "Failed to remove completed Nexus operation" - ) + logger.exception("Failed to remove completed Nexus operation") async def _start_operation( self, From 7bd6bb5fbae774e7872804b552a55573a8c6ca9b Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 08:55:05 -0400 Subject: [PATCH 043/183] Do not issue warnings when user is not using type annotations --- temporalio/nexus/handler/_util.py | 4 ++++ tests/nexus/test_handler.py | 39 +++++++++++++++++-------------- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/temporalio/nexus/handler/_util.py b/temporalio/nexus/handler/_util.py index 03057e147..09f4c0939 100644 --- a/temporalio/nexus/handler/_util.py +++ b/temporalio/nexus/handler/_util.py @@ -86,6 +86,10 @@ def _get_start_method_input_and_output_type_annotations( f"Expected decorated start method {start} to have type annotations" ) return None, None + + if not type_annotations: + return None, None + output_type = type_annotations.pop("return", None) if len(type_annotations) != 2: diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 744727a1c..91a7c1013 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -860,25 +860,28 @@ async def _test_start_operation( ), ) - decorator = ( - nexusrpc.handler.service_handler(service=MyService) - if with_service_definition - else nexusrpc.handler.service_handler - ) - service_handler = decorator(MyServiceHandler)() - - async with Worker( - env.client, - task_queue=task_queue, - nexus_service_handlers=[service_handler], - nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), - ): - response = await service_client.start_operation( - test_case.operation, - dataclass_as_dict(test_case.input), - test_case.headers, + with pytest.WarningsRecorder() as warnings: + decorator = ( + nexusrpc.handler.service_handler(service=MyService) + if with_service_definition + else nexusrpc.handler.service_handler ) - test_case.check_response(response, with_service_definition) + service_handler = decorator(MyServiceHandler)() + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[service_handler], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + response = await service_client.start_operation( + test_case.operation, + dataclass_as_dict(test_case.input), + test_case.headers, + ) + test_case.check_response(response, with_service_definition) + + assert not any(warnings), [w.message for w in warnings] async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: Any): From f95862749c9a47190bb531e8484f58190b0d63df Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 10:17:35 -0400 Subject: [PATCH 044/183] Remove redundant validation It is done by nexusrpc --- temporalio/worker/_nexus.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 99b273803..c300d2857 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -55,13 +55,6 @@ def __init__( self._bridge_worker = bridge_worker self._client = client self._task_queue = task_queue - - for service in service_handlers: - if isinstance(service, type): - raise TypeError( - f"Expected a service instance, but got a class: {service}. " - "Nexus services must be passed as instances, not classes." - ) self._handler = Handler(service_handlers, executor) self._data_converter = data_converter # TODO(nexus-preview): interceptors From 6fcf72f28de456e4cde786bd5b859ad2e2da95dd Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 10:50:37 -0400 Subject: [PATCH 045/183] Respond to code review comments --- temporalio/worker/_nexus.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index c300d2857..684200bff 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -113,8 +113,8 @@ async def raise_from_exception_queue() -> NoReturn: # TODO(nexus-prerelease): when do we remove the entry from _running_operations? _task.cancel() else: - logger.warning( - f"Received cancel_task but no running operation exists for " + logger.debug( + f"Received cancel_task but no running task exists for " f"task token: {task.task_token}" ) else: From 3ed51748741ba1fe4b52fd5f09b57d365cca18c1 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 10:50:54 -0400 Subject: [PATCH 046/183] Don't swallow exceptions when encoding failures --- temporalio/worker/_nexus.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 684200bff..cd98b0f00 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -308,14 +308,20 @@ async def _exception_to_failure_proto( self, err: BaseException, ) -> temporalio.api.nexus.v1.Failure: - api_failure = temporalio.api.failure.v1.Failure() - await self._data_converter.encode_failure(err, api_failure) - api_failure = google.protobuf.json_format.MessageToDict(api_failure) - return temporalio.api.nexus.v1.Failure( - message=api_failure.pop("message", ""), - metadata={"type": "temporal.api.failure.v1.Failure"}, - details=json.dumps(api_failure).encode("utf-8"), - ) + try: + api_failure = temporalio.api.failure.v1.Failure() + await self._data_converter.encode_failure(err, api_failure) + api_failure = google.protobuf.json_format.MessageToDict(api_failure) + return temporalio.api.nexus.v1.Failure( + message=api_failure.pop("message", ""), + metadata={"type": "temporal.api.failure.v1.Failure"}, + details=json.dumps(api_failure).encode("utf-8"), + ) + except BaseException as err: + return temporalio.api.nexus.v1.Failure( + message=f"{err.__class__.__name__}: {err}", + metadata={"type": "temporal.api.failure.v1.Failure"}, + ) async def _operation_error_to_proto( self, From 7fcd501454c1cc929afb5ac682c6c0353f4da5df Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 11:05:53 -0400 Subject: [PATCH 047/183] Catch BaseException at top-level in worker --- temporalio/worker/_activity.py | 2 +- temporalio/worker/_nexus.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index c9f71834c..5386daa87 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -177,7 +177,7 @@ async def raise_from_exception_queue() -> NoReturn: except temporalio.bridge.worker.PollShutdownError: exception_task.cancel() return - except Exception as err: + except BaseException as err: exception_task.cancel() raise RuntimeError("Activity worker failed") from err diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index cd98b0f00..53929ef16 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -124,7 +124,7 @@ async def raise_from_exception_queue() -> NoReturn: exception_task.cancel() return - except Exception as err: + except BaseException as err: raise RuntimeError("Nexus worker failed") from err # Only call this if run() raised an error From 04ce78d3c4b403855f4cca4df1cbaa682c97cf33 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 11:06:42 -0400 Subject: [PATCH 048/183] Fail worker on broken executor --- temporalio/worker/_nexus.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 53929ef16..576358611 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -208,13 +208,14 @@ async def _handle_start_operation_task( try: start_response = await self._start_operation(start_request, headers) - # TODO(nexus-prerelease): handle BrokenExecutor by failing the worker except BaseException as err: handler_err = _exception_to_handler_error(err) completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( task_token=task_token, error=await self._handler_error_to_proto(handler_err), ) + if isinstance(err, concurrent.futures.BrokenExecutor): + self._fail_worker_exception_queue.put_nowait(err) else: completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( task_token=task_token, From 7ec57ff5768d38bf464869b765bf90ae78b09c82 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 11:30:16 -0400 Subject: [PATCH 049/183] Revert "Catch BaseException at top-level in worker" This reverts commit 6431f39ec9c43642a7f208c3cc6ea1e5018ec9e3. --- temporalio/worker/_activity.py | 2 +- temporalio/worker/_nexus.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 5386daa87..c9f71834c 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -177,7 +177,7 @@ async def raise_from_exception_queue() -> NoReturn: except temporalio.bridge.worker.PollShutdownError: exception_task.cancel() return - except BaseException as err: + except Exception as err: exception_task.cancel() raise RuntimeError("Activity worker failed") from err diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 576358611..08b9dc7f2 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -124,7 +124,7 @@ async def raise_from_exception_queue() -> NoReturn: exception_task.cancel() return - except BaseException as err: + except Exception as err: raise RuntimeError("Nexus worker failed") from err # Only call this if run() raised an error From 730130762b9bfc4ec5e2c88020b491708c8bcc4a Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 11:47:04 -0400 Subject: [PATCH 050/183] Cleanup --- temporalio/nexus/handler/__init__.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index 4afdc2c15..8861fa73b 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -1,10 +1,7 @@ -from __future__ import annotations - import logging -from collections.abc import Mapping from typing import ( - TYPE_CHECKING, Any, + Mapping, MutableMapping, Optional, ) @@ -26,14 +23,6 @@ from ._operation_handlers import cancel_operation as cancel_operation from ._token import WorkflowOperationToken as WorkflowOperationToken -if TYPE_CHECKING: - from temporalio.client import ( - Client as Client, - ) - from temporalio.client import ( - WorkflowHandle as WorkflowHandle, - ) - class LoggerAdapter(logging.LoggerAdapter): def __init__(self, logger: logging.Logger, extra: Optional[Mapping[str, Any]]): From 4cbb09a89d9acf06cdaa7478cc95d82159c161eb Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 11:47:38 -0400 Subject: [PATCH 051/183] Change context method name: .current() -> .get() --- temporalio/nexus/handler/__init__.py | 2 +- temporalio/nexus/handler/_operation_context.py | 14 +++----------- temporalio/nexus/handler/_operation_handlers.py | 4 ++-- tests/nexus/test_handler.py | 10 +++++----- tests/nexus/test_workflow_caller.py | 6 +++--- 5 files changed, 14 insertions(+), 22 deletions(-) diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index 8861fa73b..b29527fbc 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -32,7 +32,7 @@ def process( self, msg: Any, kwargs: MutableMapping[str, Any] ) -> tuple[Any, MutableMapping[str, Any]]: extra = dict(self.extra or {}) - if tctx := TemporalOperationContext.current(): + if tctx := TemporalOperationContext.get(): extra["service"] = tctx.nexus_operation_context.service extra["operation"] = tctx.nexus_operation_context.operation extra["task_queue"] = tctx.task_queue diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index 54f8a7edd..129ee8c20 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -52,18 +52,10 @@ class TemporalOperationContext: task_queue: str """The task queue of the worker handling this Nexus operation.""" - # TODO(nexus-prerelease): I don't think I like these names. Perhaps .get(), or - # expose the contextvar directly in the public API. + # TODO(nexus-prerelease): Confirm how exactly we want to expose Temporal Nexus operation context @staticmethod - def try_current() -> Optional[TemporalOperationContext]: - return _current_context.get(None) - - @staticmethod - def current() -> TemporalOperationContext: - context = TemporalOperationContext.try_current() - if not context: - raise RuntimeError("Not in Nexus operation context") - return context + def get() -> TemporalOperationContext: + return _current_context.get() @staticmethod def set( diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index e4d4a9d8f..09f267877 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -86,7 +86,7 @@ async def fetch_result( type=HandlerErrorType.NOT_FOUND, cause=err, ) - tctx = TemporalOperationContext.current() + tctx = TemporalOperationContext.get() try: handle = workflow_token.to_workflow_handle( tctx.client, result_type=self._output_type @@ -120,7 +120,7 @@ async def cancel_operation( cause=err, ) - tctx = TemporalOperationContext.current() + tctx = TemporalOperationContext.get() try: handle = workflow_token.to_workflow_handle(tctx.client) except Exception as err: diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 91a7c1013..b41c30932 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -237,7 +237,7 @@ def workflow_run_operation( async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = TemporalOperationContext.current() + tctx = TemporalOperationContext.get() return await tctx.start_workflow( MyWorkflow.run, input, @@ -298,7 +298,7 @@ def workflow_run_operation_without_type_annotations( self, ) -> nexusrpc.handler.OperationHandler[Input, Output]: async def start(ctx, input): - tctx = TemporalOperationContext.current() + tctx = TemporalOperationContext.get() return await tctx.start_workflow( WorkflowWithoutTypeAnnotations.run, input, @@ -320,7 +320,7 @@ async def start( assert ctx.request_id == "test-request-id-123", "Request ID mismatch" ctx.outbound_links.extend(ctx.inbound_links) - tctx = TemporalOperationContext.current() + tctx = TemporalOperationContext.get() return await tctx.start_workflow( MyLinkTestWorkflow.run, input, @@ -1130,7 +1130,7 @@ def operation_backed_by_a_workflow( async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = TemporalOperationContext.current() + tctx = TemporalOperationContext.get() return await tctx.start_workflow( EchoWorkflow.run, input, @@ -1147,7 +1147,7 @@ def operation_that_executes_a_workflow_before_starting_the_backing_workflow( async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = TemporalOperationContext.current() + tctx = TemporalOperationContext.get() await tctx.client.start_workflow( EchoWorkflow.run, input, diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index a588ef38d..3145c98d2 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -154,7 +154,7 @@ async def start( value=OpOutput(value="sync response") ) elif isinstance(input.response_type, AsyncResponse): - tctx = TemporalOperationContext.current() + tctx = TemporalOperationContext.get() token = await tctx.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), @@ -216,7 +216,7 @@ async def start( RPCStatusCode.INVALID_ARGUMENT, b"", ) - tctx = TemporalOperationContext.current() + tctx = TemporalOperationContext.get() return await tctx.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), @@ -946,7 +946,7 @@ def my_workflow_run_operation( async def start( ctx: StartOperationContext, input: None ) -> WorkflowOperationToken[str]: - tctx = TemporalOperationContext.current() + tctx = TemporalOperationContext.get() result_1 = await tctx.client.execute_workflow( EchoWorkflow.run, "result-1", From 1b84f11bd15856aa0366410fe9fc103bd17e9a2c Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 11:53:11 -0400 Subject: [PATCH 052/183] Rename: TemporalNexusOperationContext --- temporalio/nexus/handler/__init__.py | 6 ++++-- .../nexus/handler/_operation_context.py | 20 +++++++++---------- .../nexus/handler/_operation_handlers.py | 6 +++--- temporalio/worker/_nexus.py | 10 +++++----- tests/nexus/test_handler.py | 12 +++++------ tests/nexus/test_workflow_caller.py | 11 ++++++---- 6 files changed, 35 insertions(+), 30 deletions(-) diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index b29527fbc..7cad5d25e 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -16,7 +16,9 @@ HandlerErrorType as HandlerErrorType, ) -from ._operation_context import TemporalOperationContext as TemporalOperationContext +from ._operation_context import ( + TemporalNexusOperationContext as TemporalNexusOperationContext, +) from ._operation_handlers import ( WorkflowRunOperationHandler as WorkflowRunOperationHandler, ) @@ -32,7 +34,7 @@ def process( self, msg: Any, kwargs: MutableMapping[str, Any] ) -> tuple[Any, MutableMapping[str, Any]]: extra = dict(self.extra or {}) - if tctx := TemporalOperationContext.get(): + if tctx := TemporalNexusOperationContext.get(): extra["service"] = tctx.nexus_operation_context.service extra["operation"] = tctx.nexus_operation_context.operation extra["task_queue"] = tctx.task_queue diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index 129ee8c20..080f37505 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -33,13 +33,13 @@ logger = logging.getLogger(__name__) -_current_context: ContextVar[TemporalOperationContext] = ContextVar( - "temporal-operation-context" +temporal_nexus_operation_context: ContextVar[TemporalNexusOperationContext] = ( + ContextVar("temporal-nexus-operation-context") ) @dataclass -class TemporalOperationContext: +class TemporalNexusOperationContext: """ Context for a Nexus operation being handled by a Temporal Nexus Worker. """ @@ -54,18 +54,18 @@ class TemporalOperationContext: # TODO(nexus-prerelease): Confirm how exactly we want to expose Temporal Nexus operation context @staticmethod - def get() -> TemporalOperationContext: - return _current_context.get() + def get() -> TemporalNexusOperationContext: + return temporal_nexus_operation_context.get() @staticmethod def set( - context: TemporalOperationContext, - ) -> contextvars.Token[TemporalOperationContext]: - return _current_context.set(context) + context: TemporalNexusOperationContext, + ) -> contextvars.Token[TemporalNexusOperationContext]: + return temporal_nexus_operation_context.set(context) @staticmethod - def reset(token: contextvars.Token[TemporalOperationContext]) -> None: - _current_context.reset(token) + def reset(token: contextvars.Token[TemporalNexusOperationContext]) -> None: + temporal_nexus_operation_context.reset(token) @property def temporal_start_operation_context( diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 09f267877..56c7f3737 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -21,7 +21,7 @@ ServiceHandlerT, ) -from temporalio.nexus.handler._operation_context import TemporalOperationContext +from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext from temporalio.nexus.handler._token import WorkflowOperationToken from ._util import ( @@ -86,7 +86,7 @@ async def fetch_result( type=HandlerErrorType.NOT_FOUND, cause=err, ) - tctx = TemporalOperationContext.get() + tctx = TemporalNexusOperationContext.get() try: handle = workflow_token.to_workflow_handle( tctx.client, result_type=self._output_type @@ -120,7 +120,7 @@ async def cancel_operation( cause=err, ) - tctx = TemporalOperationContext.get() + tctx = TemporalNexusOperationContext.get() try: handle = workflow_token.to_workflow_handle(tctx.client) except Exception as err: diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 08b9dc7f2..3cf918793 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -31,7 +31,7 @@ import temporalio.converter import temporalio.nexus from temporalio.exceptions import ApplicationError -from temporalio.nexus.handler import TemporalOperationContext, logger +from temporalio.nexus.handler import TemporalNexusOperationContext, logger from temporalio.service import RPCError, RPCStatusCode from ._interceptor import Interceptor @@ -162,8 +162,8 @@ async def _handle_cancel_operation_task( service=request.service, operation=request.operation, ) - TemporalOperationContext.set( - TemporalOperationContext( + TemporalNexusOperationContext.set( + TemporalNexusOperationContext( nexus_operation_context=ctx, client=self._client, task_queue=self._task_queue, @@ -258,8 +258,8 @@ async def _start_operation( ], callback_headers=dict(start_request.callback_header), ) - TemporalOperationContext.set( - TemporalOperationContext( + TemporalNexusOperationContext.set( + TemporalNexusOperationContext( nexus_operation_context=ctx, client=self._client, task_queue=self._task_queue, diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index b41c30932..6cfe6741c 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -50,7 +50,7 @@ from temporalio.nexus.handler import ( logger, ) -from temporalio.nexus.handler._operation_context import TemporalOperationContext +from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext from temporalio.nexus.handler._token import WorkflowOperationToken from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker @@ -237,7 +237,7 @@ def workflow_run_operation( async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = TemporalOperationContext.get() + tctx = TemporalNexusOperationContext.get() return await tctx.start_workflow( MyWorkflow.run, input, @@ -298,7 +298,7 @@ def workflow_run_operation_without_type_annotations( self, ) -> nexusrpc.handler.OperationHandler[Input, Output]: async def start(ctx, input): - tctx = TemporalOperationContext.get() + tctx = TemporalNexusOperationContext.get() return await tctx.start_workflow( WorkflowWithoutTypeAnnotations.run, input, @@ -320,7 +320,7 @@ async def start( assert ctx.request_id == "test-request-id-123", "Request ID mismatch" ctx.outbound_links.extend(ctx.inbound_links) - tctx = TemporalOperationContext.get() + tctx = TemporalNexusOperationContext.get() return await tctx.start_workflow( MyLinkTestWorkflow.run, input, @@ -1130,7 +1130,7 @@ def operation_backed_by_a_workflow( async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = TemporalOperationContext.get() + tctx = TemporalNexusOperationContext.get() return await tctx.start_workflow( EchoWorkflow.run, input, @@ -1147,7 +1147,7 @@ def operation_that_executes_a_workflow_before_starting_the_backing_workflow( async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = TemporalOperationContext.get() + tctx = TemporalNexusOperationContext.get() await tctx.client.start_workflow( EchoWorkflow.run, input, diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 3145c98d2..274f53780 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -33,7 +33,10 @@ ) from temporalio.common import WorkflowIDConflictPolicy from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError -from temporalio.nexus.handler import TemporalOperationContext, WorkflowOperationToken +from temporalio.nexus.handler import ( + TemporalNexusOperationContext, + WorkflowOperationToken, +) from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -154,7 +157,7 @@ async def start( value=OpOutput(value="sync response") ) elif isinstance(input.response_type, AsyncResponse): - tctx = TemporalOperationContext.get() + tctx = TemporalNexusOperationContext.get() token = await tctx.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), @@ -216,7 +219,7 @@ async def start( RPCStatusCode.INVALID_ARGUMENT, b"", ) - tctx = TemporalOperationContext.get() + tctx = TemporalNexusOperationContext.get() return await tctx.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), @@ -946,7 +949,7 @@ def my_workflow_run_operation( async def start( ctx: StartOperationContext, input: None ) -> WorkflowOperationToken[str]: - tctx = TemporalOperationContext.get() + tctx = TemporalNexusOperationContext.get() result_1 = await tctx.client.execute_workflow( EchoWorkflow.run, "result-1", From bdfc0197415b115e50cce58c21c8872e853d571b Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 12:03:04 -0400 Subject: [PATCH 053/183] Expose contextvar object directly --- temporalio/nexus/handler/__init__.py | 7 ++++-- .../nexus/handler/_operation_context.py | 23 ++++--------------- .../nexus/handler/_operation_handlers.py | 12 ++++++---- temporalio/worker/_nexus.py | 14 +++++++---- tests/nexus/test_handler.py | 14 ++++++----- tests/nexus/test_workflow_caller.py | 8 +++---- 6 files changed, 37 insertions(+), 41 deletions(-) diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index 7cad5d25e..629c412d3 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -17,7 +17,10 @@ ) from ._operation_context import ( - TemporalNexusOperationContext as TemporalNexusOperationContext, + _TemporalNexusOperationContext as _TemporalNexusOperationContext, +) +from ._operation_context import ( + temporal_operation_context as temporal_operation_context, ) from ._operation_handlers import ( WorkflowRunOperationHandler as WorkflowRunOperationHandler, @@ -34,7 +37,7 @@ def process( self, msg: Any, kwargs: MutableMapping[str, Any] ) -> tuple[Any, MutableMapping[str, Any]]: extra = dict(self.extra or {}) - if tctx := TemporalNexusOperationContext.get(): + if tctx := temporal_operation_context.get(None): extra["service"] = tctx.nexus_operation_context.service extra["operation"] = tctx.nexus_operation_context.operation extra["task_queue"] = tctx.task_queue diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index 080f37505..0a869307f 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -1,6 +1,5 @@ from __future__ import annotations -import contextvars import logging import re import urllib.parse @@ -32,14 +31,15 @@ logger = logging.getLogger(__name__) +# TODO(nexus-prerelease): Confirm how exactly we want to expose Temporal Nexus operation context -temporal_nexus_operation_context: ContextVar[TemporalNexusOperationContext] = ( - ContextVar("temporal-nexus-operation-context") +temporal_operation_context: ContextVar[_TemporalNexusOperationContext] = ContextVar( + "temporal-operation-context" ) @dataclass -class TemporalNexusOperationContext: +class _TemporalNexusOperationContext: """ Context for a Nexus operation being handled by a Temporal Nexus Worker. """ @@ -52,21 +52,6 @@ class TemporalNexusOperationContext: task_queue: str """The task queue of the worker handling this Nexus operation.""" - # TODO(nexus-prerelease): Confirm how exactly we want to expose Temporal Nexus operation context - @staticmethod - def get() -> TemporalNexusOperationContext: - return temporal_nexus_operation_context.get() - - @staticmethod - def set( - context: TemporalNexusOperationContext, - ) -> contextvars.Token[TemporalNexusOperationContext]: - return temporal_nexus_operation_context.set(context) - - @staticmethod - def reset(token: contextvars.Token[TemporalNexusOperationContext]) -> None: - temporal_nexus_operation_context.reset(token) - @property def temporal_start_operation_context( self, diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 56c7f3737..5453664d0 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -21,7 +21,9 @@ ServiceHandlerT, ) -from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext +from temporalio.nexus.handler._operation_context import ( + temporal_operation_context, +) from temporalio.nexus.handler._token import WorkflowOperationToken from ._util import ( @@ -86,10 +88,10 @@ async def fetch_result( type=HandlerErrorType.NOT_FOUND, cause=err, ) - tctx = TemporalNexusOperationContext.get() + ctx = temporal_operation_context.get() try: handle = workflow_token.to_workflow_handle( - tctx.client, result_type=self._output_type + ctx.client, result_type=self._output_type ) except Exception as err: raise HandlerError( @@ -120,9 +122,9 @@ async def cancel_operation( cause=err, ) - tctx = TemporalNexusOperationContext.get() + ctx = temporal_operation_context.get() try: - handle = workflow_token.to_workflow_handle(tctx.client) + handle = workflow_token.to_workflow_handle(ctx.client) except Exception as err: raise HandlerError( "Failed to construct workflow handle from workflow operation token", diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 3cf918793..5abad641f 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -31,7 +31,11 @@ import temporalio.converter import temporalio.nexus from temporalio.exceptions import ApplicationError -from temporalio.nexus.handler import TemporalNexusOperationContext, logger +from temporalio.nexus.handler import ( + _TemporalNexusOperationContext, + logger, + temporal_operation_context, +) from temporalio.service import RPCError, RPCStatusCode from ._interceptor import Interceptor @@ -162,8 +166,8 @@ async def _handle_cancel_operation_task( service=request.service, operation=request.operation, ) - TemporalNexusOperationContext.set( - TemporalNexusOperationContext( + temporal_operation_context.set( + _TemporalNexusOperationContext( nexus_operation_context=ctx, client=self._client, task_queue=self._task_queue, @@ -258,8 +262,8 @@ async def _start_operation( ], callback_headers=dict(start_request.callback_header), ) - TemporalNexusOperationContext.set( - TemporalNexusOperationContext( + temporal_operation_context.set( + _TemporalNexusOperationContext( nexus_operation_context=ctx, client=self._client, task_queue=self._task_queue, diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 6cfe6741c..88cfa03e2 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -50,7 +50,9 @@ from temporalio.nexus.handler import ( logger, ) -from temporalio.nexus.handler._operation_context import TemporalNexusOperationContext +from temporalio.nexus.handler._operation_context import ( + temporal_operation_context, +) from temporalio.nexus.handler._token import WorkflowOperationToken from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker @@ -237,7 +239,7 @@ def workflow_run_operation( async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = TemporalNexusOperationContext.get() + tctx = temporal_operation_context.get() return await tctx.start_workflow( MyWorkflow.run, input, @@ -298,7 +300,7 @@ def workflow_run_operation_without_type_annotations( self, ) -> nexusrpc.handler.OperationHandler[Input, Output]: async def start(ctx, input): - tctx = TemporalNexusOperationContext.get() + tctx = temporal_operation_context.get() return await tctx.start_workflow( WorkflowWithoutTypeAnnotations.run, input, @@ -320,7 +322,7 @@ async def start( assert ctx.request_id == "test-request-id-123", "Request ID mismatch" ctx.outbound_links.extend(ctx.inbound_links) - tctx = TemporalNexusOperationContext.get() + tctx = temporal_operation_context.get() return await tctx.start_workflow( MyLinkTestWorkflow.run, input, @@ -1130,7 +1132,7 @@ def operation_backed_by_a_workflow( async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = TemporalNexusOperationContext.get() + tctx = temporal_operation_context.get() return await tctx.start_workflow( EchoWorkflow.run, input, @@ -1147,7 +1149,7 @@ def operation_that_executes_a_workflow_before_starting_the_backing_workflow( async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = TemporalNexusOperationContext.get() + tctx = temporal_operation_context.get() await tctx.client.start_workflow( EchoWorkflow.run, input, diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 274f53780..9c09a34a2 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -34,8 +34,8 @@ from temporalio.common import WorkflowIDConflictPolicy from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError from temporalio.nexus.handler import ( - TemporalNexusOperationContext, WorkflowOperationToken, + temporal_operation_context, ) from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import UnsandboxedWorkflowRunner, Worker @@ -157,7 +157,7 @@ async def start( value=OpOutput(value="sync response") ) elif isinstance(input.response_type, AsyncResponse): - tctx = TemporalNexusOperationContext.get() + tctx = temporal_operation_context.get() token = await tctx.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), @@ -219,7 +219,7 @@ async def start( RPCStatusCode.INVALID_ARGUMENT, b"", ) - tctx = TemporalNexusOperationContext.get() + tctx = temporal_operation_context.get() return await tctx.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), @@ -949,7 +949,7 @@ def my_workflow_run_operation( async def start( ctx: StartOperationContext, input: None ) -> WorkflowOperationToken[str]: - tctx = TemporalNexusOperationContext.get() + tctx = temporal_operation_context.get() result_1 = await tctx.client.execute_workflow( EchoWorkflow.run, "result-1", From 2b1decee70a2f45c337d4511ad33676f231326a2 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 15:30:50 -0400 Subject: [PATCH 054/183] Mark methods as private --- temporalio/nexus/handler/_operation_context.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index 0a869307f..e4e725f4f 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -53,7 +53,7 @@ class _TemporalNexusOperationContext: """The task queue of the worker handling this Nexus operation.""" @property - def temporal_start_operation_context( + def _temporal_start_operation_context( self, ) -> Optional[_TemporalStartOperationContext]: ctx = self.nexus_operation_context @@ -62,7 +62,7 @@ def temporal_start_operation_context( return _TemporalStartOperationContext(ctx) @property - def temporal_cancel_operation_context( + def _temporal_cancel_operation_context( self, ) -> Optional[_TemporalCancelOperationContext]: ctx = self.nexus_operation_context @@ -71,7 +71,7 @@ def temporal_cancel_operation_context( return _TemporalCancelOperationContext(ctx) # Overload for single-param workflow - # TODO(nexus-prerelease): support other overloads? + # TODO(nexus-prerelease): bring over other overloads async def start_workflow( self, workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], @@ -132,7 +132,7 @@ async def start_workflow( Nexus caller is itself a workflow, this means that the workflow in the caller namespace web UI will contain links to the started workflow, and vice versa. """ - start_operation_context = self.temporal_start_operation_context + start_operation_context = self._temporal_start_operation_context if not start_operation_context: raise RuntimeError( "temporalio.nexus.handler.start_workflow() must be called from " From 086efa5936e98c52c54337c66947e410e89427d9 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 15:36:53 -0400 Subject: [PATCH 055/183] Add run-time type check --- temporalio/nexus/handler/_operation_handlers.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 5453664d0..0d71ee298 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -21,6 +21,7 @@ ServiceHandlerT, ) +from temporalio.client import WorkflowHandle from temporalio.nexus.handler._operation_context import ( temporal_operation_context, ) @@ -60,6 +61,18 @@ async def start( self, ctx: StartOperationContext, input: InputT ) -> nexusrpc.handler.StartOperationResultAsync: token = await self._start(ctx, input) + if not isinstance(token, WorkflowOperationToken): + if isinstance(token, WorkflowHandle): + raise RuntimeError( + f"Expected {token} to be a WorkflowOperationToken, but got a WorkflowHandle. " + f"You must use :py:meth:`temporalio.nexus.handler.start_workflow` " + "to start a workflow that will deliver the result of the Nexus operation, " + "not :py:meth:`temporalio.client.Client.start_workflow`." + ) + raise RuntimeError( + f"Expected {token} to be a WorkflowOperationToken, but got {type(token)}. " + "This is a bug in the Nexus SDK. Please report it to the Temporal team." + ) return StartOperationResultAsync(token.encode()) async def cancel(self, ctx: CancelOperationContext, token: str) -> None: From b6dfb9698ddfe4a296c5aea56290102b3cb028a5 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 15:46:31 -0400 Subject: [PATCH 056/183] Make start_workflow a static function --- temporalio/nexus/handler/__init__.py | 1 + .../nexus/handler/_operation_context.py | 126 ---------------- temporalio/nexus/handler/_workflow.py | 137 ++++++++++++++++++ tests/nexus/test_handler.py | 21 +-- tests/nexus/test_workflow_caller.py | 9 +- 5 files changed, 148 insertions(+), 146 deletions(-) create mode 100644 temporalio/nexus/handler/_workflow.py diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index 629c412d3..d684bcf6c 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -27,6 +27,7 @@ ) from ._operation_handlers import cancel_operation as cancel_operation from ._token import WorkflowOperationToken as WorkflowOperationToken +from ._workflow import start_workflow as start_workflow class LoggerAdapter(logging.LoggerAdapter): diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/handler/_operation_context.py index e4e725f4f..3f811b608 100644 --- a/temporalio/nexus/handler/_operation_context.py +++ b/temporalio/nexus/handler/_operation_context.py @@ -5,12 +5,9 @@ import urllib.parse from contextvars import ContextVar from dataclasses import dataclass -from datetime import timedelta from typing import ( Any, - Mapping, Optional, - Sequence, Union, ) @@ -21,13 +18,6 @@ import temporalio.api.enums.v1 import temporalio.common from temporalio.client import Client, NexusCompletionCallback, WorkflowHandle -from temporalio.nexus.handler._token import WorkflowOperationToken -from temporalio.types import ( - MethodAsyncSingleParam, - ParamType, - ReturnType, - SelfType, -) logger = logging.getLogger(__name__) @@ -70,122 +60,6 @@ def _temporal_cancel_operation_context( return None return _TemporalCancelOperationContext(ctx) - # Overload for single-param workflow - # TODO(nexus-prerelease): bring over other overloads - async def start_workflow( - self, - workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], - arg: ParamType, - *, - id: str, - task_queue: Optional[str] = None, - execution_timeout: Optional[timedelta] = None, - run_timeout: Optional[timedelta] = None, - task_timeout: Optional[timedelta] = None, - id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, - id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, - retry_policy: Optional[temporalio.common.RetryPolicy] = None, - cron_schedule: str = "", - memo: Optional[Mapping[str, Any]] = None, - search_attributes: Optional[ - Union[ - temporalio.common.TypedSearchAttributes, - temporalio.common.SearchAttributes, - ] - ] = None, - static_summary: Optional[str] = None, - static_details: Optional[str] = None, - start_delay: Optional[timedelta] = None, - start_signal: Optional[str] = None, - start_signal_args: Sequence[Any] = [], - rpc_metadata: Mapping[str, str] = {}, - rpc_timeout: Optional[timedelta] = None, - request_eager_start: bool = False, - priority: temporalio.common.Priority = temporalio.common.Priority.default, - versioning_override: Optional[temporalio.common.VersioningOverride] = None, - ) -> WorkflowOperationToken[ReturnType]: - """Start a workflow that will deliver the result of the Nexus operation. - - The workflow will be started in the same namespace as the Nexus worker, using - the same client as the worker. If task queue is not specified, the worker's task - queue will be used. - - See :py:meth:`temporalio.client.Client.start_workflow` for all arguments. - - The return value is :py:class:`temporalio.nexus.handler.WorkflowOperationToken`. - Use :py:meth:`temporalio.nexus.handler.WorkflowOperationToken.to_workflow_handle` - to get a :py:class:`temporalio.client.WorkflowHandle` for interacting with the - workflow. - - The workflow will be started as usual, with the following modifications: - - - On workflow completion, Temporal server will deliver the workflow result to - the Nexus operation caller, using the callback from the Nexus operation start - request. - - - The request ID from the Nexus operation start request will be used as the - request ID for the start workflow request. - - - Inbound links to the caller that were submitted in the Nexus start operation - request will be attached to the started workflow and, outbound links to the - started workflow will be added to the Nexus start operation response. If the - Nexus caller is itself a workflow, this means that the workflow in the caller - namespace web UI will contain links to the started workflow, and vice versa. - """ - start_operation_context = self._temporal_start_operation_context - if not start_operation_context: - raise RuntimeError( - "temporalio.nexus.handler.start_workflow() must be called from " - "within a Nexus start operation context" - ) - - # TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this: - # if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') { - # internalOptions.onConflictOptions = { - # attachLinks: true, - # attachCompletionCallbacks: true, - # attachRequestId: true, - # }; - # } - - # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, - # but these are deliberately not exposed in overloads, hence the type-check - # violation. - wf_handle = await self.client.start_workflow( # type: ignore - workflow=workflow, - arg=arg, - id=id, - task_queue=task_queue or self.task_queue, - execution_timeout=execution_timeout, - run_timeout=run_timeout, - task_timeout=task_timeout, - id_reuse_policy=id_reuse_policy, - id_conflict_policy=id_conflict_policy, - retry_policy=retry_policy, - cron_schedule=cron_schedule, - memo=memo, - search_attributes=search_attributes, - static_summary=static_summary, - static_details=static_details, - start_delay=start_delay, - start_signal=start_signal, - start_signal_args=start_signal_args, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - request_eager_start=request_eager_start, - priority=priority, - versioning_override=versioning_override, - nexus_completion_callbacks=start_operation_context.get_completion_callbacks(), - workflow_event_links=start_operation_context.get_workflow_event_links(), - request_id=start_operation_context.nexus_operation_context.request_id, - ) - - start_operation_context.add_outbound_links(wf_handle) - - return WorkflowOperationToken[ReturnType]._unsafe_from_workflow_handle( - wf_handle - ) - @dataclass class _TemporalStartOperationContext: diff --git a/temporalio/nexus/handler/_workflow.py b/temporalio/nexus/handler/_workflow.py new file mode 100644 index 000000000..f2da5a27e --- /dev/null +++ b/temporalio/nexus/handler/_workflow.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import ( + Any, + Mapping, + Optional, + Sequence, + Union, +) + +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.common +from temporalio.nexus.handler._operation_context import temporal_operation_context +from temporalio.nexus.handler._token import WorkflowOperationToken +from temporalio.types import ( + MethodAsyncSingleParam, + ParamType, + ReturnType, + SelfType, +) + + +# Overload for single-param workflow +# TODO(nexus-prerelease): bring over other overloads +async def start_workflow( + workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: Optional[str] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, +) -> WorkflowOperationToken[ReturnType]: + """Start a workflow that will deliver the result of the Nexus operation. + + The workflow will be started in the same namespace as the Nexus worker, using + the same client as the worker. If task queue is not specified, the worker's task + queue will be used. + + See :py:meth:`temporalio.client.Client.start_workflow` for all arguments. + + The return value is :py:class:`temporalio.nexus.handler.WorkflowOperationToken`. + Use :py:meth:`temporalio.nexus.handler.WorkflowOperationToken.to_workflow_handle` + to get a :py:class:`temporalio.client.WorkflowHandle` for interacting with the + workflow. + + The workflow will be started as usual, with the following modifications: + + - On workflow completion, Temporal server will deliver the workflow result to + the Nexus operation caller, using the callback from the Nexus operation start + request. + + - The request ID from the Nexus operation start request will be used as the + request ID for the start workflow request. + + - Inbound links to the caller that were submitted in the Nexus start operation + request will be attached to the started workflow and, outbound links to the + started workflow will be added to the Nexus start operation response. If the + Nexus caller is itself a workflow, this means that the workflow in the caller + namespace web UI will contain links to the started workflow, and vice versa. + """ + ctx = temporal_operation_context.get() + start_operation_context = ctx._temporal_start_operation_context + if not start_operation_context: + raise RuntimeError( + "temporalio.nexus.handler.start_workflow() must be called from " + "within a Nexus start operation context" + ) + + # TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this: + # if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') { + # internalOptions.onConflictOptions = { + # attachLinks: true, + # attachCompletionCallbacks: true, + # attachRequestId: true, + # }; + # } + + # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, + # but these are deliberately not exposed in overloads, hence the type-check + # violation. + wf_handle = await ctx.client.start_workflow( # type: ignore + workflow=workflow, + arg=arg, + id=id, + task_queue=task_queue or ctx.task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + start_signal=start_signal, + start_signal_args=start_signal_args, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + request_eager_start=request_eager_start, + priority=priority, + versioning_override=versioning_override, + nexus_completion_callbacks=start_operation_context.get_completion_callbacks(), + workflow_event_links=start_operation_context.get_workflow_event_links(), + request_id=start_operation_context.nexus_operation_context.request_id, + ) + + start_operation_context.add_outbound_links(wf_handle) + + return WorkflowOperationToken[ReturnType]._unsafe_from_workflow_handle(wf_handle) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 88cfa03e2..5f9d1ed56 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -47,12 +47,7 @@ from temporalio.common import WorkflowIDReusePolicy from temporalio.converter import FailureConverter, PayloadConverter from temporalio.exceptions import ApplicationError -from temporalio.nexus.handler import ( - logger, -) -from temporalio.nexus.handler._operation_context import ( - temporal_operation_context, -) +from temporalio.nexus.handler import logger, start_workflow, temporal_operation_context from temporalio.nexus.handler._token import WorkflowOperationToken from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker @@ -239,8 +234,7 @@ def workflow_run_operation( async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = temporal_operation_context.get() - return await tctx.start_workflow( + return await start_workflow( MyWorkflow.run, input, id=str(uuid.uuid4()), @@ -300,8 +294,7 @@ def workflow_run_operation_without_type_annotations( self, ) -> nexusrpc.handler.OperationHandler[Input, Output]: async def start(ctx, input): - tctx = temporal_operation_context.get() - return await tctx.start_workflow( + return await start_workflow( WorkflowWithoutTypeAnnotations.run, input, id=str(uuid.uuid4()), @@ -322,8 +315,7 @@ async def start( assert ctx.request_id == "test-request-id-123", "Request ID mismatch" ctx.outbound_links.extend(ctx.inbound_links) - tctx = temporal_operation_context.get() - return await tctx.start_workflow( + return await start_workflow( MyLinkTestWorkflow.run, input, id=str(uuid.uuid4()), @@ -1132,8 +1124,7 @@ def operation_backed_by_a_workflow( async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: - tctx = temporal_operation_context.get() - return await tctx.start_workflow( + return await start_workflow( EchoWorkflow.run, input, id=input.value, @@ -1158,7 +1149,7 @@ async def start( ) # This should fail. It will not fail if the Nexus request ID was incorrectly # propagated to both StartWorkflow requests. - return await tctx.start_workflow( + return await start_workflow( EchoWorkflow.run, input, id=input.value, diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 9c09a34a2..09c1e4b95 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -35,6 +35,7 @@ from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError from temporalio.nexus.handler import ( WorkflowOperationToken, + start_workflow, temporal_operation_context, ) from temporalio.service import RPCError, RPCStatusCode @@ -157,8 +158,7 @@ async def start( value=OpOutput(value="sync response") ) elif isinstance(input.response_type, AsyncResponse): - tctx = temporal_operation_context.get() - token = await tctx.start_workflow( + token = await start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, @@ -219,8 +219,7 @@ async def start( RPCStatusCode.INVALID_ARGUMENT, b"", ) - tctx = temporal_operation_context.get() - return await tctx.start_workflow( + return await start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, @@ -959,7 +958,7 @@ async def start( # In case result_1 is incorrectly being delivered to the caller as the operation # result, give time for that incorrect behavior to occur. await asyncio.sleep(0.5) - return await tctx.start_workflow( + return await start_workflow( EchoWorkflow.run, f"{result_1}-result-2", id=str(uuid.uuid4()), From 93e077597b79c685ea934ac6979eed765dec7285 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 15:54:42 -0400 Subject: [PATCH 057/183] Remove accidental exports --- temporalio/nexus/handler/__init__.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index d684bcf6c..995531b64 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -6,16 +6,6 @@ Optional, ) -from nexusrpc.handler import ( - CancelOperationContext as CancelOperationContext, -) -from nexusrpc.handler import ( - HandlerError as HandlerError, -) -from nexusrpc.handler import ( - HandlerErrorType as HandlerErrorType, -) - from ._operation_context import ( _TemporalNexusOperationContext as _TemporalNexusOperationContext, ) From 29344ad2f981e2edc65b197fa72d81b0fb91db7b Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 16:07:45 -0400 Subject: [PATCH 058/183] Docstrings --- .../nexus/handler/_operation_handlers.py | 31 +++++++++++++++++++ temporalio/nexus/handler/_token.py | 2 +- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 0d71ee298..60168629c 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -37,6 +37,32 @@ class WorkflowRunOperationHandler( nexusrpc.handler.OperationHandler[InputT, OutputT], Generic[InputT, OutputT, ServiceHandlerT], ): + """ + Operation handler for Nexus operations that start a workflow. + + Use this class to create an operation handler that starts a workflow by passing your + ``start`` method to the constructor. Your ``start`` method must use + :py:func:`temporalio.nexus.handler.start_workflow` to start the workflow. + + Example: + + .. code-block:: python + + @service_handler(service=MyNexusService) class MyNexusServiceHandler: + @operation_handler def my_workflow_run_operation( + self, + ) -> OperationHandler[MyInput, MyOutput]: + async def start( + ctx: StartOperationContext, input: MyInput + ) -> WorkflowOperationToken[MyOutput]: + return await start_workflow( + WorkflowStartedByNexusOperation.run, input, + id=str(uuid.uuid4()), + ) + + return WorkflowRunOperationHandler(start) + """ + def __init__( self, start: Callable[ @@ -60,6 +86,10 @@ def __init__( async def start( self, ctx: StartOperationContext, input: InputT ) -> nexusrpc.handler.StartOperationResultAsync: + """ + Start the operation, by starting a workflow and completing asynchronously. + """ + token = await self._start(ctx, input) if not isinstance(token, WorkflowOperationToken): if isinstance(token, WorkflowHandle): @@ -76,6 +106,7 @@ async def start( return StartOperationResultAsync(token.encode()) async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + """Cancel the operation, by cancelling the workflow.""" await cancel_operation(token) async def fetch_info( diff --git a/temporalio/nexus/handler/_token.py b/temporalio/nexus/handler/_token.py index 08e9074a0..487e4f18e 100644 --- a/temporalio/nexus/handler/_token.py +++ b/temporalio/nexus/handler/_token.py @@ -16,7 +16,7 @@ @dataclass(frozen=True) class WorkflowOperationToken(Generic[OutputT]): - """Represents the structured data of a Nexus workflow operation token.""" + """A Nexus operation token for an operation backed by a workflow.""" namespace: str workflow_id: str From 60e466867943a88627b0eb790c50dd23417d529c Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 18:03:08 -0400 Subject: [PATCH 059/183] Comment, cleanup --- temporalio/nexus/handler/_operation_handlers.py | 3 +++ tests/nexus/test_handler.py | 2 +- tests/nexus/test_handler_interface_implementation.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 60168629c..72b92d54f 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -63,6 +63,9 @@ async def start( return WorkflowRunOperationHandler(start) """ + # TODO(nexus-prerelease): I think we want this to be optional, so that the class can + # be used by subclassing, as well as by injecting the start method in the + # constructor. def __init__( self, start: Callable[ diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 5f9d1ed56..3fdd0508f 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -41,7 +41,7 @@ ) import temporalio.api.failure.v1 -import temporalio.nexus +import temporalio.nexus.handler from temporalio import workflow from temporalio.client import Client from temporalio.common import WorkflowIDReusePolicy diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index 2e2872d45..a4575f351 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -6,7 +6,7 @@ from nexusrpc.handler import OperationHandler, SyncOperationHandler import temporalio.api.failure.v1 -import temporalio.nexus +import temporalio.nexus.handler from temporalio.nexus.handler import ( WorkflowOperationToken, ) From e79f222fcbd656aa1e62ec0cce3d5f7a31d91437 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 18:14:42 -0400 Subject: [PATCH 060/183] TODO --- tests/nexus/test_handler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 3fdd0508f..17ae707dd 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -961,6 +961,7 @@ def start(ctx: StartOperationContext, input: Input) -> Output: value=f"from start method on {self.__class__.__name__}: {input.value}" ) + # TODO(nexus-prerelease) why is this test passing? start is not `async def` return nexusrpc.handler.SyncOperationHandler(start) From c7b0170a5b3b8dfae481e11cd123d5b9bb6c68bc Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 18:50:32 -0400 Subject: [PATCH 061/183] TODOs --- tests/nexus/test_handler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 17ae707dd..67a80c044 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -71,11 +71,10 @@ class NonSerializableOutput: callable: Callable[[], Any] = lambda: None -# TODO: type check nexus implementation under mypy - +# TODO(nexus-prelease): Test attaching multiple callers to the same operation. +# TODO(nexus-preview): type check nexus implementation under mypy # TODO(nexus-prerelease): test dynamic creation of a service from unsugared definition # TODO(nexus-prerelease): test malformed inbound_links and outbound_links - # TODO(nexus-prerelease): test good error message on forgetting to add decorators etc From d731ac26f7390d58f9f69bc44b7233154eafd7ae Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 20:09:12 -0400 Subject: [PATCH 062/183] Get rid of spurious type parameters --- temporalio/nexus/handler/_operation_handlers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 72b92d54f..bcd6a5365 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -4,7 +4,6 @@ Any, Awaitable, Callable, - Generic, ) import nexusrpc.handler @@ -18,7 +17,6 @@ from nexusrpc.types import ( InputT, OutputT, - ServiceHandlerT, ) from temporalio.client import WorkflowHandle @@ -35,7 +33,6 @@ class WorkflowRunOperationHandler( nexusrpc.handler.OperationHandler[InputT, OutputT], - Generic[InputT, OutputT, ServiceHandlerT], ): """ Operation handler for Nexus operations that start a workflow. From 8755353078fe215364962aab84372567761b9ecb Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 22:31:56 -0400 Subject: [PATCH 063/183] Add worker logging --- temporalio/worker/_nexus.py | 1 + 1 file changed, 1 insertion(+) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 5abad641f..705473f01 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -213,6 +213,7 @@ async def _handle_start_operation_task( try: start_response = await self._start_operation(start_request, headers) except BaseException as err: + logger.exception("Failed to execute Nexus start operation method") handler_err = _exception_to_handler_error(err) completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( task_token=task_token, From a1a3df6ec4667969c58cc6fb15bb397c7b3defe0 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 20:09:53 -0400 Subject: [PATCH 064/183] Type-level enforcement of the two ways to use WorkflowRunOperationHandler --- .../nexus/handler/_operation_handlers.py | 103 ++++++++++++------ tests/helpers/nexus.py | 13 +++ tests/nexus/test_handler.py | 34 +++--- .../test_handler_interface_implementation.py | 4 +- .../test_handler_operation_definitions.py | 9 +- tests/nexus/test_workflow_caller.py | 7 +- tests/nexus/test_workflow_run_operation.py | 74 +++++++++++++ 7 files changed, 185 insertions(+), 59 deletions(-) create mode 100644 tests/nexus/test_workflow_run_operation.py diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index bcd6a5365..1673cbc45 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -1,9 +1,11 @@ from __future__ import annotations +from abc import ABC, abstractmethod from typing import ( Any, Awaitable, Callable, + Optional, ) import nexusrpc.handler @@ -33,6 +35,7 @@ class WorkflowRunOperationHandler( nexusrpc.handler.OperationHandler[InputT, OutputT], + ABC, ): """ Operation handler for Nexus operations that start a workflow. @@ -45,8 +48,10 @@ class WorkflowRunOperationHandler( .. code-block:: python - @service_handler(service=MyNexusService) class MyNexusServiceHandler: - @operation_handler def my_workflow_run_operation( + @service_handler(service=MyNexusService) + class MyNexusServiceHandler: + @operation_handler + def my_workflow_run_operation( self, ) -> OperationHandler[MyInput, MyOutput]: async def start( @@ -57,53 +62,52 @@ async def start( id=str(uuid.uuid4()), ) - return WorkflowRunOperationHandler(start) + return WorkflowRunOperationHandler.from_start_workflow(start) """ - # TODO(nexus-prerelease): I think we want this to be optional, so that the class can - # be used by subclassing, as well as by injecting the start method in the - # constructor. def __init__( self, - start: Callable[ + start: Optional[ + Callable[ + [StartOperationContext, InputT], + Awaitable[WorkflowOperationToken[OutputT]], + ] + ] = None, + ) -> None: + if start is not None: + if not is_async_callable(start): + raise RuntimeError( + f"{start} is not an `async def` method. " + "WorkflowRunOperationHandler must be initialized with an " + "`async def` start method." + ) + self._start = start + if start.__doc__: + self.start.__func__.__doc__ = start.__doc__ + self._input_type, self._output_type = ( + get_workflow_run_start_method_input_and_output_type_annotations(start) + ) + else: + self._start = self._input_type = self._output_type = None + + @classmethod + def from_start_workflow( + cls, + start_workflow: Callable[ [StartOperationContext, InputT], Awaitable[WorkflowOperationToken[OutputT]], ], - ): - if not is_async_callable(start): - raise RuntimeError( - f"{start} is not an `async def` method. " - "WorkflowRunOperationHandler must be initialized with an " - "`async def` start method." - ) - self._start = start - if start.__doc__: - self.start.__func__.__doc__ = start.__doc__ - self._input_type, self._output_type = ( - get_workflow_run_start_method_input_and_output_type_annotations(start) - ) + ) -> WorkflowRunOperationHandler[InputT, OutputT]: + return _WorkflowRunOperationHandler(start_workflow) + @abstractmethod async def start( self, ctx: StartOperationContext, input: InputT ) -> nexusrpc.handler.StartOperationResultAsync: """ Start the operation, by starting a workflow and completing asynchronously. """ - - token = await self._start(ctx, input) - if not isinstance(token, WorkflowOperationToken): - if isinstance(token, WorkflowHandle): - raise RuntimeError( - f"Expected {token} to be a WorkflowOperationToken, but got a WorkflowHandle. " - f"You must use :py:meth:`temporalio.nexus.handler.start_workflow` " - "to start a workflow that will deliver the result of the Nexus operation, " - "not :py:meth:`temporalio.client.Client.start_workflow`." - ) - raise RuntimeError( - f"Expected {token} to be a WorkflowOperationToken, but got {type(token)}. " - "This is a bug in the Nexus SDK. Please report it to the Temporal team." - ) - return StartOperationResultAsync(token.encode()) + ... async def cancel(self, ctx: CancelOperationContext, token: str) -> None: """Cancel the operation, by cancelling the workflow.""" @@ -146,6 +150,35 @@ async def fetch_result( return await handle.result() +class _WorkflowRunOperationHandler(WorkflowRunOperationHandler[InputT, OutputT]): + async def start( + self, ctx: StartOperationContext, input: InputT + ) -> nexusrpc.handler.StartOperationResultAsync: + """ + Start the operation, by starting a workflow and completing asynchronously. + """ + + if self._start is None: + raise RuntimeError( + "Do not use _WorkflowRunOperationHandler directly. " + "Use WorkflowRunOperationHandler.from_start_workflow instead." + ) + + token = await self._start(ctx, input) + if not isinstance(token, WorkflowOperationToken): + if isinstance(token, WorkflowHandle): + raise RuntimeError( + f"Expected {token} to be a WorkflowOperationToken, but got a WorkflowHandle. " + f"You must use :py:meth:`temporalio.nexus.handler.start_workflow` " + "to start a workflow that will deliver the result of the Nexus operation, " + "not :py:meth:`temporalio.client.Client.start_workflow`." + ) + raise RuntimeError( + f"Expected {token} to be a WorkflowOperationToken, but got {type(token)}. " + ) + return StartOperationResultAsync(token.encode()) + + async def cancel_operation( token: str, **kwargs: Any, diff --git a/tests/helpers/nexus.py b/tests/helpers/nexus.py index 88aadc7c5..f0f2f3410 100644 --- a/tests/helpers/nexus.py +++ b/tests/helpers/nexus.py @@ -1,3 +1,4 @@ +import dataclasses from dataclasses import dataclass from typing import Any, Mapping, Optional @@ -93,3 +94,15 @@ async def cancel_operation( # Token can also be sent as "Nexus-Operation-Token" header params={"token": token}, ) + + +def dataclass_as_dict(dataclass: Any) -> dict[str, Any]: + """ + Return a shallow dict of the dataclass's fields. + + dataclasses.as_dict goes too far (attempts to pickle values) + """ + return { + field.name: getattr(dataclass, field.name) + for field in dataclasses.fields(dataclass) + } diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 67a80c044..bdae01b5b 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -51,7 +51,7 @@ from temporalio.nexus.handler._token import WorkflowOperationToken from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers.nexus import ServiceClient, create_nexus_endpoint +from tests.helpers.nexus import ServiceClient, create_nexus_endpoint, dataclass_as_dict HTTP_PORT = 7243 @@ -240,7 +240,9 @@ async def start( id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) - return temporalio.nexus.handler.WorkflowRunOperationHandler(start) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( + start + ) @nexusrpc.handler.operation_handler def sync_operation_with_non_async_def( @@ -299,7 +301,9 @@ async def start(ctx, input): id=str(uuid.uuid4()), ) - return temporalio.nexus.handler.WorkflowRunOperationHandler(start) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( + start + ) @nexusrpc.handler.operation_handler def workflow_run_op_link_test( @@ -320,7 +324,9 @@ async def start( id=str(uuid.uuid4()), ) - return temporalio.nexus.handler.WorkflowRunOperationHandler(start) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( + start + ) class OperationHandlerReturningUnwrappedResult( nexusrpc.handler.OperationHandler[Input, Output] @@ -925,18 +931,6 @@ async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: A assert getattr(record, "operation", None) == operation_name -def dataclass_as_dict(dataclass: Any) -> dict[str, Any]: - """ - Return a shallow dict of the dataclass's fields. - - dataclasses.as_dict goes too far (attempts to pickle values) - """ - return { - field.name: getattr(dataclass, field.name) - for field in dataclasses.fields(dataclass) - } - - class _InstantiationCase: executor: bool handler: Callable[[], Any] @@ -1131,7 +1125,9 @@ async def start( id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) - return temporalio.nexus.handler.WorkflowRunOperationHandler(start) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( + start + ) @nexusrpc.handler.operation_handler def operation_that_executes_a_workflow_before_starting_the_backing_workflow( @@ -1156,7 +1152,9 @@ async def start( id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) - return temporalio.nexus.handler.WorkflowRunOperationHandler(start) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( + start + ) async def test_request_id_becomes_start_workflow_request_id(env: WorkflowEnvironment): diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index a4575f351..12d04be01 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -49,7 +49,9 @@ async def start( ctx: nexusrpc.handler.StartOperationContext, input: str ) -> WorkflowOperationToken[int]: ... - return temporalio.nexus.handler.WorkflowRunOperationHandler(start) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( + start + ) error_message = None diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index fa65ccb49..739290cee 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -10,6 +10,7 @@ import pytest import temporalio.nexus.handler +from temporalio.nexus.handler._operation_handlers import WorkflowRunOperationHandler from temporalio.nexus.handler._token import WorkflowOperationToken @@ -40,7 +41,9 @@ async def start( ctx: nexusrpc.handler.StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: ... - return temporalio.nexus.handler.WorkflowRunOperationHandler(start) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( + start + ) expected_operations = { "my_workflow_run_operation_handler": nexusrpc.Operation( @@ -63,7 +66,7 @@ async def start( ctx: nexusrpc.handler.StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: ... - return temporalio.nexus.handler.WorkflowRunOperationHandler(start) + return WorkflowRunOperationHandler.from_start_workflow(start) expected_operations = NotCalled.expected_operations @@ -79,7 +82,7 @@ async def start( ctx: nexusrpc.handler.StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: ... - return temporalio.nexus.handler.WorkflowRunOperationHandler(start) + return WorkflowRunOperationHandler.from_start_workflow(start) expected_operations = { "workflow_run_operation_with_name_override": nexusrpc.Operation( diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 09c1e4b95..61f35d1c4 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -38,6 +38,7 @@ start_workflow, temporal_operation_context, ) +from temporalio.nexus.handler._operation_handlers import WorkflowRunOperationHandler from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -225,7 +226,7 @@ async def start( id=input.response_type.operation_workflow_id, ) - return temporalio.nexus.handler.WorkflowRunOperationHandler(start) + return WorkflowRunOperationHandler.from_start_workflow(start) # ----------------------------------------------------------------------------- @@ -964,7 +965,9 @@ async def start( id=str(uuid.uuid4()), ) - return temporalio.nexus.handler.WorkflowRunOperationHandler(start) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( + start + ) @workflow.defn diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py new file mode 100644 index 000000000..2448a278f --- /dev/null +++ b/tests/nexus/test_workflow_run_operation.py @@ -0,0 +1,74 @@ +import uuid +from dataclasses import dataclass + +from nexusrpc.handler import ( + OperationHandler, + StartOperationContext, + StartOperationResultAsync, + operation_handler, + service_handler, +) + +from temporalio import workflow +from temporalio.nexus.handler import WorkflowRunOperationHandler, start_workflow +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from tests.helpers.nexus import ServiceClient, create_nexus_endpoint, dataclass_as_dict + +HTTP_PORT = 7243 + + +@dataclass +class Input: + value: str + + +@workflow.defn +class EchoWorkflow: + @workflow.run + async def run(self, input: str) -> str: + return input + + +class MyOperation(WorkflowRunOperationHandler): + async def start( + self, ctx: StartOperationContext, input: Input + ) -> StartOperationResultAsync: + token = await start_workflow( + EchoWorkflow.run, + input.value, + id=str(uuid.uuid4()), + ) + return StartOperationResultAsync(token.encode()) + + +@service_handler +class MyService: + @operation_handler + def op(self) -> OperationHandler[Input, str]: + return MyOperation() + + +async def test_workflow_run_operation_via_subclassing(env: WorkflowEnvironment): + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=MyService.__name__, + ) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[MyService()], + ): + resp = await service_client.start_operation( + "op", + dataclass_as_dict(Input(value="test")), + ) + assert resp.status_code == 201 + + +def server_address(env: WorkflowEnvironment) -> str: + http_port = getattr(env, "_http_port", 7243) + return f"http://127.0.0.1:{http_port}" From 1db7ff039ef512c14807614cf0599a162b91f486 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Jun 2025 23:45:32 -0400 Subject: [PATCH 065/183] Respond to upstream: SyncOperation.from_callable --- .../nexus/handler/_operation_handlers.py | 2 +- ...ynamic_creation_of_user_handler_classes.py | 2 +- tests/nexus/test_handler.py | 47 ++++++++++--------- tests/nexus/test_handler_async_operation.py | 2 +- .../test_handler_interface_implementation.py | 9 ++-- .../test_handler_operation_definitions.py | 11 ++--- tests/nexus/test_workflow_caller.py | 15 +++--- 7 files changed, 44 insertions(+), 44 deletions(-) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 1673cbc45..d122c37e6 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -115,7 +115,7 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None: async def fetch_info( self, ctx: nexusrpc.handler.FetchOperationInfoContext, token: str - ) -> nexusrpc.handler.OperationInfo: + ) -> nexusrpc.OperationInfo: raise NotImplementedError( "Temporal Nexus operation handlers do not support fetching operation info." ) diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index dce89c534..8734a0fe8 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -40,7 +40,7 @@ async def _increment_op( ) -> int: return input + 1 - return SyncOperationHandler(_increment_op) + return SyncOperationHandler.from_callable(_increment_op) op_handler_factories = { # TODO(nexus-prerelease): check that name=name should be required here. Should the op factory diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index bdae01b5b..98f72cae6 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -30,14 +30,13 @@ import nexusrpc.handler.syncio import pytest from google.protobuf import json_format +from nexusrpc import OperationInfo from nexusrpc.handler import ( CancelOperationContext, - StartOperationContext, -) -from nexusrpc.handler._common import ( FetchOperationInfoContext, FetchOperationResultContext, - OperationInfo, + StartOperationContext, + SyncOperationHandler, ) import temporalio.api.failure.v1 @@ -47,8 +46,12 @@ from temporalio.common import WorkflowIDReusePolicy from temporalio.converter import FailureConverter, PayloadConverter from temporalio.exceptions import ApplicationError -from temporalio.nexus.handler import logger, start_workflow, temporal_operation_context -from temporalio.nexus.handler._token import WorkflowOperationToken +from temporalio.nexus.handler import ( + WorkflowOperationToken, + logger, + start_workflow, + temporal_operation_context, +) from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ServiceClient, create_nexus_endpoint, dataclass_as_dict @@ -141,7 +144,7 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: value=f"from start method on {self.__class__.__name__}: {input.value}" ) - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def hang(self) -> nexusrpc.handler.OperationHandler[Input, Output]: @@ -149,7 +152,7 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: await asyncio.Future() return Output(value="won't reach here") - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def non_retryable_application_error( @@ -164,7 +167,7 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: non_retryable=True, ) - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def retryable_application_error( @@ -178,7 +181,7 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: non_retryable=False, ) - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def handler_error_internal( @@ -192,7 +195,7 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: cause=RuntimeError("cause message"), ) - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def operation_error_failed( @@ -204,7 +207,7 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: state=nexusrpc.handler.OperationErrorState.FAILED, ) - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def check_operation_timeout_header( @@ -216,7 +219,7 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: value=f"from start method on {self.__class__.__name__}: {input.value}" ) - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def log(self) -> nexusrpc.handler.OperationHandler[Input, Output]: @@ -224,7 +227,7 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: logger.info("Logging from start method", extra={"input_value": input.value}) return Output(value=f"logged: {input.value}") - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def workflow_run_operation( @@ -253,7 +256,7 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: value=f"from start method on {self.__class__.__name__}: {input.value}" ) - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) if False: # TODO(nexus-prerelease): fix tests of callable instances @@ -270,7 +273,7 @@ def __call__( value=f"from start method on {self.__class__.__name__}: {input.value}" ) - return nexusrpc.handler.syncio.SyncOperationHandler(start()) + return SyncOperationHandler.from_callable(start()) _sync_operation_with_non_async_callable_instance = ( nexusrpc.handler.operation_handler( @@ -288,7 +291,7 @@ async def start(ctx, input): value=f"from start method on {self.__class__.__name__} without type annotations: {input}" ) - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def workflow_run_operation_without_type_annotations( @@ -368,7 +371,7 @@ def idempotency_check( async def start(ctx: StartOperationContext, input: None) -> Output: return Output(value=f"request_id: {ctx.request_id}") - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def non_serializable_output( @@ -379,7 +382,7 @@ async def start( ) -> NonSerializableOutput: return NonSerializableOutput() - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @dataclass @@ -955,7 +958,7 @@ def start(ctx: StartOperationContext, input: Input) -> Output: ) # TODO(nexus-prerelease) why is this test passing? start is not `async def` - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.service_handler(service=EchoService) @@ -967,7 +970,7 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: value=f"from start method on {self.__class__.__name__}: {input.value}" ) - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.service_handler(service=EchoService) @@ -1015,7 +1018,7 @@ class SyncCancel(_InstantiationCase): handler = SyncCancelHandler executor = False exception = RuntimeError - match = "cancel must be an `async def`" + match = "cancel method must be an `async def`" @pytest.mark.parametrize( diff --git a/tests/nexus/test_handler_async_operation.py b/tests/nexus/test_handler_async_operation.py index bfe850cbb..19d4f0ae1 100644 --- a/tests/nexus/test_handler_async_operation.py +++ b/tests/nexus/test_handler_async_operation.py @@ -15,12 +15,12 @@ import nexusrpc import nexusrpc.handler import pytest +from nexusrpc import OperationInfo from nexusrpc.handler import ( CancelOperationContext, FetchOperationInfoContext, FetchOperationResultContext, OperationHandler, - OperationInfo, StartOperationContext, StartOperationResultAsync, ) diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index 12d04be01..331ab25cb 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -5,10 +5,9 @@ import pytest from nexusrpc.handler import OperationHandler, SyncOperationHandler -import temporalio.api.failure.v1 -import temporalio.nexus.handler from temporalio.nexus.handler import ( WorkflowOperationToken, + WorkflowRunOperationHandler, ) HTTP_PORT = 7243 @@ -32,7 +31,7 @@ async def start( ctx: nexusrpc.handler.StartOperationContext, input: None ) -> None: ... - return SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) error_message = None @@ -49,9 +48,7 @@ async def start( ctx: nexusrpc.handler.StartOperationContext, input: str ) -> WorkflowOperationToken[int]: ... - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( - start - ) + return WorkflowRunOperationHandler.from_start_workflow(start) error_message = None diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index 739290cee..4c1245a66 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -9,9 +9,10 @@ import nexusrpc.handler import pytest -import temporalio.nexus.handler -from temporalio.nexus.handler._operation_handlers import WorkflowRunOperationHandler -from temporalio.nexus.handler._token import WorkflowOperationToken +from temporalio.nexus.handler import ( + WorkflowOperationToken, + WorkflowRunOperationHandler, +) @dataclass @@ -41,9 +42,7 @@ async def start( ctx: nexusrpc.handler.StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: ... - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( - start - ) + return WorkflowRunOperationHandler.from_start_workflow(start) expected_operations = { "my_workflow_run_operation_handler": nexusrpc.Operation( diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 61f35d1c4..efed3ea60 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -11,6 +11,7 @@ CancelOperationContext, FetchOperationInfoContext, StartOperationContext, + SyncOperationHandler, ) import temporalio.api @@ -35,10 +36,10 @@ from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError from temporalio.nexus.handler import ( WorkflowOperationToken, + WorkflowRunOperationHandler, start_workflow, temporal_operation_context, ) -from temporalio.nexus.handler._operation_handlers import WorkflowRunOperationHandler from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -173,7 +174,7 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None: async def fetch_info( self, ctx: FetchOperationInfoContext, token: str - ) -> nexusrpc.handler.OperationInfo: + ) -> nexusrpc.OperationInfo: raise NotImplementedError async def fetch_result( @@ -204,7 +205,7 @@ async def start(ctx: StartOperationContext, input: OpInput) -> OpOutput: ) return OpOutput(value="sync response") - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def async_operation( @@ -763,7 +764,7 @@ async def start( ) -> ServiceClassNameOutput: return ServiceClassNameOutput(self.__class__.__name__) - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.service_handler(service=ServiceInterfaceWithoutNameOverride) @@ -777,7 +778,7 @@ async def start( ) -> ServiceClassNameOutput: return ServiceClassNameOutput(self.__class__.__name__) - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.service_handler(service=ServiceInterfaceWithNameOverride) @@ -791,7 +792,7 @@ async def start( ) -> ServiceClassNameOutput: return ServiceClassNameOutput(self.__class__.__name__) - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) @nexusrpc.handler.service_handler(name="service-impl-🌈") @@ -805,7 +806,7 @@ async def start( ) -> ServiceClassNameOutput: return ServiceClassNameOutput(self.__class__.__name__) - return nexusrpc.handler.SyncOperationHandler(start) + return SyncOperationHandler.from_callable(start) class NameOverride(IntEnum): From 07c6d39f1973791407ed42505c89ee0d0d7ae214 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 08:57:58 -0400 Subject: [PATCH 066/183] -> WorkflowRunOperation.from_callable() --- .../nexus/handler/_operation_handlers.py | 2 +- tests/nexus/test_handler.py | 20 +++++-------------- .../test_handler_interface_implementation.py | 2 +- .../test_handler_operation_definitions.py | 6 +++--- tests/nexus/test_workflow_caller.py | 6 ++---- 5 files changed, 12 insertions(+), 24 deletions(-) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index d122c37e6..e71ad091a 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -91,7 +91,7 @@ def __init__( self._start = self._input_type = self._output_type = None @classmethod - def from_start_workflow( + def from_callable( cls, start_workflow: Callable[ [StartOperationContext, InputT], diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 98f72cae6..756e8636d 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -243,9 +243,7 @@ async def start( id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( - start - ) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def sync_operation_with_non_async_def( @@ -304,9 +302,7 @@ async def start(ctx, input): id=str(uuid.uuid4()), ) - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( - start - ) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def workflow_run_op_link_test( @@ -327,9 +323,7 @@ async def start( id=str(uuid.uuid4()), ) - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( - start - ) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) class OperationHandlerReturningUnwrappedResult( nexusrpc.handler.OperationHandler[Input, Output] @@ -1128,9 +1122,7 @@ async def start( id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( - start - ) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) @nexusrpc.handler.operation_handler def operation_that_executes_a_workflow_before_starting_the_backing_workflow( @@ -1155,9 +1147,7 @@ async def start( id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( - start - ) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) async def test_request_id_becomes_start_workflow_request_id(env: WorkflowEnvironment): diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index 331ab25cb..0a0399e16 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -48,7 +48,7 @@ async def start( ctx: nexusrpc.handler.StartOperationContext, input: str ) -> WorkflowOperationToken[int]: ... - return WorkflowRunOperationHandler.from_start_workflow(start) + return WorkflowRunOperationHandler.from_callable(start) error_message = None diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index 4c1245a66..da3781846 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -42,7 +42,7 @@ async def start( ctx: nexusrpc.handler.StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: ... - return WorkflowRunOperationHandler.from_start_workflow(start) + return WorkflowRunOperationHandler.from_callable(start) expected_operations = { "my_workflow_run_operation_handler": nexusrpc.Operation( @@ -65,7 +65,7 @@ async def start( ctx: nexusrpc.handler.StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: ... - return WorkflowRunOperationHandler.from_start_workflow(start) + return WorkflowRunOperationHandler.from_callable(start) expected_operations = NotCalled.expected_operations @@ -81,7 +81,7 @@ async def start( ctx: nexusrpc.handler.StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: ... - return WorkflowRunOperationHandler.from_start_workflow(start) + return WorkflowRunOperationHandler.from_callable(start) expected_operations = { "workflow_run_operation_with_name_override": nexusrpc.Operation( diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index efed3ea60..aeb2dcb92 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -227,7 +227,7 @@ async def start( id=input.response_type.operation_workflow_id, ) - return WorkflowRunOperationHandler.from_start_workflow(start) + return WorkflowRunOperationHandler.from_callable(start) # ----------------------------------------------------------------------------- @@ -966,9 +966,7 @@ async def start( id=str(uuid.uuid4()), ) - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_start_workflow( - start - ) + return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) @workflow.defn From 2616755c0e9cedfe2c9feec139590e2024e65411 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 08:59:42 -0400 Subject: [PATCH 067/183] TODO --- temporalio/nexus/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index e69de29bb..1b868c610 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -0,0 +1,3 @@ +# TODO(nexus-prerelease) WARN temporal_sdk_core::worker::nexus: Failed to parse nexus timeout header value '9.155416ms' +# 2025-06-25T12:58:05.749589Z WARN temporal_sdk_core::worker::nexus: Failed to parse nexus timeout header value '9.155416ms' +# 2025-06-25T12:58:05.763052Z WARN temporal_sdk_core::worker::nexus: Nexus task not found on completion. This may happen if the operation has already been cancelled but completed anyway. details=Status { code: NotFound, message: "Nexus task not found or already expired", details: b"\x08\x05\x12'Nexus task not found or already expired\x1aB\n@type.googleapis.com/temporal.api.errordetails.v1.NotFoundFailure", metadata: MetadataMap { headers: {"content-type": "application/grpc"} }, source: None } From 27d7e4157c6a0c91e269595fae28c9d9dd89fd06 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 11:06:03 -0400 Subject: [PATCH 068/183] Parameterize workflow_run_operation tests --- tests/nexus/test_workflow_run_operation.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index 2448a278f..bde35db86 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -1,6 +1,8 @@ import uuid from dataclasses import dataclass +from typing import Any, Type +import pytest from nexusrpc.handler import ( OperationHandler, StartOperationContext, @@ -43,24 +45,33 @@ async def start( @service_handler -class MyService: +class SubclassingHappyPath: @operation_handler def op(self) -> OperationHandler[Input, str]: return MyOperation() -async def test_workflow_run_operation_via_subclassing(env: WorkflowEnvironment): +@pytest.mark.parametrize( + "service_handler_cls", + [ + SubclassingHappyPath, + ], +) +async def test_workflow_run_operation( + env: WorkflowEnvironment, + service_handler_cls: Type[Any], +): task_queue = str(uuid.uuid4()) endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id service_client = ServiceClient( server_address=server_address(env), endpoint=endpoint, - service=MyService.__name__, + service=service_handler_cls.__name__, ) async with Worker( env.client, task_queue=task_queue, - nexus_service_handlers=[MyService()], + nexus_service_handlers=[service_handler_cls()], ): resp = await service_client.start_operation( "op", From c0cf503f5f1dce984225bbad8badd5eff9a2e697 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 11:18:26 -0400 Subject: [PATCH 069/183] Failing test case --- tests/nexus/test_workflow_run_operation.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index bde35db86..b75bd2407 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -3,6 +3,7 @@ from typing import Any, Type import pytest +from nexusrpc import Operation, service from nexusrpc.handler import ( OperationHandler, StartOperationContext, @@ -51,10 +52,31 @@ def op(self) -> OperationHandler[Input, str]: return MyOperation() +@service +class Service: + op: Operation[Input, str] + + +@service_handler +class SubclassingNoInputOutputTypeAnnotationsWithoutServiceDefinition: + @operation_handler + def op(self) -> OperationHandler: + return MyOperation() + + +@service_handler(service=Service) +class SubclassingNoInputOutputTypeAnnotationsWithServiceDefinition: + @operation_handler + def op(self) -> OperationHandler[Input, str]: + return MyOperation() + + @pytest.mark.parametrize( "service_handler_cls", [ SubclassingHappyPath, + SubclassingNoInputOutputTypeAnnotationsWithoutServiceDefinition, + SubclassingNoInputOutputTypeAnnotationsWithServiceDefinition, ], ) async def test_workflow_run_operation( From efb9df51df76393ab730700f23a8077b9736de72 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 17:33:26 -0400 Subject: [PATCH 070/183] Test: clean up imports --- tests/nexus/test_handler.py | 139 ++++++++++++++-------------- tests/nexus/test_workflow_caller.py | 66 ++++++------- 2 files changed, 106 insertions(+), 99 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 756e8636d..9fff3f947 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -27,7 +27,6 @@ import httpx import nexusrpc -import nexusrpc.handler.syncio import pytest from google.protobuf import json_format from nexusrpc import OperationInfo @@ -35,8 +34,16 @@ CancelOperationContext, FetchOperationInfoContext, FetchOperationResultContext, + HandlerError, + HandlerErrorType, + OperationError, + OperationErrorState, + OperationHandler, StartOperationContext, SyncOperationHandler, + operation_handler, + service_handler, + sync_operation_handler, ) import temporalio.api.failure.v1 @@ -135,8 +142,8 @@ async def run(self, input: Input) -> Output: # The service_handler decorator is applied by the test class MyServiceHandler: - @nexusrpc.handler.operation_handler - def echo(self) -> nexusrpc.handler.OperationHandler[Input, Output]: + @operation_handler + def echo(self) -> OperationHandler[Input, Output]: async def start(ctx: StartOperationContext, input: Input) -> Output: assert ctx.headers["test-header-key"] == "test-header-value" ctx.outbound_links.extend(ctx.inbound_links) @@ -146,18 +153,18 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: return SyncOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler - def hang(self) -> nexusrpc.handler.OperationHandler[Input, Output]: + @operation_handler + def hang(self) -> OperationHandler[Input, Output]: async def start(ctx: StartOperationContext, input: Input) -> Output: await asyncio.Future() return Output(value="won't reach here") return SyncOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler + @operation_handler def non_retryable_application_error( self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: + ) -> OperationHandler[Input, Output]: async def start(ctx: StartOperationContext, input: Input) -> Output: raise ApplicationError( "non-retryable application error", @@ -169,10 +176,10 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: return SyncOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler + @operation_handler def retryable_application_error( self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: + ) -> OperationHandler[Input, Output]: async def start(ctx: StartOperationContext, input: Input) -> Output: raise ApplicationError( "retryable application error", @@ -183,36 +190,36 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: return SyncOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler + @operation_handler def handler_error_internal( self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: + ) -> OperationHandler[Input, Output]: async def start(ctx: StartOperationContext, input: Input) -> Output: - raise nexusrpc.handler.HandlerError( + raise HandlerError( message="deliberate internal handler error", - type=nexusrpc.handler.HandlerErrorType.INTERNAL, + type=HandlerErrorType.INTERNAL, retryable=False, cause=RuntimeError("cause message"), ) return SyncOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler + @operation_handler def operation_error_failed( self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: + ) -> OperationHandler[Input, Output]: async def start(ctx: StartOperationContext, input: Input) -> Output: - raise nexusrpc.handler.OperationError( + raise OperationError( message="deliberate operation error", - state=nexusrpc.handler.OperationErrorState.FAILED, + state=OperationErrorState.FAILED, ) return SyncOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler + @operation_handler def check_operation_timeout_header( self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: + ) -> OperationHandler[Input, Output]: async def start(ctx: StartOperationContext, input: Input) -> Output: assert "operation-timeout" in ctx.headers return Output( @@ -221,18 +228,18 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: return SyncOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler - def log(self) -> nexusrpc.handler.OperationHandler[Input, Output]: + @operation_handler + def log(self) -> OperationHandler[Input, Output]: async def start(ctx: StartOperationContext, input: Input) -> Output: logger.info("Logging from start method", extra={"input_value": input.value}) return Output(value=f"logged: {input.value}") return SyncOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler + @operation_handler def workflow_run_operation( self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: + ) -> OperationHandler[Input, Output]: async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: @@ -245,10 +252,10 @@ async def start( return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler + @operation_handler def sync_operation_with_non_async_def( self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: + ) -> OperationHandler[Input, Output]: async def start(ctx: StartOperationContext, input: Input) -> Output: return Output( value=f"from start method on {self.__class__.__name__}: {input.value}" @@ -260,7 +267,7 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: # TODO(nexus-prerelease): fix tests of callable instances def sync_operation_with_non_async_callable_instance( self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: + ) -> OperationHandler[Input, Output]: class start: def __call__( self, @@ -273,15 +280,13 @@ def __call__( return SyncOperationHandler.from_callable(start()) - _sync_operation_with_non_async_callable_instance = ( - nexusrpc.handler.operation_handler( - name="sync_operation_with_non_async_callable_instance", - )( - sync_operation_with_non_async_callable_instance, - ) + _sync_operation_with_non_async_callable_instance = operation_handler( + name="sync_operation_with_non_async_callable_instance", + )( + sync_operation_with_non_async_callable_instance, ) - @nexusrpc.handler.operation_handler + @operation_handler def sync_operation_without_type_annotations(self): async def start(ctx, input): # The input type from the op definition in the service definition is used to deserialize the input. @@ -291,10 +296,10 @@ async def start(ctx, input): return SyncOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler + @operation_handler def workflow_run_operation_without_type_annotations( self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: + ) -> OperationHandler[Input, Output]: async def start(ctx, input): return await start_workflow( WorkflowWithoutTypeAnnotations.run, @@ -304,10 +309,10 @@ async def start(ctx, input): return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler + @operation_handler def workflow_run_op_link_test( self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: + ) -> OperationHandler[Input, Output]: async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: @@ -325,9 +330,7 @@ async def start( return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) - class OperationHandlerReturningUnwrappedResult( - nexusrpc.handler.OperationHandler[Input, Output] - ): + class OperationHandlerReturningUnwrappedResult(OperationHandler[Input, Output]): async def start( self, ctx: StartOperationContext, @@ -352,25 +355,25 @@ async def fetch_result( async def cancel(self, ctx: CancelOperationContext, token: str) -> None: raise NotImplementedError - @nexusrpc.handler.operation_handler + @operation_handler def operation_returning_unwrapped_result_at_runtime_error( self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: + ) -> OperationHandler[Input, Output]: return MyServiceHandler.OperationHandlerReturningUnwrappedResult() - @nexusrpc.handler.operation_handler + @operation_handler def idempotency_check( self, - ) -> nexusrpc.handler.OperationHandler[None, Output]: + ) -> OperationHandler[None, Output]: async def start(ctx: StartOperationContext, input: None) -> Output: return Output(value=f"request_id: {ctx.request_id}") return SyncOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler + @operation_handler def non_serializable_output( self, - ) -> nexusrpc.handler.OperationHandler[Input, NonSerializableOutput]: + ) -> OperationHandler[Input, NonSerializableOutput]: async def start( ctx: StartOperationContext, input: Input ) -> NonSerializableOutput: @@ -648,7 +651,7 @@ class OperationHandlerReturningUnwrappedResultError(_FailureTestCase): retryable_header=False, failure_message=( "Operation start method must return either " - "nexusrpc.handler.StartOperationResultSync or nexusrpc.handler.StartOperationResultAsync." + "StartOperationResultSync or StartOperationResultAsync." ), ) @@ -739,7 +742,7 @@ class HandlerErrorInternal(_FailureTestCase): ) -class OperationError(_FailureTestCase): +class OperationErrorFailed(_FailureTestCase): operation = "operation_error_failed" expected = UnsuccessfulResponse( status_code=424, @@ -828,7 +831,7 @@ async def test_start_operation_protocol_level_failures( [ NonRetryableApplicationError, RetryableApplicationError, - OperationError, + OperationErrorFailed, ], ) async def test_start_operation_operation_failures( @@ -858,9 +861,9 @@ async def _test_start_operation( with pytest.WarningsRecorder() as warnings: decorator = ( - nexusrpc.handler.service_handler(service=MyService) + service_handler(service=MyService) if with_service_definition - else nexusrpc.handler.service_handler + else service_handler ) service_handler = decorator(MyServiceHandler)() @@ -940,10 +943,10 @@ class EchoService: echo: nexusrpc.Operation[Input, Output] -@nexusrpc.handler.service_handler(service=EchoService) +@service_handler(service=EchoService) class SyncStartHandler: - @nexusrpc.handler.operation_handler - def echo(self) -> nexusrpc.handler.OperationHandler[Input, Output]: + @operation_handler + def echo(self) -> OperationHandler[Input, Output]: def start(ctx: StartOperationContext, input: Input) -> Output: assert ctx.headers["test-header-key"] == "test-header-value" ctx.outbound_links.extend(ctx.inbound_links) @@ -955,10 +958,10 @@ def start(ctx: StartOperationContext, input: Input) -> Output: return SyncOperationHandler.from_callable(start) -@nexusrpc.handler.service_handler(service=EchoService) +@service_handler(service=EchoService) class DefaultCancelHandler: - @nexusrpc.handler.operation_handler - def echo(self) -> nexusrpc.handler.OperationHandler[Input, Output]: + @operation_handler + def echo(self) -> OperationHandler[Input, Output]: async def start(ctx: StartOperationContext, input: Input) -> Output: return Output( value=f"from start method on {self.__class__.__name__}: {input.value}" @@ -967,9 +970,9 @@ async def start(ctx: StartOperationContext, input: Input) -> Output: return SyncOperationHandler.from_callable(start) -@nexusrpc.handler.service_handler(service=EchoService) +@service_handler(service=EchoService) class SyncCancelHandler: - class SyncCancel(nexusrpc.handler.OperationHandler[Input, Output]): + class SyncCancel(OperationHandler[Input, Output]): async def start( self, ctx: StartOperationContext, @@ -990,8 +993,8 @@ def fetch_info(self, ctx: FetchOperationInfoContext) -> OperationInfo: def fetch_result(self, ctx: FetchOperationResultContext) -> Output: raise NotImplementedError - @nexusrpc.handler.operation_handler - def echo(self) -> nexusrpc.handler.OperationHandler[Input, Output]: + @operation_handler + def echo(self) -> OperationHandler[Input, Output]: return SyncCancelHandler.SyncCancel() @@ -1053,7 +1056,7 @@ async def test_cancel_operation_with_invalid_token(env: WorkflowEnvironment): service=MyService.__name__, ) - decorator = nexusrpc.handler.service_handler(service=MyService) + decorator = service_handler(service=MyService) service_handler = decorator(MyServiceHandler)() async with Worker( @@ -1082,7 +1085,7 @@ async def test_request_id_is_received_by_sync_operation_handler( service=MyService.__name__, ) - decorator = nexusrpc.handler.service_handler(service=MyService) + decorator = service_handler(service=MyService) service_handler = decorator(MyServiceHandler)() async with Worker( @@ -1106,12 +1109,12 @@ async def run(self, input: Input) -> Output: return Output(value=input.value) -@nexusrpc.handler.service_handler +@service_handler class ServiceHandlerForRequestIdTest: - @nexusrpc.handler.operation_handler + @operation_handler def operation_backed_by_a_workflow( self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: + ) -> OperationHandler[Input, Output]: async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: @@ -1124,10 +1127,10 @@ async def start( return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler + @operation_handler def operation_that_executes_a_workflow_before_starting_the_backing_workflow( self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: + ) -> OperationHandler[Input, Output]: async def start( ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index aeb2dcb92..32dfd2372 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -10,8 +10,14 @@ from nexusrpc.handler import ( CancelOperationContext, FetchOperationInfoContext, + FetchOperationResultContext, + OperationHandler, StartOperationContext, + StartOperationResultAsync, + StartOperationResultSync, SyncOperationHandler, + operation_handler, + service_handler, ) import temporalio.api @@ -141,12 +147,12 @@ async def run( # TODO: make types pass pyright strict mode -class SyncOrAsyncOperation(nexusrpc.handler.OperationHandler[OpInput, OpOutput]): +class SyncOrAsyncOperation(OperationHandler[OpInput, OpOutput]): async def start( self, ctx: StartOperationContext, input: OpInput ) -> Union[ - nexusrpc.handler.StartOperationResultSync[OpOutput], - nexusrpc.handler.StartOperationResultAsync, + StartOperationResultSync[OpOutput], + StartOperationResultAsync, ]: if input.response_type.exception_in_operation_start: # TODO(dan): don't think RPCError should be used here @@ -156,16 +162,14 @@ async def start( b"", ) if isinstance(input.response_type, SyncResponse): - return nexusrpc.handler.StartOperationResultSync( - value=OpOutput(value="sync response") - ) + return StartOperationResultSync(value=OpOutput(value="sync response")) elif isinstance(input.response_type, AsyncResponse): token = await start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, ) - return nexusrpc.handler.StartOperationResultAsync(token.encode()) + return StartOperationResultAsync(token.encode()) else: raise TypeError @@ -178,23 +182,23 @@ async def fetch_info( raise NotImplementedError async def fetch_result( - self, ctx: nexusrpc.handler.FetchOperationResultContext, token: str + self, ctx: FetchOperationResultContext, token: str ) -> OpOutput: raise NotImplementedError -@nexusrpc.handler.service_handler(service=ServiceInterface) +@service_handler(service=ServiceInterface) class ServiceImpl: - @nexusrpc.handler.operation_handler + @operation_handler def sync_or_async_operation( self, - ) -> nexusrpc.handler.OperationHandler[OpInput, OpOutput]: + ) -> OperationHandler[OpInput, OpOutput]: return SyncOrAsyncOperation() - @nexusrpc.handler.operation_handler + @operation_handler def sync_operation( self, - ) -> nexusrpc.handler.OperationHandler[OpInput, OpOutput]: + ) -> OperationHandler[OpInput, OpOutput]: async def start(ctx: StartOperationContext, input: OpInput) -> OpOutput: assert isinstance(input.response_type, SyncResponse) if input.response_type.exception_in_operation_start: @@ -207,10 +211,10 @@ async def start(ctx: StartOperationContext, input: OpInput) -> OpOutput: return SyncOperationHandler.from_callable(start) - @nexusrpc.handler.operation_handler + @operation_handler def async_operation( self, - ) -> nexusrpc.handler.OperationHandler[OpInput, HandlerWfOutput]: + ) -> OperationHandler[OpInput, HandlerWfOutput]: async def start( ctx: StartOperationContext, input: OpInput ) -> WorkflowOperationToken[HandlerWfOutput]: @@ -309,7 +313,7 @@ def _get_operation( op_input: OpInput, ) -> Union[ nexusrpc.Operation[OpInput, OpOutput], - Callable[[Any], nexusrpc.handler.OperationHandler[OpInput, OpOutput]], + Callable[[Any], OperationHandler[OpInput, OpOutput]], ]: return { ( @@ -753,12 +757,12 @@ class ServiceInterfaceWithNameOverride: op: nexusrpc.Operation[None, ServiceClassNameOutput] -@nexusrpc.handler.service_handler +@service_handler class ServiceImplInterfaceWithNeitherInterfaceNorNameOverride: - @nexusrpc.handler.operation_handler + @operation_handler def op( self, - ) -> nexusrpc.handler.OperationHandler[None, ServiceClassNameOutput]: + ) -> OperationHandler[None, ServiceClassNameOutput]: async def start( ctx: StartOperationContext, input: None ) -> ServiceClassNameOutput: @@ -767,12 +771,12 @@ async def start( return SyncOperationHandler.from_callable(start) -@nexusrpc.handler.service_handler(service=ServiceInterfaceWithoutNameOverride) +@service_handler(service=ServiceInterfaceWithoutNameOverride) class ServiceImplInterfaceWithoutNameOverride: - @nexusrpc.handler.operation_handler + @operation_handler def op( self, - ) -> nexusrpc.handler.OperationHandler[None, ServiceClassNameOutput]: + ) -> OperationHandler[None, ServiceClassNameOutput]: async def start( ctx: StartOperationContext, input: None ) -> ServiceClassNameOutput: @@ -781,12 +785,12 @@ async def start( return SyncOperationHandler.from_callable(start) -@nexusrpc.handler.service_handler(service=ServiceInterfaceWithNameOverride) +@service_handler(service=ServiceInterfaceWithNameOverride) class ServiceImplInterfaceWithNameOverride: - @nexusrpc.handler.operation_handler + @operation_handler def op( self, - ) -> nexusrpc.handler.OperationHandler[None, ServiceClassNameOutput]: + ) -> OperationHandler[None, ServiceClassNameOutput]: async def start( ctx: StartOperationContext, input: None ) -> ServiceClassNameOutput: @@ -795,12 +799,12 @@ async def start( return SyncOperationHandler.from_callable(start) -@nexusrpc.handler.service_handler(name="service-impl-🌈") +@service_handler(name="service-impl-🌈") class ServiceImplWithNameOverride: - @nexusrpc.handler.operation_handler + @operation_handler def op( self, - ) -> nexusrpc.handler.OperationHandler[None, ServiceClassNameOutput]: + ) -> OperationHandler[None, ServiceClassNameOutput]: async def start( ctx: StartOperationContext, input: None ) -> ServiceClassNameOutput: @@ -941,12 +945,12 @@ async def run(self, input: str) -> str: return input -@nexusrpc.handler.service_handler +@service_handler class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: - @nexusrpc.handler.operation_handler + @operation_handler def my_workflow_run_operation( self, - ) -> nexusrpc.handler.OperationHandler[None, str]: + ) -> OperationHandler[None, str]: async def start( ctx: StartOperationContext, input: None ) -> WorkflowOperationToken[str]: From 400260d8e3a3bb0e44e91a2f41e9be4ac9d7beec Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 18:28:17 -0400 Subject: [PATCH 071/183] Respond to upstream: sync_operation_handler --- ...ynamic_creation_of_user_handler_classes.py | 33 +- tests/nexus/test_handler.py | 282 ++++++++---------- .../test_handler_interface_implementation.py | 11 +- tests/nexus/test_workflow_caller.py | 89 ++---- 4 files changed, 168 insertions(+), 247 deletions(-) diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index 8734a0fe8..39d0b8f72 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -1,10 +1,10 @@ import uuid -from typing import Any import httpx import nexusrpc.handler import pytest -from nexusrpc.handler import SyncOperationHandler +from nexusrpc.handler import sync_operation_handler +from nexusrpc.handler._util import get_operation_factory from temporalio.client import Client from temporalio.worker import Worker @@ -33,22 +33,19 @@ def make_incrementer_user_service_definition_and_service_handler_classes( # # service handler # - def factory(self: Any) -> nexusrpc.handler.OperationHandler[int, int]: - async def _increment_op( - ctx: nexusrpc.handler.StartOperationContext, - input: int, - ) -> int: - return input + 1 - - return SyncOperationHandler.from_callable(_increment_op) - - op_handler_factories = { - # TODO(nexus-prerelease): check that name=name should be required here. Should the op factory - # name not default to the name of the method attribute (i.e. key), as opposed to - # the name of the method object (i.e. value.__name__)? - name: nexusrpc.handler.operation_handler(name=name)(factory) - for name in op_names - } + @sync_operation_handler + async def _increment_op( + self, + ctx: nexusrpc.handler.StartOperationContext, + input: int, + ) -> int: + return input + 1 + + op_handler_factories = {} + for name in op_names: + op_handler_factory, _ = get_operation_factory(_increment_op) + assert op_handler_factory + op_handler_factories[name] = op_handler_factory handler_cls = nexusrpc.handler.service_handler(service=service_cls)( type("ServiceImpl", (), op_handler_factories) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 9fff3f947..177d0422e 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -40,7 +40,6 @@ OperationErrorState, OperationHandler, StartOperationContext, - SyncOperationHandler, operation_handler, service_handler, sync_operation_handler, @@ -142,126 +141,94 @@ async def run(self, input: Input) -> Output: # The service_handler decorator is applied by the test class MyServiceHandler: - @operation_handler - def echo(self) -> OperationHandler[Input, Output]: - async def start(ctx: StartOperationContext, input: Input) -> Output: - assert ctx.headers["test-header-key"] == "test-header-value" - ctx.outbound_links.extend(ctx.inbound_links) - return Output( - value=f"from start method on {self.__class__.__name__}: {input.value}" - ) - - return SyncOperationHandler.from_callable(start) - - @operation_handler - def hang(self) -> OperationHandler[Input, Output]: - async def start(ctx: StartOperationContext, input: Input) -> Output: - await asyncio.Future() - return Output(value="won't reach here") - - return SyncOperationHandler.from_callable(start) - - @operation_handler - def non_retryable_application_error( - self, - ) -> OperationHandler[Input, Output]: - async def start(ctx: StartOperationContext, input: Input) -> Output: - raise ApplicationError( - "non-retryable application error", - "details arg", - # TODO(nexus-prerelease): what values of `type` should be tested? - type="TestFailureType", - non_retryable=True, - ) - - return SyncOperationHandler.from_callable(start) - - @operation_handler - def retryable_application_error( - self, - ) -> OperationHandler[Input, Output]: - async def start(ctx: StartOperationContext, input: Input) -> Output: - raise ApplicationError( - "retryable application error", - "details arg", - type="TestFailureType", - non_retryable=False, - ) - - return SyncOperationHandler.from_callable(start) - - @operation_handler - def handler_error_internal( - self, - ) -> OperationHandler[Input, Output]: - async def start(ctx: StartOperationContext, input: Input) -> Output: - raise HandlerError( - message="deliberate internal handler error", - type=HandlerErrorType.INTERNAL, - retryable=False, - cause=RuntimeError("cause message"), - ) - - return SyncOperationHandler.from_callable(start) - - @operation_handler - def operation_error_failed( - self, - ) -> OperationHandler[Input, Output]: - async def start(ctx: StartOperationContext, input: Input) -> Output: - raise OperationError( - message="deliberate operation error", - state=OperationErrorState.FAILED, - ) - - return SyncOperationHandler.from_callable(start) - - @operation_handler - def check_operation_timeout_header( - self, - ) -> OperationHandler[Input, Output]: - async def start(ctx: StartOperationContext, input: Input) -> Output: - assert "operation-timeout" in ctx.headers - return Output( - value=f"from start method on {self.__class__.__name__}: {input.value}" - ) + @sync_operation_handler + async def echo(self, ctx: StartOperationContext, input: Input) -> Output: + assert ctx.headers["test-header-key"] == "test-header-value" + ctx.outbound_links.extend(ctx.inbound_links) + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) - return SyncOperationHandler.from_callable(start) + @sync_operation_handler + async def hang(self, ctx: StartOperationContext, input: Input) -> Output: + await asyncio.Future() + return Output(value="won't reach here") + + @sync_operation_handler + async def non_retryable_application_error( + self, ctx: StartOperationContext, input: Input + ) -> Output: + raise ApplicationError( + "non-retryable application error", + "details arg", + # TODO(nexus-prerelease): what values of `type` should be tested? + type="TestFailureType", + non_retryable=True, + ) - @operation_handler - def log(self) -> OperationHandler[Input, Output]: - async def start(ctx: StartOperationContext, input: Input) -> Output: - logger.info("Logging from start method", extra={"input_value": input.value}) - return Output(value=f"logged: {input.value}") + @sync_operation_handler + async def retryable_application_error( + self, ctx: StartOperationContext, input: Input + ) -> Output: + raise ApplicationError( + "retryable application error", + "details arg", + type="TestFailureType", + non_retryable=False, + ) - return SyncOperationHandler.from_callable(start) + @sync_operation_handler + async def handler_error_internal( + self, ctx: StartOperationContext, input: Input + ) -> Output: + raise HandlerError( + message="deliberate internal handler error", + type=HandlerErrorType.INTERNAL, + retryable=False, + cause=RuntimeError("cause message"), + ) - @operation_handler - def workflow_run_operation( - self, - ) -> OperationHandler[Input, Output]: - async def start( - ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: - return await start_workflow( - MyWorkflow.run, - input, - id=str(uuid.uuid4()), - id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, - ) + @sync_operation_handler + async def operation_error_failed( + self, ctx: StartOperationContext, input: Input + ) -> Output: + raise OperationError( + message="deliberate operation error", + state=OperationErrorState.FAILED, + ) - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) + @sync_operation_handler + async def check_operation_timeout_header( + self, ctx: StartOperationContext, input: Input + ) -> Output: + assert "operation-timeout" in ctx.headers + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) - @operation_handler - def sync_operation_with_non_async_def( - self, - ) -> OperationHandler[Input, Output]: - async def start(ctx: StartOperationContext, input: Input) -> Output: - return Output( - value=f"from start method on {self.__class__.__name__}: {input.value}" - ) + @sync_operation_handler + async def log(self, ctx: StartOperationContext, input: Input) -> Output: + logger.info("Logging from start method", extra={"input_value": input.value}) + return Output(value=f"logged: {input.value}") + + @sync_operation_handler + async def workflow_run_operation( + self, ctx: StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: + return await start_workflow( + MyWorkflow.run, + input, + id=str(uuid.uuid4()), + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, + ) - return SyncOperationHandler.from_callable(start) + @sync_operation_handler + async def sync_operation_with_non_async_def( + self, ctx: StartOperationContext, input: Input + ) -> Output: + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) if False: # TODO(nexus-prerelease): fix tests of callable instances @@ -278,7 +245,7 @@ def __call__( value=f"from start method on {self.__class__.__name__}: {input.value}" ) - return SyncOperationHandler.from_callable(start()) + return sync_operation_handler(start()) _sync_operation_with_non_async_callable_instance = operation_handler( name="sync_operation_with_non_async_callable_instance", @@ -286,15 +253,14 @@ def __call__( sync_operation_with_non_async_callable_instance, ) - @operation_handler - def sync_operation_without_type_annotations(self): - async def start(ctx, input): - # The input type from the op definition in the service definition is used to deserialize the input. - return Output( - value=f"from start method on {self.__class__.__name__} without type annotations: {input}" - ) - - return SyncOperationHandler.from_callable(start) + @sync_operation_handler + async def sync_operation_without_type_annotations( + self, ctx: StartOperationContext, input: Input + ) -> Output: + # The input type from the op definition in the service definition is used to deserialize the input. + return Output( + value=f"from start method on {self.__class__.__name__} without type annotations: {input}" + ) @operation_handler def workflow_run_operation_without_type_annotations( @@ -361,25 +327,17 @@ def operation_returning_unwrapped_result_at_runtime_error( ) -> OperationHandler[Input, Output]: return MyServiceHandler.OperationHandlerReturningUnwrappedResult() - @operation_handler - def idempotency_check( - self, - ) -> OperationHandler[None, Output]: - async def start(ctx: StartOperationContext, input: None) -> Output: - return Output(value=f"request_id: {ctx.request_id}") + @sync_operation_handler + async def idempotency_check( + self, ctx: StartOperationContext, input: None + ) -> Output: + return Output(value=f"request_id: {ctx.request_id}") - return SyncOperationHandler.from_callable(start) - - @operation_handler - def non_serializable_output( - self, - ) -> OperationHandler[Input, NonSerializableOutput]: - async def start( - ctx: StartOperationContext, input: Input - ) -> NonSerializableOutput: - return NonSerializableOutput() - - return SyncOperationHandler.from_callable(start) + @sync_operation_handler + async def non_serializable_output( + self, ctx: StartOperationContext, input: Input + ) -> NonSerializableOutput: + return NonSerializableOutput() @dataclass @@ -945,29 +903,23 @@ class EchoService: @service_handler(service=EchoService) class SyncStartHandler: - @operation_handler - def echo(self) -> OperationHandler[Input, Output]: - def start(ctx: StartOperationContext, input: Input) -> Output: - assert ctx.headers["test-header-key"] == "test-header-value" - ctx.outbound_links.extend(ctx.inbound_links) - return Output( - value=f"from start method on {self.__class__.__name__}: {input.value}" - ) - - # TODO(nexus-prerelease) why is this test passing? start is not `async def` - return SyncOperationHandler.from_callable(start) + # TODO(nexus-prerelease): why is this test passing? start is not `async def` + @sync_operation_handler + def echo(self, ctx: StartOperationContext, input: Input) -> Output: + assert ctx.headers["test-header-key"] == "test-header-value" + ctx.outbound_links.extend(ctx.inbound_links) + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) @service_handler(service=EchoService) class DefaultCancelHandler: - @operation_handler - def echo(self) -> OperationHandler[Input, Output]: - async def start(ctx: StartOperationContext, input: Input) -> Output: - return Output( - value=f"from start method on {self.__class__.__name__}: {input.value}" - ) - - return SyncOperationHandler.from_callable(start) + @sync_operation_handler + async def echo(self, ctx: StartOperationContext, input: Input) -> Output: + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) @service_handler(service=EchoService) @@ -1057,12 +1009,12 @@ async def test_cancel_operation_with_invalid_token(env: WorkflowEnvironment): ) decorator = service_handler(service=MyService) - service_handler = decorator(MyServiceHandler)() + user_service_handler = decorator(MyServiceHandler)() async with Worker( env.client, task_queue=task_queue, - nexus_service_handlers=[service_handler], + nexus_service_handlers=[user_service_handler], nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), ): cancel_response = await service_client.cancel_operation( @@ -1086,12 +1038,12 @@ async def test_request_id_is_received_by_sync_operation_handler( ) decorator = service_handler(service=MyService) - service_handler = decorator(MyServiceHandler)() + user_service_handler = decorator(MyServiceHandler)() async with Worker( env.client, task_queue=task_queue, - nexus_service_handlers=[service_handler], + nexus_service_handlers=[user_service_handler], nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), ): request_id = str(uuid.uuid4()) diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index 0a0399e16..92802dc88 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -3,7 +3,7 @@ import nexusrpc import nexusrpc.handler import pytest -from nexusrpc.handler import OperationHandler, SyncOperationHandler +from nexusrpc.handler import StartOperationContext, sync_operation_handler from temporalio.nexus.handler import ( WorkflowOperationToken, @@ -25,13 +25,8 @@ class Interface: op: nexusrpc.Operation[None, None] class Impl: - @nexusrpc.handler.operation_handler - def op(self) -> OperationHandler[None, None]: - async def start( - ctx: nexusrpc.handler.StartOperationContext, input: None - ) -> None: ... - - return SyncOperationHandler.from_callable(start) + @sync_operation_handler + async def op(self, ctx: StartOperationContext, input: None) -> None: ... error_message = None diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 32dfd2372..11bd9b3b4 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -15,9 +15,9 @@ StartOperationContext, StartOperationResultAsync, StartOperationResultSync, - SyncOperationHandler, operation_handler, service_handler, + sync_operation_handler, ) import temporalio.api @@ -195,21 +195,18 @@ def sync_or_async_operation( ) -> OperationHandler[OpInput, OpOutput]: return SyncOrAsyncOperation() - @operation_handler - def sync_operation( - self, - ) -> OperationHandler[OpInput, OpOutput]: - async def start(ctx: StartOperationContext, input: OpInput) -> OpOutput: - assert isinstance(input.response_type, SyncResponse) - if input.response_type.exception_in_operation_start: - raise RPCError( - "RPCError INVALID_ARGUMENT in Nexus operation", - RPCStatusCode.INVALID_ARGUMENT, - b"", - ) - return OpOutput(value="sync response") - - return SyncOperationHandler.from_callable(start) + @sync_operation_handler + async def sync_operation( + self, ctx: StartOperationContext, input: OpInput + ) -> OpOutput: + assert isinstance(input.response_type, SyncResponse) + if input.response_type.exception_in_operation_start: + raise RPCError( + "RPCError INVALID_ARGUMENT in Nexus operation", + RPCStatusCode.INVALID_ARGUMENT, + b"", + ) + return OpOutput(value="sync response") @operation_handler def async_operation( @@ -759,58 +756,38 @@ class ServiceInterfaceWithNameOverride: @service_handler class ServiceImplInterfaceWithNeitherInterfaceNorNameOverride: - @operation_handler - def op( - self, - ) -> OperationHandler[None, ServiceClassNameOutput]: - async def start( - ctx: StartOperationContext, input: None - ) -> ServiceClassNameOutput: - return ServiceClassNameOutput(self.__class__.__name__) - - return SyncOperationHandler.from_callable(start) + @sync_operation_handler + async def op( + self, ctx: StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) @service_handler(service=ServiceInterfaceWithoutNameOverride) class ServiceImplInterfaceWithoutNameOverride: - @operation_handler - def op( - self, - ) -> OperationHandler[None, ServiceClassNameOutput]: - async def start( - ctx: StartOperationContext, input: None - ) -> ServiceClassNameOutput: - return ServiceClassNameOutput(self.__class__.__name__) - - return SyncOperationHandler.from_callable(start) + @sync_operation_handler + async def op( + self, ctx: StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) @service_handler(service=ServiceInterfaceWithNameOverride) class ServiceImplInterfaceWithNameOverride: - @operation_handler - def op( - self, - ) -> OperationHandler[None, ServiceClassNameOutput]: - async def start( - ctx: StartOperationContext, input: None - ) -> ServiceClassNameOutput: - return ServiceClassNameOutput(self.__class__.__name__) - - return SyncOperationHandler.from_callable(start) + @sync_operation_handler + async def op( + self, ctx: StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) @service_handler(name="service-impl-🌈") class ServiceImplWithNameOverride: - @operation_handler - def op( - self, - ) -> OperationHandler[None, ServiceClassNameOutput]: - async def start( - ctx: StartOperationContext, input: None - ) -> ServiceClassNameOutput: - return ServiceClassNameOutput(self.__class__.__name__) - - return SyncOperationHandler.from_callable(start) + @sync_operation_handler + async def op( + self, ctx: StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) class NameOverride(IntEnum): From 7355554e152e464bc2223f8bb275ed4d75a6a983 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 19:24:46 -0400 Subject: [PATCH 072/183] New workflow_run_operation_handler --- temporalio/nexus/handler/__init__.py | 6 +- temporalio/nexus/handler/_decorators.py | 128 +++++++++++++++ .../nexus/handler/_operation_handlers.py | 109 +++++-------- temporalio/nexus/handler/_util.py | 56 +------ temporalio/worker/_interceptor.py | 8 +- temporalio/worker/_workflow_instance.py | 5 + temporalio/workflow.py | 17 ++ .../nexus/test_get_input_and_output_types.py | 153 ------------------ tests/nexus/test_handler.py | 130 +++++++-------- .../test_handler_interface_implementation.py | 13 +- .../test_handler_operation_definitions.py | 46 ++---- tests/nexus/test_workflow_caller.py | 81 +++++----- tests/nexus/test_workflow_run_operation.py | 9 +- 13 files changed, 324 insertions(+), 437 deletions(-) create mode 100644 temporalio/nexus/handler/_decorators.py delete mode 100644 tests/nexus/test_get_input_and_output_types.py diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py index 995531b64..86584ae0a 100644 --- a/temporalio/nexus/handler/__init__.py +++ b/temporalio/nexus/handler/__init__.py @@ -6,15 +6,15 @@ Optional, ) +from ._decorators import ( + workflow_run_operation_handler as workflow_run_operation_handler, +) from ._operation_context import ( _TemporalNexusOperationContext as _TemporalNexusOperationContext, ) from ._operation_context import ( temporal_operation_context as temporal_operation_context, ) -from ._operation_handlers import ( - WorkflowRunOperationHandler as WorkflowRunOperationHandler, -) from ._operation_handlers import cancel_operation as cancel_operation from ._token import WorkflowOperationToken as WorkflowOperationToken from ._workflow import start_workflow as start_workflow diff --git a/temporalio/nexus/handler/_decorators.py b/temporalio/nexus/handler/_decorators.py new file mode 100644 index 000000000..ff0e6d599 --- /dev/null +++ b/temporalio/nexus/handler/_decorators.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from typing import ( + Awaitable, + Callable, + Optional, + Union, + overload, +) + +import nexusrpc +from nexusrpc.handler import ( + OperationHandler, + StartOperationContext, +) +from nexusrpc.types import InputT, OutputT, ServiceHandlerT + +from temporalio.nexus.handler._operation_handlers import ( + WorkflowRunOperationHandler, +) +from temporalio.nexus.handler._token import ( + WorkflowOperationToken, +) +from temporalio.nexus.handler._util import ( + get_workflow_run_start_method_input_and_output_type_annotations, +) + + +@overload +def workflow_run_operation_handler( + start: Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowOperationToken[OutputT]], + ], +) -> Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowOperationToken[OutputT]], +]: ... + + +@overload +def workflow_run_operation_handler( + *, + name: Optional[str] = None, +) -> Callable[ + [ + Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowOperationToken[OutputT]], + ] + ], + Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowOperationToken[OutputT]], + ], +]: ... + + +def workflow_run_operation_handler( + start: Optional[ + Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowOperationToken[OutputT]], + ] + ] = None, + *, + name: Optional[str] = None, +) -> Union[ + Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowOperationToken[OutputT]], + ], + Callable[ + [ + Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowOperationToken[OutputT]], + ] + ], + Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowOperationToken[OutputT]], + ], + ], +]: + """ + Decorator marking a method as the start method for a workflow-backed operation. + """ + + def decorator( + start: Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowOperationToken[OutputT]], + ], + ) -> Callable[ + [ServiceHandlerT, StartOperationContext, InputT], + Awaitable[WorkflowOperationToken[OutputT]], + ]: + ( + input_type, + output_type, + ) = get_workflow_run_start_method_input_and_output_type_annotations(start) + + def operation_handler_factory( + self: ServiceHandlerT, + ) -> OperationHandler[InputT, OutputT]: + async def _start( + ctx: StartOperationContext, input: InputT + ) -> WorkflowOperationToken[OutputT]: + return await start(self, ctx, input) + + _start.__doc__ = start.__doc__ + return WorkflowRunOperationHandler(_start, input_type, output_type) + + operation_handler_factory.__nexus_operation__ = nexusrpc.Operation( + name=name or start.__name__, + method_name=start.__name__, + input_type=input_type, + output_type=output_type, + ) + + start.__nexus_operation_factory__ = operation_handler_factory + return start + + if start is None: + return decorator + + return decorator(start) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index e71ad091a..179e681f9 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -1,18 +1,21 @@ from __future__ import annotations -from abc import ABC, abstractmethod from typing import ( Any, Awaitable, Callable, Optional, + Type, ) -import nexusrpc.handler +from nexusrpc import OperationInfo from nexusrpc.handler import ( CancelOperationContext, + FetchOperationInfoContext, + FetchOperationResultContext, HandlerError, HandlerErrorType, + OperationHandler, StartOperationContext, StartOperationResultAsync, ) @@ -28,15 +31,11 @@ from temporalio.nexus.handler._token import WorkflowOperationToken from ._util import ( - get_workflow_run_start_method_input_and_output_type_annotations, is_async_callable, ) -class WorkflowRunOperationHandler( - nexusrpc.handler.OperationHandler[InputT, OutputT], - ABC, -): +class WorkflowRunOperationHandler(OperationHandler[InputT, OutputT]): """ Operation handler for Nexus operations that start a workflow. @@ -67,61 +66,58 @@ async def start( def __init__( self, - start: Optional[ - Callable[ - [StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], - ] - ] = None, - ) -> None: - if start is not None: - if not is_async_callable(start): - raise RuntimeError( - f"{start} is not an `async def` method. " - "WorkflowRunOperationHandler must be initialized with an " - "`async def` start method." - ) - self._start = start - if start.__doc__: - self.start.__func__.__doc__ = start.__doc__ - self._input_type, self._output_type = ( - get_workflow_run_start_method_input_and_output_type_annotations(start) - ) - else: - self._start = self._input_type = self._output_type = None - - @classmethod - def from_callable( - cls, - start_workflow: Callable[ + start: Callable[ [StartOperationContext, InputT], Awaitable[WorkflowOperationToken[OutputT]], ], - ) -> WorkflowRunOperationHandler[InputT, OutputT]: - return _WorkflowRunOperationHandler(start_workflow) + input_type: Optional[Type[InputT]], + output_type: Optional[Type[OutputT]], + ) -> None: + if not is_async_callable(start): + raise RuntimeError( + f"{start} is not an `async def` method. " + "WorkflowRunOperationHandler must be initialized with an " + "`async def` start method." + ) + self._start = start + if start.__doc__: + self.start.__func__.__doc__ = start.__doc__ + self._input_type = input_type + self._output_type = output_type - @abstractmethod async def start( self, ctx: StartOperationContext, input: InputT - ) -> nexusrpc.handler.StartOperationResultAsync: + ) -> StartOperationResultAsync: """ Start the operation, by starting a workflow and completing asynchronously. """ - ... + token = await self._start(ctx, input) + if not isinstance(token, WorkflowOperationToken): + if isinstance(token, WorkflowHandle): + raise RuntimeError( + f"Expected {token} to be a WorkflowOperationToken, but got a WorkflowHandle. " + f"You must use :py:meth:`temporalio.nexus.handler.start_workflow` " + "to start a workflow that will deliver the result of the Nexus operation, " + "not :py:meth:`temporalio.client.Client.start_workflow`." + ) + raise RuntimeError( + f"Expected {token} to be a WorkflowOperationToken, but got {type(token)}. " + ) + return StartOperationResultAsync(token.encode()) async def cancel(self, ctx: CancelOperationContext, token: str) -> None: """Cancel the operation, by cancelling the workflow.""" await cancel_operation(token) async def fetch_info( - self, ctx: nexusrpc.handler.FetchOperationInfoContext, token: str - ) -> nexusrpc.OperationInfo: + self, ctx: FetchOperationInfoContext, token: str + ) -> OperationInfo: raise NotImplementedError( "Temporal Nexus operation handlers do not support fetching operation info." ) async def fetch_result( - self, ctx: nexusrpc.handler.FetchOperationResultContext, token: str + self, ctx: FetchOperationResultContext, token: str ) -> OutputT: raise NotImplementedError( "Temporal Nexus operation handlers do not support fetching the operation result." @@ -150,35 +146,6 @@ async def fetch_result( return await handle.result() -class _WorkflowRunOperationHandler(WorkflowRunOperationHandler[InputT, OutputT]): - async def start( - self, ctx: StartOperationContext, input: InputT - ) -> nexusrpc.handler.StartOperationResultAsync: - """ - Start the operation, by starting a workflow and completing asynchronously. - """ - - if self._start is None: - raise RuntimeError( - "Do not use _WorkflowRunOperationHandler directly. " - "Use WorkflowRunOperationHandler.from_start_workflow instead." - ) - - token = await self._start(ctx, input) - if not isinstance(token, WorkflowOperationToken): - if isinstance(token, WorkflowHandle): - raise RuntimeError( - f"Expected {token} to be a WorkflowOperationToken, but got a WorkflowHandle. " - f"You must use :py:meth:`temporalio.nexus.handler.start_workflow` " - "to start a workflow that will deliver the result of the Nexus operation, " - "not :py:meth:`temporalio.client.Client.start_workflow`." - ) - raise RuntimeError( - f"Expected {token} to be a WorkflowOperationToken, but got {type(token)}. " - ) - return StartOperationResultAsync(token.encode()) - - async def cancel_operation( token: str, **kwargs: Any, diff --git a/temporalio/nexus/handler/_util.py b/temporalio/nexus/handler/_util.py index 09f4c0939..dce02eab9 100644 --- a/temporalio/nexus/handler/_util.py +++ b/temporalio/nexus/handler/_util.py @@ -10,15 +10,16 @@ Callable, Optional, Type, - Union, ) from nexusrpc.handler import ( StartOperationContext, + get_start_method_input_and_output_type_annotations, ) from nexusrpc.types import ( InputT, OutputT, + ServiceHandlerT, ) from ._token import ( @@ -28,7 +29,7 @@ def get_workflow_run_start_method_input_and_output_type_annotations( start: Callable[ - [StartOperationContext, InputT], + [ServiceHandlerT, StartOperationContext, InputT], Awaitable[WorkflowOperationToken[OutputT]], ], ) -> tuple[ @@ -40,7 +41,7 @@ def get_workflow_run_start_method_input_and_output_type_annotations( `start` must be a type-annotated start method that returns a :py:class:`WorkflowHandle`. """ - input_type, output_type = _get_start_method_input_and_output_type_annotations(start) + input_type, output_type = get_start_method_input_and_output_type_annotations(start) origin_type = typing.get_origin(output_type) if not origin_type: output_type = None @@ -65,55 +66,6 @@ def get_workflow_run_start_method_input_and_output_type_annotations( return input_type, output_type -def _get_start_method_input_and_output_type_annotations( - start: Callable[ - [StartOperationContext, InputT], - Union[OutputT, Awaitable[OutputT]], - ], -) -> tuple[ - Optional[Type[InputT]], - Optional[Type[OutputT]], -]: - """Return operation input and output types. - - `start` must be a type-annotated start method that returns a synchronous result. - """ - try: - type_annotations = typing.get_type_hints(start) - except TypeError: - # TODO(nexus-preview): stacklevel - warnings.warn( - f"Expected decorated start method {start} to have type annotations" - ) - return None, None - - if not type_annotations: - return None, None - - output_type = type_annotations.pop("return", None) - - if len(type_annotations) != 2: - # TODO(nexus-preview): stacklevel - suffix = f": {type_annotations}" if type_annotations else "" - warnings.warn( - f"Expected decorated start method {start} to have exactly 2 " - f"type-annotated parameters (ctx and input), but it has {len(type_annotations)}" - f"{suffix}." - ) - input_type = None - else: - ctx_type, input_type = type_annotations.values() - if not issubclass(ctx_type, StartOperationContext): - # TODO(nexus-preview): stacklevel - warnings.warn( - f"Expected first parameter of {start} to be an instance of " - f"StartOperationContext, but is {ctx_type}." - ) - input_type = None - - return input_type, output_type - - # Copied from https://github.com/modelcontextprotocol/python-sdk # # Copyright (c) 2024 Anthropic, PBC. diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 6f6965093..f6d7672a9 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -27,6 +27,7 @@ import temporalio.activity import temporalio.api.common.v1 import temporalio.common +import temporalio.nexus.handler import temporalio.workflow from temporalio.workflow import VersioningIntent @@ -299,6 +300,10 @@ class StartNexusOperationInput(Generic[InputT, OutputT]): operation: Union[ nexusrpc.Operation[InputT, OutputT], Callable[[Any], nexusrpc.handler.OperationHandler[InputT, OutputT]], + Callable[ + [Any, nexusrpc.handler.StartOperationContext, InputT], + temporalio.nexus.handler.WorkflowOperationToken[OutputT], + ], str, ] input: InputT @@ -309,6 +314,7 @@ class StartNexusOperationInput(Generic[InputT, OutputT]): _operation_name: str = field(init=False, repr=False) _input_type: Optional[Type[InputT]] = field(init=False, repr=False) + # TODO(nexus-prerelease): update this logic to handle service impl start methods def __post_init__(self) -> None: if isinstance(self.operation, str): self._operation_name = self.operation @@ -318,7 +324,7 @@ def __post_init__(self) -> None: self._input_type = self.operation.input_type self.output_type = self.operation.output_type elif isinstance(self.operation, Callable): - op = getattr(self.operation, "__nexus_operation__", None) + _, op = nexusrpc.handler.get_operation_factory(self.operation) if isinstance(op, nexusrpc.Operation): self._operation_name = op.name self._input_type = op.input_type diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index a709ea069..5f1e1db5f 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -59,6 +59,7 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.nexus.handler import temporalio.workflow from temporalio.service import __version__ @@ -1500,6 +1501,10 @@ async def workflow_start_nexus_operation( operation: Union[ nexusrpc.Operation[I, O], Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], + Callable[ + [Any, nexusrpc.handler.StartOperationContext, I], + temporalio.nexus.handler.WorkflowOperationToken[O], + ], str, ], input: Any, diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 7fcd7f376..84aa25e09 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -56,6 +56,7 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.nexus.handler import temporalio.workflow from .types import ( @@ -856,6 +857,10 @@ async def workflow_start_nexus_operation( operation: Union[ nexusrpc.Operation[I, O], Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], + Callable[ + [S, nexusrpc.handler.StartOperationContext, I], + temporalio.nexus.handler.WorkflowOperationToken[O], + ], str, ], input: Any, @@ -4422,6 +4427,10 @@ async def start_nexus_operation( operation: Union[ nexusrpc.Operation[I, O], Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], + Callable[ + [S, nexusrpc.handler.StartOperationContext, I], + temporalio.nexus.handler.WorkflowOperationToken[O], + ], str, ], input: Any, @@ -5205,6 +5214,10 @@ async def start_operation( operation: Union[ nexusrpc.Operation[I, O], Callable[[S], nexusrpc.handler.OperationHandler[I, O]], + Callable[ + [S, nexusrpc.handler.StartOperationContext, I], + temporalio.nexus.handler.WorkflowOperationToken[O], + ], str, ], input: I, @@ -5229,6 +5242,10 @@ async def execute_operation( operation: Union[ nexusrpc.Operation[I, O], Callable[[S], nexusrpc.handler.OperationHandler[I, O]], + Callable[ + [S, nexusrpc.handler.StartOperationContext, I], + temporalio.nexus.handler.WorkflowOperationToken[O], + ], str, ], input: I, diff --git a/tests/nexus/test_get_input_and_output_types.py b/tests/nexus/test_get_input_and_output_types.py deleted file mode 100644 index fcfa0fa8b..000000000 --- a/tests/nexus/test_get_input_and_output_types.py +++ /dev/null @@ -1,153 +0,0 @@ -import warnings -from typing import ( - Any, - Awaitable, - Type, - Union, - get_args, - get_origin, -) - -import pytest -from nexusrpc.handler import ( - StartOperationContext, -) - -from temporalio.nexus.handler._util import ( - _get_start_method_input_and_output_type_annotations, -) - - -class Input: - pass - - -class Output: - pass - - -class _TestCase: - @staticmethod - def start(ctx: StartOperationContext, i: Input) -> Output: ... - - expected_types: tuple[Any, Any] - - -class SyncMethod(_TestCase): - @staticmethod - def start(ctx: StartOperationContext, i: Input) -> Output: ... - - expected_types = (Input, Output) - - -class AsyncMethod(_TestCase): - @staticmethod - async def start(ctx: StartOperationContext, i: Input) -> Output: ... - - expected_types = (Input, Output) - - -class UnionMethod(_TestCase): - @staticmethod - def start( - ctx: StartOperationContext, i: Input - ) -> Union[Output, Awaitable[Output]]: ... - - expected_types = (Input, Union[Output, Awaitable[Output]]) - - -class MissingInputAnnotationInUnionMethod(_TestCase): - @staticmethod - def start(ctx: StartOperationContext, i) -> Union[Output, Awaitable[Output]]: ... - - expected_types = (None, Union[Output, Awaitable[Output]]) - - -class TooFewParams(_TestCase): - @staticmethod - def start(i: Input) -> Output: ... - - expected_types = (None, Output) - - -class TooManyParams(_TestCase): - @staticmethod - def start(ctx: StartOperationContext, i: Input, extra: int) -> Output: ... - - expected_types = (None, Output) - - -class WrongOptionsType(_TestCase): - @staticmethod - def start(ctx: int, i: Input) -> Output: ... - - expected_types = (None, Output) - - -class NoReturnHint(_TestCase): - @staticmethod - def start(ctx: StartOperationContext, i: Input): ... - - expected_types = (Input, None) - - -class NoInputAnnotation(_TestCase): - @staticmethod - def start(ctx: StartOperationContext, i) -> Output: ... - - expected_types = (None, Output) - - -class NoOptionsAnnotation(_TestCase): - @staticmethod - def start(ctx, i: Input) -> Output: ... - - expected_types = (None, Output) - - -class AllAnnotationsMissing(_TestCase): - @staticmethod - def start(ctx: StartOperationContext, i): ... - - expected_types = (None, None) - - -class ExplicitNoneTypes(_TestCase): - @staticmethod - def start(ctx: StartOperationContext, i: None) -> None: ... - - expected_types = (type(None), type(None)) - - -@pytest.mark.parametrize( - "test_case", - [ - SyncMethod, - AsyncMethod, - UnionMethod, - TooFewParams, - TooManyParams, - WrongOptionsType, - NoReturnHint, - NoInputAnnotation, - NoOptionsAnnotation, - MissingInputAnnotationInUnionMethod, - AllAnnotationsMissing, - ExplicitNoneTypes, - ], -) -def test_get_input_and_output_types(test_case: Type[_TestCase]): - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - input_type, output_type = _get_start_method_input_and_output_type_annotations( - test_case.start - ) - expected_input_type, expected_output_type = test_case.expected_types - assert input_type is expected_input_type - - expected_origin = get_origin(expected_output_type) - if expected_origin: # Awaitable and Union cases - assert get_origin(output_type) is expected_origin - assert get_args(output_type) == get_args(expected_output_type) - else: - assert output_type is expected_output_type diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 177d0422e..e0740c039 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -57,6 +57,7 @@ logger, start_workflow, temporal_operation_context, + workflow_run_operation_handler, ) from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker @@ -211,7 +212,7 @@ async def log(self, ctx: StartOperationContext, input: Input) -> Output: logger.info("Logging from start method", extra={"input_value": input.value}) return Output(value=f"logged: {input.value}") - @sync_operation_handler + @workflow_run_operation_handler async def workflow_run_operation( self, ctx: StartOperationContext, input: Input ) -> WorkflowOperationToken[Output]: @@ -262,39 +263,29 @@ async def sync_operation_without_type_annotations( value=f"from start method on {self.__class__.__name__} without type annotations: {input}" ) - @operation_handler - def workflow_run_operation_without_type_annotations( - self, - ) -> OperationHandler[Input, Output]: - async def start(ctx, input): - return await start_workflow( - WorkflowWithoutTypeAnnotations.run, - input, - id=str(uuid.uuid4()), - ) - - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) + @workflow_run_operation_handler + async def workflow_run_operation_without_type_annotations(self, ctx, input): + return await start_workflow( + WorkflowWithoutTypeAnnotations.run, + input, + id=str(uuid.uuid4()), + ) - @operation_handler - def workflow_run_op_link_test( - self, - ) -> OperationHandler[Input, Output]: - async def start( - ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: - assert any( - link.url == "http://inbound-link/" for link in ctx.inbound_links - ), "Inbound link not found" - assert ctx.request_id == "test-request-id-123", "Request ID mismatch" - ctx.outbound_links.extend(ctx.inbound_links) - - return await start_workflow( - MyLinkTestWorkflow.run, - input, - id=str(uuid.uuid4()), - ) + @workflow_run_operation_handler + async def workflow_run_op_link_test( + self, ctx: StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: + assert any( + link.url == "http://inbound-link/" for link in ctx.inbound_links + ), "Inbound link not found" + assert ctx.request_id == "test-request-id-123", "Request ID mismatch" + ctx.outbound_links.extend(ctx.inbound_links) - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) + return await start_workflow( + MyLinkTestWorkflow.run, + input, + id=str(uuid.uuid4()), + ) class OperationHandlerReturningUnwrappedResult(OperationHandler[Input, Output]): async def start( @@ -609,7 +600,8 @@ class OperationHandlerReturningUnwrappedResultError(_FailureTestCase): retryable_header=False, failure_message=( "Operation start method must return either " - "StartOperationResultSync or StartOperationResultAsync." + "nexusrpc.handler.StartOperationResultSync or " + "nexusrpc.handler.StartOperationResultAsync." ), ) @@ -823,12 +815,12 @@ async def _test_start_operation( if with_service_definition else service_handler ) - service_handler = decorator(MyServiceHandler)() + user_service_handler = decorator(MyServiceHandler)() async with Worker( env.client, task_queue=task_queue, - nexus_service_handlers=[service_handler], + nexus_service_handlers=[user_service_handler], nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), ): response = await service_client.start_operation( @@ -1063,46 +1055,36 @@ async def run(self, input: Input) -> Output: @service_handler class ServiceHandlerForRequestIdTest: - @operation_handler - def operation_backed_by_a_workflow( - self, - ) -> OperationHandler[Input, Output]: - async def start( - ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: - return await start_workflow( - EchoWorkflow.run, - input, - id=input.value, - id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, - ) - - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) - - @operation_handler - def operation_that_executes_a_workflow_before_starting_the_backing_workflow( - self, - ) -> OperationHandler[Input, Output]: - async def start( - ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: - tctx = temporal_operation_context.get() - await tctx.client.start_workflow( - EchoWorkflow.run, - input, - id=input.value, - task_queue=tctx.task_queue, - ) - # This should fail. It will not fail if the Nexus request ID was incorrectly - # propagated to both StartWorkflow requests. - return await start_workflow( - EchoWorkflow.run, - input, - id=input.value, - id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, - ) + @workflow_run_operation_handler + async def operation_backed_by_a_workflow( + self, ctx: StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: + return await start_workflow( + EchoWorkflow.run, + input, + id=input.value, + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, + ) - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) + @workflow_run_operation_handler + async def operation_that_executes_a_workflow_before_starting_the_backing_workflow( + self, ctx: StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: + tctx = temporal_operation_context.get() + await tctx.client.start_workflow( + EchoWorkflow.run, + input, + id=input.value, + task_queue=tctx.task_queue, + ) + # This should fail. It will not fail if the Nexus request ID was incorrectly + # propagated to both StartWorkflow requests. + return await start_workflow( + EchoWorkflow.run, + input, + id=input.value, + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, + ) async def test_request_id_becomes_start_workflow_request_id(env: WorkflowEnvironment): diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index 92802dc88..6e00d4838 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -7,7 +7,7 @@ from temporalio.nexus.handler import ( WorkflowOperationToken, - WorkflowRunOperationHandler, + workflow_run_operation_handler, ) HTTP_PORT = 7243 @@ -37,13 +37,10 @@ class Interface: op: nexusrpc.Operation[str, int] class Impl: - @nexusrpc.handler.operation_handler - def op(self) -> nexusrpc.handler.OperationHandler[str, int]: - async def start( - ctx: nexusrpc.handler.StartOperationContext, input: str - ) -> WorkflowOperationToken[int]: ... - - return WorkflowRunOperationHandler.from_callable(start) + @workflow_run_operation_handler + async def op( + self, ctx: StartOperationContext, input: str + ) -> WorkflowOperationToken[int]: ... error_message = None diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index da3781846..7e29b4680 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -8,10 +8,11 @@ import nexusrpc.handler import pytest +from nexusrpc.handler import StartOperationContext from temporalio.nexus.handler import ( WorkflowOperationToken, - WorkflowRunOperationHandler, + workflow_run_operation_handler, ) @@ -34,15 +35,10 @@ class _TestCase: class NotCalled(_TestCase): @nexusrpc.handler.service_handler class Service: - @nexusrpc.handler.operation_handler - def my_workflow_run_operation_handler( - self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: - async def start( - ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: ... - - return WorkflowRunOperationHandler.from_callable(start) + @workflow_run_operation_handler + async def my_workflow_run_operation_handler( + self, ctx: StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: ... expected_operations = { "my_workflow_run_operation_handler": nexusrpc.Operation( @@ -57,15 +53,10 @@ async def start( class CalledWithoutArgs(_TestCase): @nexusrpc.handler.service_handler class Service: - @nexusrpc.handler.operation_handler() - def my_workflow_run_operation_handler( - self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: - async def start( - ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: ... - - return WorkflowRunOperationHandler.from_callable(start) + @workflow_run_operation_handler + async def my_workflow_run_operation_handler( + self, ctx: StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: ... expected_operations = NotCalled.expected_operations @@ -73,15 +64,10 @@ async def start( class CalledWithNameOverride(_TestCase): @nexusrpc.handler.service_handler class Service: - @nexusrpc.handler.operation_handler(name="operation-name") - def workflow_run_operation_with_name_override( - self, - ) -> nexusrpc.handler.OperationHandler[Input, Output]: - async def start( - ctx: nexusrpc.handler.StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: ... - - return WorkflowRunOperationHandler.from_callable(start) + @workflow_run_operation_handler(name="operation-name") + async def workflow_run_operation_with_name_override( + self, ctx: StartOperationContext, input: Input + ) -> WorkflowOperationToken[Output]: ... expected_operations = { "workflow_run_operation_with_name_override": nexusrpc.Operation( @@ -111,7 +97,9 @@ async def test_collected_operation_names( assert isinstance(service, nexusrpc.ServiceDefinition) assert service.name == "Service" for method_name, expected_op in test_case.expected_operations.items(): - actual_op = getattr(test_case.Service, method_name).__nexus_operation__ + _, actual_op = nexusrpc.handler.get_operation_factory( + getattr(test_case.Service, method_name) + ) assert isinstance(actual_op, nexusrpc.Operation) assert actual_op.name == expected_op.name assert actual_op.input_type == expected_op.input_type diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 11bd9b3b4..d23b6080b 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -42,9 +42,9 @@ from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError from temporalio.nexus.handler import ( WorkflowOperationToken, - WorkflowRunOperationHandler, start_workflow, temporal_operation_context, + workflow_run_operation_handler, ) from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import UnsandboxedWorkflowRunner, Worker @@ -208,27 +208,22 @@ async def sync_operation( ) return OpOutput(value="sync response") - @operation_handler - def async_operation( - self, - ) -> OperationHandler[OpInput, HandlerWfOutput]: - async def start( - ctx: StartOperationContext, input: OpInput - ) -> WorkflowOperationToken[HandlerWfOutput]: - assert isinstance(input.response_type, AsyncResponse) - if input.response_type.exception_in_operation_start: - raise RPCError( - "RPCError INVALID_ARGUMENT in Nexus operation", - RPCStatusCode.INVALID_ARGUMENT, - b"", - ) - return await start_workflow( - HandlerWorkflow.run, - HandlerWfInput(op_input=input), - id=input.response_type.operation_workflow_id, + @workflow_run_operation_handler + async def async_operation( + self, ctx: StartOperationContext, input: OpInput + ) -> WorkflowOperationToken[HandlerWfOutput]: + assert isinstance(input.response_type, AsyncResponse) + if input.response_type.exception_in_operation_start: + raise RPCError( + "RPCError INVALID_ARGUMENT in Nexus operation", + RPCStatusCode.INVALID_ARGUMENT, + b"", ) - - return WorkflowRunOperationHandler.from_callable(start) + return await start_workflow( + HandlerWorkflow.run, + HandlerWfInput(op_input=input), + id=input.response_type.operation_workflow_id, + ) # ----------------------------------------------------------------------------- @@ -924,30 +919,25 @@ async def run(self, input: str) -> str: @service_handler class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: - @operation_handler - def my_workflow_run_operation( - self, - ) -> OperationHandler[None, str]: - async def start( - ctx: StartOperationContext, input: None - ) -> WorkflowOperationToken[str]: - tctx = temporal_operation_context.get() - result_1 = await tctx.client.execute_workflow( - EchoWorkflow.run, - "result-1", - id=str(uuid.uuid4()), - task_queue=tctx.task_queue, - ) - # In case result_1 is incorrectly being delivered to the caller as the operation - # result, give time for that incorrect behavior to occur. - await asyncio.sleep(0.5) - return await start_workflow( - EchoWorkflow.run, - f"{result_1}-result-2", - id=str(uuid.uuid4()), - ) - - return temporalio.nexus.handler.WorkflowRunOperationHandler.from_callable(start) + @workflow_run_operation_handler + async def my_workflow_run_operation( + self, ctx: StartOperationContext, input: None + ) -> WorkflowOperationToken[str]: + tctx = temporal_operation_context.get() + result_1 = await tctx.client.execute_workflow( + EchoWorkflow.run, + "result-1", + id=str(uuid.uuid4()), + task_queue=tctx.task_queue, + ) + # In case result_1 is incorrectly being delivered to the caller as the operation + # result, give time for that incorrect behavior to occur. + await asyncio.sleep(0.5) + return await start_workflow( + EchoWorkflow.run, + f"{result_1}-result-2", + id=str(uuid.uuid4()), + ) @workflow.defn @@ -958,6 +948,7 @@ async def run(self, input: str, task_queue: str) -> str: service=ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow, endpoint=make_nexus_endpoint_name(task_queue), ) + # TODO(nexus-prerelease): update StartNexusOperationInput.__post_init__ return await nexus_client.execute_operation( ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow.my_workflow_run_operation, None, diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index b75bd2407..f6409aae2 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -13,7 +13,8 @@ ) from temporalio import workflow -from temporalio.nexus.handler import WorkflowRunOperationHandler, start_workflow +from temporalio.nexus.handler import start_workflow +from temporalio.nexus.handler._operation_handlers import WorkflowRunOperationHandler from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ServiceClient, create_nexus_endpoint, dataclass_as_dict @@ -33,7 +34,13 @@ async def run(self, input: str) -> str: return input +# TODO(nexus-prerelease): this test dates from a point at which we were encouraging +# subclassing WorkflowRunOperationHandler as part of the public API. Leaving it in for +# now. class MyOperation(WorkflowRunOperationHandler): + def __init__(self): + pass + async def start( self, ctx: StartOperationContext, input: Input ) -> StartOperationResultAsync: From ec316902a3291ed7f43608fe7c2e5349e2b509a0 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 22:18:34 -0400 Subject: [PATCH 073/183] Delete reference to obsolete __nexus_service_metadata__ --- temporalio/workflow.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 84aa25e09..64624cd76 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -5197,8 +5197,6 @@ def __init__( self._service_name = service elif service_defn := getattr(service, "__nexus_service__", None): self._service_name = service_defn.name - elif service_metadata := getattr(service, "__nexus_service_metadata__", None): - self._service_name = service_metadata.name else: raise ValueError( f"`service` may be a name (str), or a class decorated with either " From 7f9c1440729895e4ca2568ea97baaf31595de404 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 22:20:42 -0400 Subject: [PATCH 074/183] TODO --- temporalio/nexus/handler/_decorators.py | 1 + temporalio/workflow.py | 1 + 2 files changed, 2 insertions(+) diff --git a/temporalio/nexus/handler/_decorators.py b/temporalio/nexus/handler/_decorators.py index ff0e6d599..5d75d0d9f 100644 --- a/temporalio/nexus/handler/_decorators.py +++ b/temporalio/nexus/handler/_decorators.py @@ -112,6 +112,7 @@ async def _start( _start.__doc__ = start.__doc__ return WorkflowRunOperationHandler(_start, input_type, output_type) + # TODO(preview): make double-underscore attrs private to nexusrpc and expose getters/setters operation_handler_factory.__nexus_operation__ = nexusrpc.Operation( name=name or start.__name__, method_name=start.__name__, diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 64624cd76..cceb24d66 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -5195,6 +5195,7 @@ def __init__( # class. if isinstance(service, str): self._service_name = service + # TODO(preview): make double-underscore attrs private to nexusrpc and expose getters/setters elif service_defn := getattr(service, "__nexus_service__", None): self._service_name = service_defn.name else: From 602d4120b9876788f2bf09a9b907d132dcd8fa21 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 22:24:00 -0400 Subject: [PATCH 075/183] Use get_callable_name utility --- temporalio/nexus/handler/_decorators.py | 6 ++++-- temporalio/nexus/handler/_util.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/temporalio/nexus/handler/_decorators.py b/temporalio/nexus/handler/_decorators.py index 5d75d0d9f..461a6f3d9 100644 --- a/temporalio/nexus/handler/_decorators.py +++ b/temporalio/nexus/handler/_decorators.py @@ -22,6 +22,7 @@ WorkflowOperationToken, ) from temporalio.nexus.handler._util import ( + get_callable_name, get_workflow_run_start_method_input_and_output_type_annotations, ) @@ -112,10 +113,11 @@ async def _start( _start.__doc__ = start.__doc__ return WorkflowRunOperationHandler(_start, input_type, output_type) + method_name = get_callable_name(start) # TODO(preview): make double-underscore attrs private to nexusrpc and expose getters/setters operation_handler_factory.__nexus_operation__ = nexusrpc.Operation( - name=name or start.__name__, - method_name=start.__name__, + name=name or method_name, + method_name=method_name, input_type=input_type, output_type=output_type, ) diff --git a/temporalio/nexus/handler/_util.py b/temporalio/nexus/handler/_util.py index dce02eab9..b7adc9e2f 100644 --- a/temporalio/nexus/handler/_util.py +++ b/temporalio/nexus/handler/_util.py @@ -66,6 +66,18 @@ def get_workflow_run_start_method_input_and_output_type_annotations( return input_type, output_type +def get_callable_name(fn: Callable[..., Any]) -> str: + method_name = getattr(fn, "__name__", None) + if not method_name and callable(fn) and hasattr(fn, "__call__"): + method_name = fn.__class__.__name__ + if not method_name: + raise TypeError( + f"Could not determine callable name: " + f"expected {fn} to be a function or callable instance." + ) + return method_name + + # Copied from https://github.com/modelcontextprotocol/python-sdk # # Copyright (c) 2024 Anthropic, PBC. From ec1f05ad89c4223ed211d2c9e014eaa4c293f217 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 22:29:42 -0400 Subject: [PATCH 076/183] Fix test: 'not an async def` message changed --- tests/nexus/test_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index e0740c039..e3b3cc0de 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -946,7 +946,7 @@ class SyncHandlerNoExecutor(_InstantiationCase): handler = SyncStartHandler executor = False exception = RuntimeError - match = "is not an `async def` method" + match = "must be an `async def`" class DefaultCancel(_InstantiationCase): From c6a8e32bee010ca722d87f07ae540b62373917e8 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 22:36:21 -0400 Subject: [PATCH 077/183] Refactor --- temporalio/nexus/handler/_operation_handlers.py | 2 +- temporalio/nexus/handler/_token.py | 15 +++++++-------- temporalio/nexus/handler/_workflow.py | 4 +++- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/handler/_operation_handlers.py index 179e681f9..ee899e3b0 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/handler/_operation_handlers.py @@ -168,7 +168,7 @@ async def cancel_operation( ctx = temporal_operation_context.get() try: - handle = workflow_token.to_workflow_handle(ctx.client) + handle = workflow_token._to_client_workflow_handle(ctx.client) except Exception as err: raise HandlerError( "Failed to construct workflow handle from workflow operation token", diff --git a/temporalio/nexus/handler/_token.py b/temporalio/nexus/handler/_token.py index 487e4f18e..fb7b1852a 100644 --- a/temporalio/nexus/handler/_token.py +++ b/temporalio/nexus/handler/_token.py @@ -3,12 +3,11 @@ import base64 import json from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Type +from typing import Any, Generic, Literal, Optional, Type from nexusrpc.types import OutputT -if TYPE_CHECKING: - from temporalio.client import Client, WorkflowHandle +from temporalio import client OPERATION_TOKEN_TYPE_WORKFLOW = 1 OperationTokenType = Literal[1] @@ -25,9 +24,9 @@ class WorkflowOperationToken(Generic[OutputT]): # serialized token; it's only used to reject newer token versions on load. version: Optional[int] = None - def to_workflow_handle( - self, client: Client, result_type: Optional[Type[OutputT]] = None - ) -> WorkflowHandle[Any, OutputT]: + def _to_client_workflow_handle( + self, client: client.Client, result_type: Optional[Type[OutputT]] = None + ) -> client.WorkflowHandle[Any, OutputT]: """Create a :py:class:`temporalio.client.WorkflowHandle` from the token.""" if client.namespace != self.namespace: raise ValueError( @@ -39,8 +38,8 @@ def to_workflow_handle( # TODO(nexus-preview): The return type here should be dictated by the input workflow # handle type. @classmethod - def _unsafe_from_workflow_handle( - cls, workflow_handle: WorkflowHandle[Any, OutputT] + def _unsafe_from_client_workflow_handle( + cls, workflow_handle: client.WorkflowHandle[Any, OutputT] ) -> WorkflowOperationToken[OutputT]: """Create a :py:class:`WorkflowOperationToken` from a workflow handle. diff --git a/temporalio/nexus/handler/_workflow.py b/temporalio/nexus/handler/_workflow.py index f2da5a27e..c70276bc1 100644 --- a/temporalio/nexus/handler/_workflow.py +++ b/temporalio/nexus/handler/_workflow.py @@ -134,4 +134,6 @@ async def start_workflow( start_operation_context.add_outbound_links(wf_handle) - return WorkflowOperationToken[ReturnType]._unsafe_from_workflow_handle(wf_handle) + return WorkflowOperationToken[ReturnType]._unsafe_from_client_workflow_handle( + wf_handle + ) From 63d19b22c72b82a5ce21702672e56c453dc49c4b Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 22:43:35 -0400 Subject: [PATCH 078/183] Reorganize: temporalio.nexus.handler -> temporalo.nexus --- temporalio/nexus/__init__.py | 41 +++++++++++++++++++ temporalio/nexus/{handler => }/_decorators.py | 6 +-- .../nexus/{handler => }/_operation_context.py | 0 .../{handler => }/_operation_handlers.py | 8 ++-- temporalio/nexus/{handler => }/_token.py | 0 temporalio/nexus/{handler => }/_util.py | 0 temporalio/nexus/{handler => }/_workflow.py | 10 ++--- temporalio/nexus/handler/__init__.py | 40 ------------------ temporalio/worker/_interceptor.py | 4 +- temporalio/worker/_nexus.py | 2 +- temporalio/worker/_workflow_instance.py | 4 +- temporalio/workflow.py | 10 ++--- tests/nexus/test_handler.py | 5 +-- .../test_handler_interface_implementation.py | 2 +- .../test_handler_operation_definitions.py | 2 +- tests/nexus/test_workflow_caller.py | 6 +-- tests/nexus/test_workflow_run_operation.py | 4 +- 17 files changed, 72 insertions(+), 72 deletions(-) rename temporalio/nexus/{handler => }/_decorators.py (95%) rename temporalio/nexus/{handler => }/_operation_context.py (100%) rename temporalio/nexus/{handler => }/_operation_handlers.py (95%) rename temporalio/nexus/{handler => }/_token.py (100%) rename temporalio/nexus/{handler => }/_util.py (100%) rename temporalio/nexus/{handler => }/_workflow.py (92%) delete mode 100644 temporalio/nexus/handler/__init__.py diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index 1b868c610..e2079da26 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -1,3 +1,44 @@ +import logging +from typing import ( + Any, + Mapping, + MutableMapping, + Optional, +) + +from ._decorators import ( + workflow_run_operation_handler as workflow_run_operation_handler, +) +from ._operation_context import ( + _TemporalNexusOperationContext as _TemporalNexusOperationContext, +) +from ._operation_context import ( + temporal_operation_context as temporal_operation_context, +) +from ._operation_handlers import cancel_operation as cancel_operation +from ._token import WorkflowOperationToken as WorkflowOperationToken +from ._workflow import start_workflow as start_workflow + + +class LoggerAdapter(logging.LoggerAdapter): + def __init__(self, logger: logging.Logger, extra: Optional[Mapping[str, Any]]): + super().__init__(logger, extra or {}) + + def process( + self, msg: Any, kwargs: MutableMapping[str, Any] + ) -> tuple[Any, MutableMapping[str, Any]]: + extra = dict(self.extra or {}) + if tctx := temporal_operation_context.get(None): + extra["service"] = tctx.nexus_operation_context.service + extra["operation"] = tctx.nexus_operation_context.operation + extra["task_queue"] = tctx.task_queue + kwargs["extra"] = extra | kwargs.get("extra", {}) + return msg, kwargs + + +logger = LoggerAdapter(logging.getLogger(__name__), None) +"""Logger that emits additional data describing the current Nexus operation.""" + # TODO(nexus-prerelease) WARN temporal_sdk_core::worker::nexus: Failed to parse nexus timeout header value '9.155416ms' # 2025-06-25T12:58:05.749589Z WARN temporal_sdk_core::worker::nexus: Failed to parse nexus timeout header value '9.155416ms' # 2025-06-25T12:58:05.763052Z WARN temporal_sdk_core::worker::nexus: Nexus task not found on completion. This may happen if the operation has already been cancelled but completed anyway. details=Status { code: NotFound, message: "Nexus task not found or already expired", details: b"\x08\x05\x12'Nexus task not found or already expired\x1aB\n@type.googleapis.com/temporal.api.errordetails.v1.NotFoundFailure", metadata: MetadataMap { headers: {"content-type": "application/grpc"} }, source: None } diff --git a/temporalio/nexus/handler/_decorators.py b/temporalio/nexus/_decorators.py similarity index 95% rename from temporalio/nexus/handler/_decorators.py rename to temporalio/nexus/_decorators.py index 461a6f3d9..a07bf1150 100644 --- a/temporalio/nexus/handler/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -15,13 +15,13 @@ ) from nexusrpc.types import InputT, OutputT, ServiceHandlerT -from temporalio.nexus.handler._operation_handlers import ( +from temporalio.nexus._operation_handlers import ( WorkflowRunOperationHandler, ) -from temporalio.nexus.handler._token import ( +from temporalio.nexus._token import ( WorkflowOperationToken, ) -from temporalio.nexus.handler._util import ( +from temporalio.nexus._util import ( get_callable_name, get_workflow_run_start_method_input_and_output_type_annotations, ) diff --git a/temporalio/nexus/handler/_operation_context.py b/temporalio/nexus/_operation_context.py similarity index 100% rename from temporalio/nexus/handler/_operation_context.py rename to temporalio/nexus/_operation_context.py diff --git a/temporalio/nexus/handler/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py similarity index 95% rename from temporalio/nexus/handler/_operation_handlers.py rename to temporalio/nexus/_operation_handlers.py index ee899e3b0..afda3ebb4 100644 --- a/temporalio/nexus/handler/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -25,10 +25,10 @@ ) from temporalio.client import WorkflowHandle -from temporalio.nexus.handler._operation_context import ( +from temporalio.nexus._operation_context import ( temporal_operation_context, ) -from temporalio.nexus.handler._token import WorkflowOperationToken +from temporalio.nexus._token import WorkflowOperationToken from ._util import ( is_async_callable, @@ -41,7 +41,7 @@ class WorkflowRunOperationHandler(OperationHandler[InputT, OutputT]): Use this class to create an operation handler that starts a workflow by passing your ``start`` method to the constructor. Your ``start`` method must use - :py:func:`temporalio.nexus.handler.start_workflow` to start the workflow. + :py:func:`temporalio.nexus.start_workflow` to start the workflow. Example: @@ -96,7 +96,7 @@ async def start( if isinstance(token, WorkflowHandle): raise RuntimeError( f"Expected {token} to be a WorkflowOperationToken, but got a WorkflowHandle. " - f"You must use :py:meth:`temporalio.nexus.handler.start_workflow` " + f"You must use :py:meth:`temporalio.nexus.start_workflow` " "to start a workflow that will deliver the result of the Nexus operation, " "not :py:meth:`temporalio.client.Client.start_workflow`." ) diff --git a/temporalio/nexus/handler/_token.py b/temporalio/nexus/_token.py similarity index 100% rename from temporalio/nexus/handler/_token.py rename to temporalio/nexus/_token.py diff --git a/temporalio/nexus/handler/_util.py b/temporalio/nexus/_util.py similarity index 100% rename from temporalio/nexus/handler/_util.py rename to temporalio/nexus/_util.py diff --git a/temporalio/nexus/handler/_workflow.py b/temporalio/nexus/_workflow.py similarity index 92% rename from temporalio/nexus/handler/_workflow.py rename to temporalio/nexus/_workflow.py index c70276bc1..d022cecec 100644 --- a/temporalio/nexus/handler/_workflow.py +++ b/temporalio/nexus/_workflow.py @@ -12,8 +12,8 @@ import temporalio.api.common.v1 import temporalio.api.enums.v1 import temporalio.common -from temporalio.nexus.handler._operation_context import temporal_operation_context -from temporalio.nexus.handler._token import WorkflowOperationToken +from temporalio.nexus._operation_context import temporal_operation_context +from temporalio.nexus._token import WorkflowOperationToken from temporalio.types import ( MethodAsyncSingleParam, ParamType, @@ -63,8 +63,8 @@ async def start_workflow( See :py:meth:`temporalio.client.Client.start_workflow` for all arguments. - The return value is :py:class:`temporalio.nexus.handler.WorkflowOperationToken`. - Use :py:meth:`temporalio.nexus.handler.WorkflowOperationToken.to_workflow_handle` + The return value is :py:class:`temporalio.nexus.WorkflowOperationToken`. + Use :py:meth:`temporalio.nexus.WorkflowOperationToken.to_workflow_handle` to get a :py:class:`temporalio.client.WorkflowHandle` for interacting with the workflow. @@ -87,7 +87,7 @@ async def start_workflow( start_operation_context = ctx._temporal_start_operation_context if not start_operation_context: raise RuntimeError( - "temporalio.nexus.handler.start_workflow() must be called from " + "temporalio.nexus.start_workflow() must be called from " "within a Nexus start operation context" ) diff --git a/temporalio/nexus/handler/__init__.py b/temporalio/nexus/handler/__init__.py deleted file mode 100644 index 86584ae0a..000000000 --- a/temporalio/nexus/handler/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging -from typing import ( - Any, - Mapping, - MutableMapping, - Optional, -) - -from ._decorators import ( - workflow_run_operation_handler as workflow_run_operation_handler, -) -from ._operation_context import ( - _TemporalNexusOperationContext as _TemporalNexusOperationContext, -) -from ._operation_context import ( - temporal_operation_context as temporal_operation_context, -) -from ._operation_handlers import cancel_operation as cancel_operation -from ._token import WorkflowOperationToken as WorkflowOperationToken -from ._workflow import start_workflow as start_workflow - - -class LoggerAdapter(logging.LoggerAdapter): - def __init__(self, logger: logging.Logger, extra: Optional[Mapping[str, Any]]): - super().__init__(logger, extra or {}) - - def process( - self, msg: Any, kwargs: MutableMapping[str, Any] - ) -> tuple[Any, MutableMapping[str, Any]]: - extra = dict(self.extra or {}) - if tctx := temporal_operation_context.get(None): - extra["service"] = tctx.nexus_operation_context.service - extra["operation"] = tctx.nexus_operation_context.operation - extra["task_queue"] = tctx.task_queue - kwargs["extra"] = extra | kwargs.get("extra", {}) - return msg, kwargs - - -logger = LoggerAdapter(logging.getLogger(__name__), None) -"""Logger that emits additional data describing the current Nexus operation.""" diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index f6d7672a9..db624be00 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -27,7 +27,7 @@ import temporalio.activity import temporalio.api.common.v1 import temporalio.common -import temporalio.nexus.handler +import temporalio.nexus import temporalio.workflow from temporalio.workflow import VersioningIntent @@ -302,7 +302,7 @@ class StartNexusOperationInput(Generic[InputT, OutputT]): Callable[[Any], nexusrpc.handler.OperationHandler[InputT, OutputT]], Callable[ [Any, nexusrpc.handler.StartOperationContext, InputT], - temporalio.nexus.handler.WorkflowOperationToken[OutputT], + temporalio.nexus.WorkflowOperationToken[OutputT], ], str, ] diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 705473f01..4f1ff4df4 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -31,7 +31,7 @@ import temporalio.converter import temporalio.nexus from temporalio.exceptions import ApplicationError -from temporalio.nexus.handler import ( +from temporalio.nexus import ( _TemporalNexusOperationContext, logger, temporal_operation_context, diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 5f1e1db5f..df627a532 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -59,7 +59,7 @@ import temporalio.common import temporalio.converter import temporalio.exceptions -import temporalio.nexus.handler +import temporalio.nexus import temporalio.workflow from temporalio.service import __version__ @@ -1503,7 +1503,7 @@ async def workflow_start_nexus_operation( Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], Callable[ [Any, nexusrpc.handler.StartOperationContext, I], - temporalio.nexus.handler.WorkflowOperationToken[O], + temporalio.nexus.WorkflowOperationToken[O], ], str, ], diff --git a/temporalio/workflow.py b/temporalio/workflow.py index cceb24d66..23c0657e6 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -56,7 +56,7 @@ import temporalio.common import temporalio.converter import temporalio.exceptions -import temporalio.nexus.handler +import temporalio.nexus import temporalio.workflow from .types import ( @@ -859,7 +859,7 @@ async def workflow_start_nexus_operation( Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], Callable[ [S, nexusrpc.handler.StartOperationContext, I], - temporalio.nexus.handler.WorkflowOperationToken[O], + temporalio.nexus.WorkflowOperationToken[O], ], str, ], @@ -4429,7 +4429,7 @@ async def start_nexus_operation( Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], Callable[ [S, nexusrpc.handler.StartOperationContext, I], - temporalio.nexus.handler.WorkflowOperationToken[O], + temporalio.nexus.WorkflowOperationToken[O], ], str, ], @@ -5215,7 +5215,7 @@ async def start_operation( Callable[[S], nexusrpc.handler.OperationHandler[I, O]], Callable[ [S, nexusrpc.handler.StartOperationContext, I], - temporalio.nexus.handler.WorkflowOperationToken[O], + temporalio.nexus.WorkflowOperationToken[O], ], str, ], @@ -5243,7 +5243,7 @@ async def execute_operation( Callable[[S], nexusrpc.handler.OperationHandler[I, O]], Callable[ [S, nexusrpc.handler.StartOperationContext, I], - temporalio.nexus.handler.WorkflowOperationToken[O], + temporalio.nexus.WorkflowOperationToken[O], ], str, ], diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index e3b3cc0de..bbbe1bc02 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -46,13 +46,12 @@ ) import temporalio.api.failure.v1 -import temporalio.nexus.handler from temporalio import workflow from temporalio.client import Client from temporalio.common import WorkflowIDReusePolicy from temporalio.converter import FailureConverter, PayloadConverter from temporalio.exceptions import ApplicationError -from temporalio.nexus.handler import ( +from temporalio.nexus import ( WorkflowOperationToken, logger, start_workflow, @@ -869,7 +868,7 @@ async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: A ( record for record in caplog.records - if record.name == "temporalio.nexus.handler" + if record.name == "temporalio.nexus" and record.getMessage() == "Logging from start method" ), None, diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index 6e00d4838..1f849f79b 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -5,7 +5,7 @@ import pytest from nexusrpc.handler import StartOperationContext, sync_operation_handler -from temporalio.nexus.handler import ( +from temporalio.nexus import ( WorkflowOperationToken, workflow_run_operation_handler, ) diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index 7e29b4680..4cee0f80a 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -10,7 +10,7 @@ import pytest from nexusrpc.handler import StartOperationContext -from temporalio.nexus.handler import ( +from temporalio.nexus import ( WorkflowOperationToken, workflow_run_operation_handler, ) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index d23b6080b..38aacadb1 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -29,7 +29,6 @@ import temporalio.api.operatorservice import temporalio.api.operatorservice.v1 import temporalio.nexus -import temporalio.nexus.handler from temporalio import workflow from temporalio.client import ( Client, @@ -40,8 +39,9 @@ ) from temporalio.common import WorkflowIDConflictPolicy from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError -from temporalio.nexus.handler import ( +from temporalio.nexus import ( WorkflowOperationToken, + cancel_operation, start_workflow, temporal_operation_context, workflow_run_operation_handler, @@ -174,7 +174,7 @@ async def start( raise TypeError async def cancel(self, ctx: CancelOperationContext, token: str) -> None: - return await temporalio.nexus.handler.cancel_operation(token) + return await cancel_operation(token) async def fetch_info( self, ctx: FetchOperationInfoContext, token: str diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index f6409aae2..f98673fa7 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -13,8 +13,8 @@ ) from temporalio import workflow -from temporalio.nexus.handler import start_workflow -from temporalio.nexus.handler._operation_handlers import WorkflowRunOperationHandler +from temporalio.nexus import start_workflow +from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ServiceClient, create_nexus_endpoint, dataclass_as_dict From b0c118057ec3d1e34ed604840a715ae83664fe38 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 22:49:21 -0400 Subject: [PATCH 079/183] Fix signatures of start_method on workflow caller side --- temporalio/worker/_interceptor.py | 2 +- temporalio/worker/_workflow_instance.py | 2 +- temporalio/workflow.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index db624be00..6e703033d 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -302,7 +302,7 @@ class StartNexusOperationInput(Generic[InputT, OutputT]): Callable[[Any], nexusrpc.handler.OperationHandler[InputT, OutputT]], Callable[ [Any, nexusrpc.handler.StartOperationContext, InputT], - temporalio.nexus.WorkflowOperationToken[OutputT], + Awaitable[temporalio.nexus.WorkflowOperationToken[OutputT]], ], str, ] diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index df627a532..98fcbadaf 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1503,7 +1503,7 @@ async def workflow_start_nexus_operation( Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], Callable[ [Any, nexusrpc.handler.StartOperationContext, I], - temporalio.nexus.WorkflowOperationToken[O], + Awaitable[temporalio.nexus.WorkflowOperationToken[O]], ], str, ], diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 23c0657e6..fe7752c9f 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -859,7 +859,7 @@ async def workflow_start_nexus_operation( Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], Callable[ [S, nexusrpc.handler.StartOperationContext, I], - temporalio.nexus.WorkflowOperationToken[O], + Awaitable[temporalio.nexus.WorkflowOperationToken[O]], ], str, ], @@ -4429,7 +4429,7 @@ async def start_nexus_operation( Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], Callable[ [S, nexusrpc.handler.StartOperationContext, I], - temporalio.nexus.WorkflowOperationToken[O], + Awaitable[temporalio.nexus.WorkflowOperationToken[O]], ], str, ], @@ -5215,7 +5215,7 @@ async def start_operation( Callable[[S], nexusrpc.handler.OperationHandler[I, O]], Callable[ [S, nexusrpc.handler.StartOperationContext, I], - temporalio.nexus.WorkflowOperationToken[O], + Awaitable[temporalio.nexus.WorkflowOperationToken[O]], ], str, ], @@ -5243,7 +5243,7 @@ async def execute_operation( Callable[[S], nexusrpc.handler.OperationHandler[I, O]], Callable[ [S, nexusrpc.handler.StartOperationContext, I], - temporalio.nexus.WorkflowOperationToken[O], + Awaitable[temporalio.nexus.WorkflowOperationToken[O]], ], str, ], From 9ab2d19f2c964f8ca95e98104cc3a360aadb422a Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 22:54:37 -0400 Subject: [PATCH 080/183] `from temporalio import nexus` everywhere --- tests/nexus/test_handler.py | 43 ++++++++----------- .../test_handler_interface_implementation.py | 9 ++-- .../test_handler_operation_definitions.py | 17 +++----- tests/nexus/test_workflow_caller.py | 25 ++++------- tests/nexus/test_workflow_run_operation.py | 5 +-- 5 files changed, 40 insertions(+), 59 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index bbbe1bc02..65731fbb7 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -46,18 +46,11 @@ ) import temporalio.api.failure.v1 -from temporalio import workflow +from temporalio import nexus, workflow from temporalio.client import Client from temporalio.common import WorkflowIDReusePolicy from temporalio.converter import FailureConverter, PayloadConverter from temporalio.exceptions import ApplicationError -from temporalio.nexus import ( - WorkflowOperationToken, - logger, - start_workflow, - temporal_operation_context, - workflow_run_operation_handler, -) from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ServiceClient, create_nexus_endpoint, dataclass_as_dict @@ -208,14 +201,16 @@ async def check_operation_timeout_header( @sync_operation_handler async def log(self, ctx: StartOperationContext, input: Input) -> Output: - logger.info("Logging from start method", extra={"input_value": input.value}) + nexus.logger.info( + "Logging from start method", extra={"input_value": input.value} + ) return Output(value=f"logged: {input.value}") - @workflow_run_operation_handler + @nexus.workflow_run_operation_handler async def workflow_run_operation( self, ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: - return await start_workflow( + ) -> nexus.WorkflowOperationToken[Output]: + return await nexus.start_workflow( MyWorkflow.run, input, id=str(uuid.uuid4()), @@ -262,25 +257,25 @@ async def sync_operation_without_type_annotations( value=f"from start method on {self.__class__.__name__} without type annotations: {input}" ) - @workflow_run_operation_handler + @nexus.workflow_run_operation_handler async def workflow_run_operation_without_type_annotations(self, ctx, input): - return await start_workflow( + return await nexus.start_workflow( WorkflowWithoutTypeAnnotations.run, input, id=str(uuid.uuid4()), ) - @workflow_run_operation_handler + @nexus.workflow_run_operation_handler async def workflow_run_op_link_test( self, ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: + ) -> nexus.WorkflowOperationToken[Output]: assert any( link.url == "http://inbound-link/" for link in ctx.inbound_links ), "Inbound link not found" assert ctx.request_id == "test-request-id-123", "Request ID mismatch" ctx.outbound_links.extend(ctx.inbound_links) - return await start_workflow( + return await nexus.start_workflow( MyLinkTestWorkflow.run, input, id=str(uuid.uuid4()), @@ -1054,22 +1049,22 @@ async def run(self, input: Input) -> Output: @service_handler class ServiceHandlerForRequestIdTest: - @workflow_run_operation_handler + @nexus.workflow_run_operation_handler async def operation_backed_by_a_workflow( self, ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: - return await start_workflow( + ) -> nexus.WorkflowOperationToken[Output]: + return await nexus.start_workflow( EchoWorkflow.run, input, id=input.value, id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) - @workflow_run_operation_handler + @nexus.workflow_run_operation_handler async def operation_that_executes_a_workflow_before_starting_the_backing_workflow( self, ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: - tctx = temporal_operation_context.get() + ) -> nexus.WorkflowOperationToken[Output]: + tctx = nexus.temporal_operation_context.get() await tctx.client.start_workflow( EchoWorkflow.run, input, @@ -1078,7 +1073,7 @@ async def operation_that_executes_a_workflow_before_starting_the_backing_workflo ) # This should fail. It will not fail if the Nexus request ID was incorrectly # propagated to both StartWorkflow requests. - return await start_workflow( + return await nexus.start_workflow( EchoWorkflow.run, input, id=input.value, diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index 1f849f79b..b1e37dda9 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -5,10 +5,7 @@ import pytest from nexusrpc.handler import StartOperationContext, sync_operation_handler -from temporalio.nexus import ( - WorkflowOperationToken, - workflow_run_operation_handler, -) +from temporalio import nexus HTTP_PORT = 7243 @@ -37,10 +34,10 @@ class Interface: op: nexusrpc.Operation[str, int] class Impl: - @workflow_run_operation_handler + @nexus.workflow_run_operation_handler async def op( self, ctx: StartOperationContext, input: str - ) -> WorkflowOperationToken[int]: ... + ) -> nexus.WorkflowOperationToken[int]: ... error_message = None diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index 4cee0f80a..61912988e 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -10,10 +10,7 @@ import pytest from nexusrpc.handler import StartOperationContext -from temporalio.nexus import ( - WorkflowOperationToken, - workflow_run_operation_handler, -) +from temporalio import nexus @dataclass @@ -35,10 +32,10 @@ class _TestCase: class NotCalled(_TestCase): @nexusrpc.handler.service_handler class Service: - @workflow_run_operation_handler + @nexus.workflow_run_operation_handler async def my_workflow_run_operation_handler( self, ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: ... + ) -> nexus.WorkflowOperationToken[Output]: ... expected_operations = { "my_workflow_run_operation_handler": nexusrpc.Operation( @@ -53,10 +50,10 @@ async def my_workflow_run_operation_handler( class CalledWithoutArgs(_TestCase): @nexusrpc.handler.service_handler class Service: - @workflow_run_operation_handler + @nexus.workflow_run_operation_handler async def my_workflow_run_operation_handler( self, ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: ... + ) -> nexus.WorkflowOperationToken[Output]: ... expected_operations = NotCalled.expected_operations @@ -64,10 +61,10 @@ async def my_workflow_run_operation_handler( class CalledWithNameOverride(_TestCase): @nexusrpc.handler.service_handler class Service: - @workflow_run_operation_handler(name="operation-name") + @nexus.workflow_run_operation_handler(name="operation-name") async def workflow_run_operation_with_name_override( self, ctx: StartOperationContext, input: Input - ) -> WorkflowOperationToken[Output]: ... + ) -> nexus.WorkflowOperationToken[Output]: ... expected_operations = { "workflow_run_operation_with_name_override": nexusrpc.Operation( diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 38aacadb1..18e540d28 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -29,7 +29,7 @@ import temporalio.api.operatorservice import temporalio.api.operatorservice.v1 import temporalio.nexus -from temporalio import workflow +from temporalio import nexus, workflow from temporalio.client import ( Client, WithStartWorkflowOperation, @@ -39,13 +39,6 @@ ) from temporalio.common import WorkflowIDConflictPolicy from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError -from temporalio.nexus import ( - WorkflowOperationToken, - cancel_operation, - start_workflow, - temporal_operation_context, - workflow_run_operation_handler, -) from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -208,10 +201,10 @@ async def sync_operation( ) return OpOutput(value="sync response") - @workflow_run_operation_handler + @nexus.workflow_run_operation_handler async def async_operation( self, ctx: StartOperationContext, input: OpInput - ) -> WorkflowOperationToken[HandlerWfOutput]: + ) -> nexus.WorkflowOperationToken[HandlerWfOutput]: assert isinstance(input.response_type, AsyncResponse) if input.response_type.exception_in_operation_start: raise RPCError( @@ -580,9 +573,9 @@ async def test_async_response( if op_definition_type == OpDefinitionType.SHORTHAND else "sync_or_async_operation" ) - assert WorkflowOperationToken.decode( + assert nexus.WorkflowOperationToken.decode( e.__cause__.operation_token - ) == WorkflowOperationToken( + ) == nexus.WorkflowOperationToken( namespace=handler_wf_handle._client.namespace, workflow_id=handler_wf_handle.id, ) @@ -919,11 +912,11 @@ async def run(self, input: str) -> str: @service_handler class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: - @workflow_run_operation_handler + @nexus.workflow_run_operation_handler async def my_workflow_run_operation( self, ctx: StartOperationContext, input: None - ) -> WorkflowOperationToken[str]: - tctx = temporal_operation_context.get() + ) -> nexus.WorkflowOperationToken[str]: + tctx = nexus.temporal_operation_context.get() result_1 = await tctx.client.execute_workflow( EchoWorkflow.run, "result-1", @@ -933,7 +926,7 @@ async def my_workflow_run_operation( # In case result_1 is incorrectly being delivered to the caller as the operation # result, give time for that incorrect behavior to occur. await asyncio.sleep(0.5) - return await start_workflow( + return await nexus.start_workflow( EchoWorkflow.run, f"{result_1}-result-2", id=str(uuid.uuid4()), diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index f98673fa7..1fcd3b976 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -12,8 +12,7 @@ service_handler, ) -from temporalio import workflow -from temporalio.nexus import start_workflow +from temporalio import nexus, workflow from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker @@ -44,7 +43,7 @@ def __init__(self): async def start( self, ctx: StartOperationContext, input: Input ) -> StartOperationResultAsync: - token = await start_workflow( + token = await nexus.start_workflow( EchoWorkflow.run, input.value, id=str(uuid.uuid4()), From 19968d7910996eab67e2fb56eac0c17719f7021a Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 23:00:56 -0400 Subject: [PATCH 081/183] Qualify client.WorkflowHandle in temporalio.nexus --- temporalio/nexus/_operation_context.py | 12 ++++++------ temporalio/nexus/_operation_handlers.py | 10 +++++----- temporalio/nexus/_token.py | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 3f811b608..52d17e643 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -17,7 +17,7 @@ import temporalio.api.common.v1 import temporalio.api.enums.v1 import temporalio.common -from temporalio.client import Client, NexusCompletionCallback, WorkflowHandle +from temporalio import client logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ class _TemporalNexusOperationContext: nexus_operation_context: Union[StartOperationContext, CancelOperationContext] - client: Client + client: client.Client """The Temporal client in use by the worker handling this Nexus operation.""" task_queue: str @@ -67,7 +67,7 @@ class _TemporalStartOperationContext: def get_completion_callbacks( self, - ) -> list[NexusCompletionCallback]: + ) -> list[client.NexusCompletionCallback]: ctx = self.nexus_operation_context return ( [ @@ -76,7 +76,7 @@ def get_completion_callbacks( # StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links # (for backwards compatibility). PR reference in Go SDK: # https://github.com/temporalio/sdk-go/pull/1945 - NexusCompletionCallback( + client.NexusCompletionCallback( url=ctx.callback_url, header=ctx.callback_headers, ) @@ -94,7 +94,7 @@ def get_workflow_event_links( event_links.append(link) return event_links - def add_outbound_links(self, workflow_handle: WorkflowHandle[Any, Any]): + def add_outbound_links(self, workflow_handle: client.WorkflowHandle[Any, Any]): try: link = _workflow_event_to_nexus_link( _workflow_handle_to_workflow_execution_started_event_link( @@ -124,7 +124,7 @@ class _TemporalCancelOperationContext: def _workflow_handle_to_workflow_execution_started_event_link( - handle: WorkflowHandle[Any, Any], + handle: client.WorkflowHandle[Any, Any], ) -> temporalio.api.common.v1.Link.WorkflowEvent: if handle.first_execution_run_id is None: raise ValueError( diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index afda3ebb4..3bc9c8a03 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -24,7 +24,7 @@ OutputT, ) -from temporalio.client import WorkflowHandle +from temporalio import client from temporalio.nexus._operation_context import ( temporal_operation_context, ) @@ -93,12 +93,12 @@ async def start( """ token = await self._start(ctx, input) if not isinstance(token, WorkflowOperationToken): - if isinstance(token, WorkflowHandle): + if isinstance(token, client.WorkflowHandle): raise RuntimeError( - f"Expected {token} to be a WorkflowOperationToken, but got a WorkflowHandle. " - f"You must use :py:meth:`temporalio.nexus.start_workflow` " + f"Expected {token} to be a WorkflowOperationToken, but got a client.WorkflowHandle. " + f"You must use temporalio.nexus.start_workflow " "to start a workflow that will deliver the result of the Nexus operation, " - "not :py:meth:`temporalio.client.Client.start_workflow`." + "not client.Client.start_workflow." ) raise RuntimeError( f"Expected {token} to be a WorkflowOperationToken, but got {type(token)}. " diff --git a/temporalio/nexus/_token.py b/temporalio/nexus/_token.py index fb7b1852a..fde9fa5cf 100644 --- a/temporalio/nexus/_token.py +++ b/temporalio/nexus/_token.py @@ -44,7 +44,7 @@ def _unsafe_from_client_workflow_handle( """Create a :py:class:`WorkflowOperationToken` from a workflow handle. This is a private method not intended to be used by users. It does not check - that the supplied WorkflowHandle references a workflow that has been + that the supplied client.WorkflowHandle references a workflow that has been instrumented to supply the result of a Nexus operation. """ return cls( From fb3238bff8e78136656d61485c94009fdf70300d Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 23:06:01 -0400 Subject: [PATCH 082/183] Fixup: no coverage for these --- tests/nexus/test_workflow_caller.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 18e540d28..8bdb9c8bd 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -28,7 +28,6 @@ import temporalio.api.nexus.v1 import temporalio.api.operatorservice import temporalio.api.operatorservice.v1 -import temporalio.nexus from temporalio import nexus, workflow from temporalio.client import ( Client, @@ -157,7 +156,7 @@ async def start( if isinstance(input.response_type, SyncResponse): return StartOperationResultSync(value=OpOutput(value="sync response")) elif isinstance(input.response_type, AsyncResponse): - token = await start_workflow( + token = await nexus.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, @@ -167,7 +166,7 @@ async def start( raise TypeError async def cancel(self, ctx: CancelOperationContext, token: str) -> None: - return await cancel_operation(token) + return await nexus.cancel_operation(token) async def fetch_info( self, ctx: FetchOperationInfoContext, token: str @@ -212,7 +211,7 @@ async def async_operation( RPCStatusCode.INVALID_ARGUMENT, b"", ) - return await start_workflow( + return await nexus.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, From 0cc4359b9ce933a5c7d2ec1713c32e65720fb999 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 23:05:39 -0400 Subject: [PATCH 083/183] Rename: nexus.WorkflowHandle --- temporalio/nexus/__init__.py | 2 +- temporalio/nexus/_decorators.py | 24 +++++++++---------- temporalio/nexus/_operation_handlers.py | 10 ++++---- temporalio/nexus/_token.py | 10 ++++---- temporalio/nexus/_util.py | 8 +++---- temporalio/nexus/_workflow.py | 8 +++---- temporalio/worker/_interceptor.py | 2 +- temporalio/worker/_workflow_instance.py | 2 +- tests/nexus/test_handler.py | 8 +++---- .../test_handler_interface_implementation.py | 2 +- .../test_handler_operation_definitions.py | 6 ++--- tests/nexus/test_workflow_caller.py | 8 +++---- 12 files changed, 44 insertions(+), 46 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index e2079da26..0bdf56e5c 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -16,7 +16,7 @@ temporal_operation_context as temporal_operation_context, ) from ._operation_handlers import cancel_operation as cancel_operation -from ._token import WorkflowOperationToken as WorkflowOperationToken +from ._token import WorkflowHandle as WorkflowHandle from ._workflow import start_workflow as start_workflow diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index a07bf1150..6f274495c 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -19,7 +19,7 @@ WorkflowRunOperationHandler, ) from temporalio.nexus._token import ( - WorkflowOperationToken, + WorkflowHandle, ) from temporalio.nexus._util import ( get_callable_name, @@ -31,11 +31,11 @@ def workflow_run_operation_handler( start: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], + Awaitable[WorkflowHandle[OutputT]], ], ) -> Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], + Awaitable[WorkflowHandle[OutputT]], ]: ... @@ -47,12 +47,12 @@ def workflow_run_operation_handler( [ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], + Awaitable[WorkflowHandle[OutputT]], ] ], Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], + Awaitable[WorkflowHandle[OutputT]], ], ]: ... @@ -61,7 +61,7 @@ def workflow_run_operation_handler( start: Optional[ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], + Awaitable[WorkflowHandle[OutputT]], ] ] = None, *, @@ -69,18 +69,18 @@ def workflow_run_operation_handler( ) -> Union[ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], + Awaitable[WorkflowHandle[OutputT]], ], Callable[ [ Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], + Awaitable[WorkflowHandle[OutputT]], ] ], Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], + Awaitable[WorkflowHandle[OutputT]], ], ], ]: @@ -91,11 +91,11 @@ def workflow_run_operation_handler( def decorator( start: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], + Awaitable[WorkflowHandle[OutputT]], ], ) -> Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], + Awaitable[WorkflowHandle[OutputT]], ]: ( input_type, @@ -107,7 +107,7 @@ def operation_handler_factory( ) -> OperationHandler[InputT, OutputT]: async def _start( ctx: StartOperationContext, input: InputT - ) -> WorkflowOperationToken[OutputT]: + ) -> WorkflowHandle[OutputT]: return await start(self, ctx, input) _start.__doc__ = start.__doc__ diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index 3bc9c8a03..bef7c65bf 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -28,7 +28,7 @@ from temporalio.nexus._operation_context import ( temporal_operation_context, ) -from temporalio.nexus._token import WorkflowOperationToken +from temporalio.nexus._token import WorkflowHandle from ._util import ( is_async_callable, @@ -68,7 +68,7 @@ def __init__( self, start: Callable[ [StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], + Awaitable[WorkflowHandle[OutputT]], ], input_type: Optional[Type[InputT]], output_type: Optional[Type[OutputT]], @@ -92,7 +92,7 @@ async def start( Start the operation, by starting a workflow and completing asynchronously. """ token = await self._start(ctx, input) - if not isinstance(token, WorkflowOperationToken): + if not isinstance(token, WorkflowHandle): if isinstance(token, client.WorkflowHandle): raise RuntimeError( f"Expected {token} to be a WorkflowOperationToken, but got a client.WorkflowHandle. " @@ -124,7 +124,7 @@ async def fetch_result( ) # An implementation is provided for future reference: try: - workflow_token = WorkflowOperationToken[OutputT].decode(token) + workflow_token = WorkflowHandle[OutputT].decode(token) except Exception as err: raise HandlerError( "Failed to decode operation token as workflow operation token. " @@ -157,7 +157,7 @@ async def cancel_operation( client: The client to use to cancel the operation. """ try: - workflow_token = WorkflowOperationToken[Any].decode(token) + workflow_token = WorkflowHandle[Any].decode(token) except Exception as err: raise HandlerError( "Failed to decode operation token as workflow operation token. " diff --git a/temporalio/nexus/_token.py b/temporalio/nexus/_token.py index fde9fa5cf..b80d23738 100644 --- a/temporalio/nexus/_token.py +++ b/temporalio/nexus/_token.py @@ -14,8 +14,8 @@ @dataclass(frozen=True) -class WorkflowOperationToken(Generic[OutputT]): - """A Nexus operation token for an operation backed by a workflow.""" +class WorkflowHandle(Generic[OutputT]): + """A handle to a workflow that is backing a Nexus operation.""" namespace: str workflow_id: str @@ -40,8 +40,8 @@ def _to_client_workflow_handle( @classmethod def _unsafe_from_client_workflow_handle( cls, workflow_handle: client.WorkflowHandle[Any, OutputT] - ) -> WorkflowOperationToken[OutputT]: - """Create a :py:class:`WorkflowOperationToken` from a workflow handle. + ) -> WorkflowHandle[OutputT]: + """Create a :py:class:`WorkflowHandle` from a :py:class:`temporalio.client.WorkflowHandle`. This is a private method not intended to be used by users. It does not check that the supplied client.WorkflowHandle references a workflow that has been @@ -65,7 +65,7 @@ def encode(self) -> str: ) @classmethod - def decode(cls, token: str) -> WorkflowOperationToken[OutputT]: + def decode(cls, token: str) -> WorkflowHandle[OutputT]: """Decodes and validates a token from its base64url-encoded string representation.""" if not token: raise TypeError("invalid workflow token: token is empty") diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index b7adc9e2f..2f8277e13 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -23,14 +23,14 @@ ) from ._token import ( - WorkflowOperationToken as WorkflowOperationToken, + WorkflowHandle as WorkflowHandle, ) def get_workflow_run_start_method_input_and_output_type_annotations( start: Callable[ [ServiceHandlerT, StartOperationContext, InputT], - Awaitable[WorkflowOperationToken[OutputT]], + Awaitable[WorkflowHandle[OutputT]], ], ) -> tuple[ Optional[Type[InputT]], @@ -39,13 +39,13 @@ def get_workflow_run_start_method_input_and_output_type_annotations( """Return operation input and output types. `start` must be a type-annotated start method that returns a - :py:class:`WorkflowHandle`. + :py:class:`temporalio.nexus.WorkflowHandle`. """ input_type, output_type = get_start_method_input_and_output_type_annotations(start) origin_type = typing.get_origin(output_type) if not origin_type: output_type = None - elif not issubclass(origin_type, WorkflowOperationToken): + elif not issubclass(origin_type, WorkflowHandle): warnings.warn( f"Expected return type of {start.__name__} to be a subclass of WorkflowOperationToken, " f"but is {output_type}" diff --git a/temporalio/nexus/_workflow.py b/temporalio/nexus/_workflow.py index d022cecec..aa8c28f1c 100644 --- a/temporalio/nexus/_workflow.py +++ b/temporalio/nexus/_workflow.py @@ -13,7 +13,7 @@ import temporalio.api.enums.v1 import temporalio.common from temporalio.nexus._operation_context import temporal_operation_context -from temporalio.nexus._token import WorkflowOperationToken +from temporalio.nexus._token import WorkflowHandle from temporalio.types import ( MethodAsyncSingleParam, ParamType, @@ -54,7 +54,7 @@ async def start_workflow( request_eager_start: bool = False, priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, -) -> WorkflowOperationToken[ReturnType]: +) -> WorkflowHandle[ReturnType]: """Start a workflow that will deliver the result of the Nexus operation. The workflow will be started in the same namespace as the Nexus worker, using @@ -134,6 +134,4 @@ async def start_workflow( start_operation_context.add_outbound_links(wf_handle) - return WorkflowOperationToken[ReturnType]._unsafe_from_client_workflow_handle( - wf_handle - ) + return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 6e703033d..59f8c8671 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -302,7 +302,7 @@ class StartNexusOperationInput(Generic[InputT, OutputT]): Callable[[Any], nexusrpc.handler.OperationHandler[InputT, OutputT]], Callable[ [Any, nexusrpc.handler.StartOperationContext, InputT], - Awaitable[temporalio.nexus.WorkflowOperationToken[OutputT]], + Awaitable[temporalio.nexus.WorkflowHandle[OutputT]], ], str, ] diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 98fcbadaf..cc7397601 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1503,7 +1503,7 @@ async def workflow_start_nexus_operation( Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], Callable[ [Any, nexusrpc.handler.StartOperationContext, I], - Awaitable[temporalio.nexus.WorkflowOperationToken[O]], + Awaitable[temporalio.nexus.WorkflowHandle[O]], ], str, ], diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 65731fbb7..696d64524 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -209,7 +209,7 @@ async def log(self, ctx: StartOperationContext, input: Input) -> Output: @nexus.workflow_run_operation_handler async def workflow_run_operation( self, ctx: StartOperationContext, input: Input - ) -> nexus.WorkflowOperationToken[Output]: + ) -> nexus.WorkflowHandle[Output]: return await nexus.start_workflow( MyWorkflow.run, input, @@ -268,7 +268,7 @@ async def workflow_run_operation_without_type_annotations(self, ctx, input): @nexus.workflow_run_operation_handler async def workflow_run_op_link_test( self, ctx: StartOperationContext, input: Input - ) -> nexus.WorkflowOperationToken[Output]: + ) -> nexus.WorkflowHandle[Output]: assert any( link.url == "http://inbound-link/" for link in ctx.inbound_links ), "Inbound link not found" @@ -1052,7 +1052,7 @@ class ServiceHandlerForRequestIdTest: @nexus.workflow_run_operation_handler async def operation_backed_by_a_workflow( self, ctx: StartOperationContext, input: Input - ) -> nexus.WorkflowOperationToken[Output]: + ) -> nexus.WorkflowHandle[Output]: return await nexus.start_workflow( EchoWorkflow.run, input, @@ -1063,7 +1063,7 @@ async def operation_backed_by_a_workflow( @nexus.workflow_run_operation_handler async def operation_that_executes_a_workflow_before_starting_the_backing_workflow( self, ctx: StartOperationContext, input: Input - ) -> nexus.WorkflowOperationToken[Output]: + ) -> nexus.WorkflowHandle[Output]: tctx = nexus.temporal_operation_context.get() await tctx.client.start_workflow( EchoWorkflow.run, diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index b1e37dda9..b1fdefa2f 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -37,7 +37,7 @@ class Impl: @nexus.workflow_run_operation_handler async def op( self, ctx: StartOperationContext, input: str - ) -> nexus.WorkflowOperationToken[int]: ... + ) -> nexus.WorkflowHandle[int]: ... error_message = None diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index 61912988e..a2288a37c 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -35,7 +35,7 @@ class Service: @nexus.workflow_run_operation_handler async def my_workflow_run_operation_handler( self, ctx: StartOperationContext, input: Input - ) -> nexus.WorkflowOperationToken[Output]: ... + ) -> nexus.WorkflowHandle[Output]: ... expected_operations = { "my_workflow_run_operation_handler": nexusrpc.Operation( @@ -53,7 +53,7 @@ class Service: @nexus.workflow_run_operation_handler async def my_workflow_run_operation_handler( self, ctx: StartOperationContext, input: Input - ) -> nexus.WorkflowOperationToken[Output]: ... + ) -> nexus.WorkflowHandle[Output]: ... expected_operations = NotCalled.expected_operations @@ -64,7 +64,7 @@ class Service: @nexus.workflow_run_operation_handler(name="operation-name") async def workflow_run_operation_with_name_override( self, ctx: StartOperationContext, input: Input - ) -> nexus.WorkflowOperationToken[Output]: ... + ) -> nexus.WorkflowHandle[Output]: ... expected_operations = { "workflow_run_operation_with_name_override": nexusrpc.Operation( diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 8bdb9c8bd..28b412cf1 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -203,7 +203,7 @@ async def sync_operation( @nexus.workflow_run_operation_handler async def async_operation( self, ctx: StartOperationContext, input: OpInput - ) -> nexus.WorkflowOperationToken[HandlerWfOutput]: + ) -> nexus.WorkflowHandle[HandlerWfOutput]: assert isinstance(input.response_type, AsyncResponse) if input.response_type.exception_in_operation_start: raise RPCError( @@ -572,9 +572,9 @@ async def test_async_response( if op_definition_type == OpDefinitionType.SHORTHAND else "sync_or_async_operation" ) - assert nexus.WorkflowOperationToken.decode( + assert nexus.WorkflowHandle.decode( e.__cause__.operation_token - ) == nexus.WorkflowOperationToken( + ) == nexus.WorkflowHandle( namespace=handler_wf_handle._client.namespace, workflow_id=handler_wf_handle.id, ) @@ -914,7 +914,7 @@ class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: @nexus.workflow_run_operation_handler async def my_workflow_run_operation( self, ctx: StartOperationContext, input: None - ) -> nexus.WorkflowOperationToken[str]: + ) -> nexus.WorkflowHandle[str]: tctx = nexus.temporal_operation_context.get() result_1 = await tctx.client.execute_workflow( EchoWorkflow.run, From 0c1982bc8fab3c14d956e3bc62256f9aad4a1750 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 25 Jun 2025 23:10:19 -0400 Subject: [PATCH 084/183] nexus.WorkflowHandle.{to,from}_token() --- temporalio/nexus/_operation_handlers.py | 26 ++++++++++++---------- temporalio/nexus/_token.py | 4 ++-- tests/nexus/test_workflow_caller.py | 6 ++--- tests/nexus/test_workflow_run_operation.py | 4 ++-- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index bef7c65bf..06faf246a 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -91,19 +91,19 @@ async def start( """ Start the operation, by starting a workflow and completing asynchronously. """ - token = await self._start(ctx, input) - if not isinstance(token, WorkflowHandle): - if isinstance(token, client.WorkflowHandle): + handle = await self._start(ctx, input) + if not isinstance(handle, WorkflowHandle): + if isinstance(handle, client.WorkflowHandle): raise RuntimeError( - f"Expected {token} to be a WorkflowOperationToken, but got a client.WorkflowHandle. " + f"Expected {handle} to be a WorkflowOperationToken, but got a client.WorkflowHandle. " f"You must use temporalio.nexus.start_workflow " "to start a workflow that will deliver the result of the Nexus operation, " "not client.Client.start_workflow." ) raise RuntimeError( - f"Expected {token} to be a WorkflowOperationToken, but got {type(token)}. " + f"Expected {handle} to be a WorkflowOperationToken, but got {type(handle)}. " ) - return StartOperationResultAsync(token.encode()) + return StartOperationResultAsync(handle.to_token()) async def cancel(self, ctx: CancelOperationContext, token: str) -> None: """Cancel the operation, by cancelling the workflow.""" @@ -124,7 +124,7 @@ async def fetch_result( ) # An implementation is provided for future reference: try: - workflow_token = WorkflowHandle[OutputT].decode(token) + nexus_handle = WorkflowHandle[OutputT].from_token(token) except Exception as err: raise HandlerError( "Failed to decode operation token as workflow operation token. " @@ -134,7 +134,7 @@ async def fetch_result( ) ctx = temporal_operation_context.get() try: - handle = workflow_token.to_workflow_handle( + client_handle = nexus_handle.to_workflow_handle( ctx.client, result_type=self._output_type ) except Exception as err: @@ -143,7 +143,7 @@ async def fetch_result( type=HandlerErrorType.NOT_FOUND, cause=err, ) - return await handle.result() + return await client_handle.result() async def cancel_operation( @@ -157,7 +157,7 @@ async def cancel_operation( client: The client to use to cancel the operation. """ try: - workflow_token = WorkflowHandle[Any].decode(token) + nexus_workflow_handle = WorkflowHandle[Any].from_token(token) except Exception as err: raise HandlerError( "Failed to decode operation token as workflow operation token. " @@ -168,11 +168,13 @@ async def cancel_operation( ctx = temporal_operation_context.get() try: - handle = workflow_token._to_client_workflow_handle(ctx.client) + client_workflow_handle = nexus_workflow_handle._to_client_workflow_handle( + ctx.client + ) except Exception as err: raise HandlerError( "Failed to construct workflow handle from workflow operation token", type=HandlerErrorType.NOT_FOUND, cause=err, ) - await handle.cancel(**kwargs) + await client_workflow_handle.cancel(**kwargs) diff --git a/temporalio/nexus/_token.py b/temporalio/nexus/_token.py index b80d23738..9f2957888 100644 --- a/temporalio/nexus/_token.py +++ b/temporalio/nexus/_token.py @@ -52,7 +52,7 @@ def _unsafe_from_client_workflow_handle( workflow_id=workflow_handle.id, ) - def encode(self) -> str: + def to_token(self) -> str: return _base64url_encode_no_padding( json.dumps( { @@ -65,7 +65,7 @@ def encode(self) -> str: ) @classmethod - def decode(cls, token: str) -> WorkflowHandle[OutputT]: + def from_token(cls, token: str) -> WorkflowHandle[OutputT]: """Decodes and validates a token from its base64url-encoded string representation.""" if not token: raise TypeError("invalid workflow token: token is empty") diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 28b412cf1..3dae44e61 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -156,12 +156,12 @@ async def start( if isinstance(input.response_type, SyncResponse): return StartOperationResultSync(value=OpOutput(value="sync response")) elif isinstance(input.response_type, AsyncResponse): - token = await nexus.start_workflow( + handle = await nexus.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, ) - return StartOperationResultAsync(token.encode()) + return StartOperationResultAsync(handle.to_token()) else: raise TypeError @@ -572,7 +572,7 @@ async def test_async_response( if op_definition_type == OpDefinitionType.SHORTHAND else "sync_or_async_operation" ) - assert nexus.WorkflowHandle.decode( + assert nexus.WorkflowHandle.from_token( e.__cause__.operation_token ) == nexus.WorkflowHandle( namespace=handler_wf_handle._client.namespace, diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index 1fcd3b976..ebf3a7ba9 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -43,12 +43,12 @@ def __init__(self): async def start( self, ctx: StartOperationContext, input: Input ) -> StartOperationResultAsync: - token = await nexus.start_workflow( + handle = await nexus.start_workflow( EchoWorkflow.run, input.value, id=str(uuid.uuid4()), ) - return StartOperationResultAsync(token.encode()) + return StartOperationResultAsync(handle.to_token()) @service_handler From fa0344b91566123588dd3c26e0bf0be397cc65ed Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 06:50:38 -0400 Subject: [PATCH 085/183] Respond to upstream: sync_operation_handler -> sync_operation --- ...ynamic_creation_of_user_handler_classes.py | 4 +-- tests/nexus/test_handler.py | 34 +++++++++---------- .../test_handler_interface_implementation.py | 4 +-- tests/nexus/test_workflow_caller.py | 12 +++---- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index 39d0b8f72..b15257a45 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -3,7 +3,7 @@ import httpx import nexusrpc.handler import pytest -from nexusrpc.handler import sync_operation_handler +from nexusrpc.handler import sync_operation from nexusrpc.handler._util import get_operation_factory from temporalio.client import Client @@ -33,7 +33,7 @@ def make_incrementer_user_service_definition_and_service_handler_classes( # # service handler # - @sync_operation_handler + @sync_operation async def _increment_op( self, ctx: nexusrpc.handler.StartOperationContext, diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 696d64524..bdfc28bfb 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -42,7 +42,7 @@ StartOperationContext, operation_handler, service_handler, - sync_operation_handler, + sync_operation, ) import temporalio.api.failure.v1 @@ -134,7 +134,7 @@ async def run(self, input: Input) -> Output: # The service_handler decorator is applied by the test class MyServiceHandler: - @sync_operation_handler + @sync_operation async def echo(self, ctx: StartOperationContext, input: Input) -> Output: assert ctx.headers["test-header-key"] == "test-header-value" ctx.outbound_links.extend(ctx.inbound_links) @@ -142,12 +142,12 @@ async def echo(self, ctx: StartOperationContext, input: Input) -> Output: value=f"from start method on {self.__class__.__name__}: {input.value}" ) - @sync_operation_handler + @sync_operation async def hang(self, ctx: StartOperationContext, input: Input) -> Output: await asyncio.Future() return Output(value="won't reach here") - @sync_operation_handler + @sync_operation async def non_retryable_application_error( self, ctx: StartOperationContext, input: Input ) -> Output: @@ -159,7 +159,7 @@ async def non_retryable_application_error( non_retryable=True, ) - @sync_operation_handler + @sync_operation async def retryable_application_error( self, ctx: StartOperationContext, input: Input ) -> Output: @@ -170,7 +170,7 @@ async def retryable_application_error( non_retryable=False, ) - @sync_operation_handler + @sync_operation async def handler_error_internal( self, ctx: StartOperationContext, input: Input ) -> Output: @@ -181,7 +181,7 @@ async def handler_error_internal( cause=RuntimeError("cause message"), ) - @sync_operation_handler + @sync_operation async def operation_error_failed( self, ctx: StartOperationContext, input: Input ) -> Output: @@ -190,7 +190,7 @@ async def operation_error_failed( state=OperationErrorState.FAILED, ) - @sync_operation_handler + @sync_operation async def check_operation_timeout_header( self, ctx: StartOperationContext, input: Input ) -> Output: @@ -199,7 +199,7 @@ async def check_operation_timeout_header( value=f"from start method on {self.__class__.__name__}: {input.value}" ) - @sync_operation_handler + @sync_operation async def log(self, ctx: StartOperationContext, input: Input) -> Output: nexus.logger.info( "Logging from start method", extra={"input_value": input.value} @@ -217,7 +217,7 @@ async def workflow_run_operation( id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) - @sync_operation_handler + @sync_operation async def sync_operation_with_non_async_def( self, ctx: StartOperationContext, input: Input ) -> Output: @@ -240,7 +240,7 @@ def __call__( value=f"from start method on {self.__class__.__name__}: {input.value}" ) - return sync_operation_handler(start()) + return sync_operation(start()) _sync_operation_with_non_async_callable_instance = operation_handler( name="sync_operation_with_non_async_callable_instance", @@ -248,7 +248,7 @@ def __call__( sync_operation_with_non_async_callable_instance, ) - @sync_operation_handler + @sync_operation async def sync_operation_without_type_annotations( self, ctx: StartOperationContext, input: Input ) -> Output: @@ -312,13 +312,13 @@ def operation_returning_unwrapped_result_at_runtime_error( ) -> OperationHandler[Input, Output]: return MyServiceHandler.OperationHandlerReturningUnwrappedResult() - @sync_operation_handler + @sync_operation async def idempotency_check( self, ctx: StartOperationContext, input: None ) -> Output: return Output(value=f"request_id: {ctx.request_id}") - @sync_operation_handler + @sync_operation async def non_serializable_output( self, ctx: StartOperationContext, input: Input ) -> NonSerializableOutput: @@ -890,7 +890,7 @@ class EchoService: @service_handler(service=EchoService) class SyncStartHandler: # TODO(nexus-prerelease): why is this test passing? start is not `async def` - @sync_operation_handler + @sync_operation def echo(self, ctx: StartOperationContext, input: Input) -> Output: assert ctx.headers["test-header-key"] == "test-header-value" ctx.outbound_links.extend(ctx.inbound_links) @@ -901,7 +901,7 @@ def echo(self, ctx: StartOperationContext, input: Input) -> Output: @service_handler(service=EchoService) class DefaultCancelHandler: - @sync_operation_handler + @sync_operation async def echo(self, ctx: StartOperationContext, input: Input) -> Output: return Output( value=f"from start method on {self.__class__.__name__}: {input.value}" @@ -1012,7 +1012,7 @@ async def test_cancel_operation_with_invalid_token(env: WorkflowEnvironment): assert "failed to decode operation token" in failure.message.lower() -async def test_request_id_is_received_by_sync_operation_handler( +async def test_request_id_is_received_by_sync_operation( env: WorkflowEnvironment, ): task_queue = str(uuid.uuid4()) diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index b1fdefa2f..e561c054d 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -3,7 +3,7 @@ import nexusrpc import nexusrpc.handler import pytest -from nexusrpc.handler import StartOperationContext, sync_operation_handler +from nexusrpc.handler import StartOperationContext, sync_operation from temporalio import nexus @@ -22,7 +22,7 @@ class Interface: op: nexusrpc.Operation[None, None] class Impl: - @sync_operation_handler + @sync_operation async def op(self, ctx: StartOperationContext, input: None) -> None: ... error_message = None diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 3dae44e61..37915e536 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -17,7 +17,7 @@ StartOperationResultSync, operation_handler, service_handler, - sync_operation_handler, + sync_operation, ) import temporalio.api @@ -187,7 +187,7 @@ def sync_or_async_operation( ) -> OperationHandler[OpInput, OpOutput]: return SyncOrAsyncOperation() - @sync_operation_handler + @sync_operation async def sync_operation( self, ctx: StartOperationContext, input: OpInput ) -> OpOutput: @@ -743,7 +743,7 @@ class ServiceInterfaceWithNameOverride: @service_handler class ServiceImplInterfaceWithNeitherInterfaceNorNameOverride: - @sync_operation_handler + @sync_operation async def op( self, ctx: StartOperationContext, input: None ) -> ServiceClassNameOutput: @@ -752,7 +752,7 @@ async def op( @service_handler(service=ServiceInterfaceWithoutNameOverride) class ServiceImplInterfaceWithoutNameOverride: - @sync_operation_handler + @sync_operation async def op( self, ctx: StartOperationContext, input: None ) -> ServiceClassNameOutput: @@ -761,7 +761,7 @@ async def op( @service_handler(service=ServiceInterfaceWithNameOverride) class ServiceImplInterfaceWithNameOverride: - @sync_operation_handler + @sync_operation async def op( self, ctx: StartOperationContext, input: None ) -> ServiceClassNameOutput: @@ -770,7 +770,7 @@ async def op( @service_handler(name="service-impl-🌈") class ServiceImplWithNameOverride: - @sync_operation_handler + @sync_operation async def op( self, ctx: StartOperationContext, input: None ) -> ServiceClassNameOutput: From f1bd90d9c147fca0ec5e1bc4e6cc060cb802971f Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 06:51:06 -0400 Subject: [PATCH 086/183] workflow_run_operation_handler -> workflow_run_operation --- temporalio/nexus/__init__.py | 8 ++------ temporalio/nexus/_decorators.py | 6 +++--- tests/nexus/test_handler.py | 10 +++++----- tests/nexus/test_handler_interface_implementation.py | 2 +- tests/nexus/test_handler_operation_definitions.py | 6 +++--- tests/nexus/test_workflow_caller.py | 4 ++-- 6 files changed, 16 insertions(+), 20 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index 0bdf56e5c..e4b69325d 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -6,15 +6,11 @@ Optional, ) -from ._decorators import ( - workflow_run_operation_handler as workflow_run_operation_handler, -) +from ._decorators import workflow_run_operation as workflow_run_operation from ._operation_context import ( _TemporalNexusOperationContext as _TemporalNexusOperationContext, ) -from ._operation_context import ( - temporal_operation_context as temporal_operation_context, -) +from ._operation_context import temporal_operation_context as temporal_operation_context from ._operation_handlers import cancel_operation as cancel_operation from ._token import WorkflowHandle as WorkflowHandle from ._workflow import start_workflow as start_workflow diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index 6f274495c..868c68da2 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -28,7 +28,7 @@ @overload -def workflow_run_operation_handler( +def workflow_run_operation( start: Callable[ [ServiceHandlerT, StartOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], @@ -40,7 +40,7 @@ def workflow_run_operation_handler( @overload -def workflow_run_operation_handler( +def workflow_run_operation( *, name: Optional[str] = None, ) -> Callable[ @@ -57,7 +57,7 @@ def workflow_run_operation_handler( ]: ... -def workflow_run_operation_handler( +def workflow_run_operation( start: Optional[ Callable[ [ServiceHandlerT, StartOperationContext, InputT], diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index bdfc28bfb..a82c86bc6 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -206,7 +206,7 @@ async def log(self, ctx: StartOperationContext, input: Input) -> Output: ) return Output(value=f"logged: {input.value}") - @nexus.workflow_run_operation_handler + @nexus.workflow_run_operation async def workflow_run_operation( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: @@ -257,7 +257,7 @@ async def sync_operation_without_type_annotations( value=f"from start method on {self.__class__.__name__} without type annotations: {input}" ) - @nexus.workflow_run_operation_handler + @nexus.workflow_run_operation async def workflow_run_operation_without_type_annotations(self, ctx, input): return await nexus.start_workflow( WorkflowWithoutTypeAnnotations.run, @@ -265,7 +265,7 @@ async def workflow_run_operation_without_type_annotations(self, ctx, input): id=str(uuid.uuid4()), ) - @nexus.workflow_run_operation_handler + @nexus.workflow_run_operation async def workflow_run_op_link_test( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: @@ -1049,7 +1049,7 @@ async def run(self, input: Input) -> Output: @service_handler class ServiceHandlerForRequestIdTest: - @nexus.workflow_run_operation_handler + @nexus.workflow_run_operation async def operation_backed_by_a_workflow( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: @@ -1060,7 +1060,7 @@ async def operation_backed_by_a_workflow( id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) - @nexus.workflow_run_operation_handler + @nexus.workflow_run_operation async def operation_that_executes_a_workflow_before_starting_the_backing_workflow( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index e561c054d..d8ece15b2 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -34,7 +34,7 @@ class Interface: op: nexusrpc.Operation[str, int] class Impl: - @nexus.workflow_run_operation_handler + @nexus.workflow_run_operation async def op( self, ctx: StartOperationContext, input: str ) -> nexus.WorkflowHandle[int]: ... diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index a2288a37c..e564cfd76 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -32,7 +32,7 @@ class _TestCase: class NotCalled(_TestCase): @nexusrpc.handler.service_handler class Service: - @nexus.workflow_run_operation_handler + @nexus.workflow_run_operation async def my_workflow_run_operation_handler( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... @@ -50,7 +50,7 @@ async def my_workflow_run_operation_handler( class CalledWithoutArgs(_TestCase): @nexusrpc.handler.service_handler class Service: - @nexus.workflow_run_operation_handler + @nexus.workflow_run_operation async def my_workflow_run_operation_handler( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... @@ -61,7 +61,7 @@ async def my_workflow_run_operation_handler( class CalledWithNameOverride(_TestCase): @nexusrpc.handler.service_handler class Service: - @nexus.workflow_run_operation_handler(name="operation-name") + @nexus.workflow_run_operation(name="operation-name") async def workflow_run_operation_with_name_override( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 37915e536..2295e0c40 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -200,7 +200,7 @@ async def sync_operation( ) return OpOutput(value="sync response") - @nexus.workflow_run_operation_handler + @nexus.workflow_run_operation async def async_operation( self, ctx: StartOperationContext, input: OpInput ) -> nexus.WorkflowHandle[HandlerWfOutput]: @@ -911,7 +911,7 @@ async def run(self, input: str) -> str: @service_handler class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: - @nexus.workflow_run_operation_handler + @nexus.workflow_run_operation async def my_workflow_run_operation( self, ctx: StartOperationContext, input: None ) -> nexus.WorkflowHandle[str]: From 3526b89db58d4f53aeff8e36f3abf405ff3ec7fc Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 06:54:19 -0400 Subject: [PATCH 087/183] from nexus import workflow_run_operation --- tests/nexus/test_handler.py | 19 ++++++++++--------- .../test_handler_interface_implementation.py | 3 ++- .../test_handler_operation_definitions.py | 7 ++++--- tests/nexus/test_workflow_caller.py | 5 +++-- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index a82c86bc6..8c3a653c2 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -51,6 +51,7 @@ from temporalio.common import WorkflowIDReusePolicy from temporalio.converter import FailureConverter, PayloadConverter from temporalio.exceptions import ApplicationError +from temporalio.nexus import workflow_run_operation from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ServiceClient, create_nexus_endpoint, dataclass_as_dict @@ -89,7 +90,7 @@ class MyService: # ) hang: nexusrpc.Operation[Input, Output] log: nexusrpc.Operation[Input, Output] - workflow_run_operation: nexusrpc.Operation[Input, Output] + workflow_run_operation_happy_path: nexusrpc.Operation[Input, Output] workflow_run_operation_without_type_annotations: nexusrpc.Operation[Input, Output] sync_operation_without_type_annotations: nexusrpc.Operation[Input, Output] sync_operation_with_non_async_def: nexusrpc.Operation[Input, Output] @@ -206,8 +207,8 @@ async def log(self, ctx: StartOperationContext, input: Input) -> Output: ) return Output(value=f"logged: {input.value}") - @nexus.workflow_run_operation - async def workflow_run_operation( + @workflow_run_operation + async def workflow_run_operation_happy_path( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: return await nexus.start_workflow( @@ -257,7 +258,7 @@ async def sync_operation_without_type_annotations( value=f"from start method on {self.__class__.__name__} without type annotations: {input}" ) - @nexus.workflow_run_operation + @workflow_run_operation async def workflow_run_operation_without_type_annotations(self, ctx, input): return await nexus.start_workflow( WorkflowWithoutTypeAnnotations.run, @@ -265,7 +266,7 @@ async def workflow_run_operation_without_type_annotations(self, ctx, input): id=str(uuid.uuid4()), ) - @nexus.workflow_run_operation + @workflow_run_operation async def workflow_run_op_link_test( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: @@ -537,7 +538,7 @@ class SyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): class AsyncHandlerHappyPath(_TestCase): - operation = "workflow_run_operation" + operation = "workflow_run_operation_happy_path" input = Input("hello") headers = {"Operation-Timeout": "777s"} expected = SuccessfulResponse( @@ -1004,7 +1005,7 @@ async def test_cancel_operation_with_invalid_token(env: WorkflowEnvironment): nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), ): cancel_response = await service_client.cancel_operation( - "workflow_run_operation", + "workflow_run_operation_happy_path", token="this-is-not-a-valid-token", ) assert cancel_response.status_code == 404 @@ -1049,7 +1050,7 @@ async def run(self, input: Input) -> Output: @service_handler class ServiceHandlerForRequestIdTest: - @nexus.workflow_run_operation + @workflow_run_operation async def operation_backed_by_a_workflow( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: @@ -1060,7 +1061,7 @@ async def operation_backed_by_a_workflow( id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, ) - @nexus.workflow_run_operation + @workflow_run_operation async def operation_that_executes_a_workflow_before_starting_the_backing_workflow( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index d8ece15b2..114b1dacc 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -6,6 +6,7 @@ from nexusrpc.handler import StartOperationContext, sync_operation from temporalio import nexus +from temporalio.nexus import workflow_run_operation HTTP_PORT = 7243 @@ -34,7 +35,7 @@ class Interface: op: nexusrpc.Operation[str, int] class Impl: - @nexus.workflow_run_operation + @workflow_run_operation async def op( self, ctx: StartOperationContext, input: str ) -> nexus.WorkflowHandle[int]: ... diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index e564cfd76..cb0b41c99 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -11,6 +11,7 @@ from nexusrpc.handler import StartOperationContext from temporalio import nexus +from temporalio.nexus import workflow_run_operation @dataclass @@ -32,7 +33,7 @@ class _TestCase: class NotCalled(_TestCase): @nexusrpc.handler.service_handler class Service: - @nexus.workflow_run_operation + @workflow_run_operation async def my_workflow_run_operation_handler( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... @@ -50,7 +51,7 @@ async def my_workflow_run_operation_handler( class CalledWithoutArgs(_TestCase): @nexusrpc.handler.service_handler class Service: - @nexus.workflow_run_operation + @workflow_run_operation async def my_workflow_run_operation_handler( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... @@ -61,7 +62,7 @@ async def my_workflow_run_operation_handler( class CalledWithNameOverride(_TestCase): @nexusrpc.handler.service_handler class Service: - @nexus.workflow_run_operation(name="operation-name") + @workflow_run_operation(name="operation-name") async def workflow_run_operation_with_name_override( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 2295e0c40..e9e96ce5f 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -38,6 +38,7 @@ ) from temporalio.common import WorkflowIDConflictPolicy from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError +from temporalio.nexus import workflow_run_operation from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -200,7 +201,7 @@ async def sync_operation( ) return OpOutput(value="sync response") - @nexus.workflow_run_operation + @workflow_run_operation async def async_operation( self, ctx: StartOperationContext, input: OpInput ) -> nexus.WorkflowHandle[HandlerWfOutput]: @@ -911,7 +912,7 @@ async def run(self, input: str) -> str: @service_handler class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: - @nexus.workflow_run_operation + @workflow_run_operation async def my_workflow_run_operation( self, ctx: StartOperationContext, input: None ) -> nexus.WorkflowHandle[str]: From 29f11cace5550d6f6b14e7d18c13e11c131145ec Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 07:32:05 -0400 Subject: [PATCH 088/183] Respond to upstream: operation_handler is not in the public API --- tests/nexus/test_handler.py | 2 +- tests/nexus/test_handler_async_operation.py | 10 ++++++---- tests/nexus/test_workflow_caller.py | 2 +- tests/nexus/test_workflow_run_operation.py | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 8c3a653c2..e85471fe8 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -40,10 +40,10 @@ OperationErrorState, OperationHandler, StartOperationContext, - operation_handler, service_handler, sync_operation, ) +from nexusrpc.handler._decorators import operation_handler import temporalio.api.failure.v1 from temporalio import nexus, workflow diff --git a/tests/nexus/test_handler_async_operation.py b/tests/nexus/test_handler_async_operation.py index 19d4f0ae1..82280f5bd 100644 --- a/tests/nexus/test_handler_async_operation.py +++ b/tests/nexus/test_handler_async_operation.py @@ -23,7 +23,9 @@ OperationHandler, StartOperationContext, StartOperationResultAsync, + service_handler, ) +from nexusrpc.handler._decorators import operation_handler from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker @@ -109,21 +111,21 @@ def cancel(self, ctx: CancelOperationContext, token: str) -> None: @dataclass -@nexusrpc.handler.service_handler +@service_handler class MyServiceHandlerWithAsyncDefs: executor: TaskExecutor - @nexusrpc.handler.operation_handler + @operation_handler def async_operation(self) -> OperationHandler[Input, Output]: return AsyncOperationWithAsyncDefs(self.executor) @dataclass -@nexusrpc.handler.service_handler +@service_handler class MyServiceHandlerWithNonAsyncDefs: executor: TaskExecutor - @nexusrpc.handler.operation_handler + @operation_handler def async_operation(self) -> OperationHandler[Input, Output]: return AsyncOperationWithNonAsyncDefs(self.executor) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index e9e96ce5f..09043eba1 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -15,10 +15,10 @@ StartOperationContext, StartOperationResultAsync, StartOperationResultSync, - operation_handler, service_handler, sync_operation, ) +from nexusrpc.handler._decorators import operation_handler import temporalio.api import temporalio.api.common diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index ebf3a7ba9..d45895222 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -8,9 +8,9 @@ OperationHandler, StartOperationContext, StartOperationResultAsync, - operation_handler, service_handler, ) +from nexusrpc.handler._decorators import operation_handler from temporalio import nexus, workflow from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler From 97f1d48772592e94f31646c8a455f0fc8357bc8b Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 07:41:31 -0400 Subject: [PATCH 089/183] New nexus operation context API - Expose nexus.client(), nexus.info() - Don't expose contextvar --- temporalio/nexus/__init__.py | 11 +++-- temporalio/nexus/_operation_context.py | 55 ++++++++++++++++++++----- temporalio/nexus/_operation_handlers.py | 6 +-- temporalio/nexus/_workflow.py | 10 ++--- temporalio/worker/_nexus.py | 11 ++--- tests/nexus/test_handler.py | 32 +++++++++----- tests/nexus/test_workflow_caller.py | 5 +-- 7 files changed, 90 insertions(+), 40 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index e4b69325d..0c42dd564 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -7,10 +7,15 @@ ) from ._decorators import workflow_run_operation as workflow_run_operation +from ._operation_context import Info as Info +from ._operation_context import ( + _temporal_operation_context as _temporal_operation_context, +) from ._operation_context import ( _TemporalNexusOperationContext as _TemporalNexusOperationContext, ) -from ._operation_context import temporal_operation_context as temporal_operation_context +from ._operation_context import client as client +from ._operation_context import info as info from ._operation_handlers import cancel_operation as cancel_operation from ._token import WorkflowHandle as WorkflowHandle from ._workflow import start_workflow as start_workflow @@ -24,10 +29,10 @@ def process( self, msg: Any, kwargs: MutableMapping[str, Any] ) -> tuple[Any, MutableMapping[str, Any]]: extra = dict(self.extra or {}) - if tctx := temporal_operation_context.get(None): + if tctx := _temporal_operation_context.get(None): extra["service"] = tctx.nexus_operation_context.service extra["operation"] = tctx.nexus_operation_context.operation - extra["task_queue"] = tctx.task_queue + extra["task_queue"] = tctx.info().task_queue kwargs["extra"] = extra | kwargs.get("extra", {}) return msg, kwargs diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 52d17e643..e00f490fd 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import ( Any, + Callable, Optional, Union, ) @@ -16,31 +17,61 @@ import temporalio.api.common.v1 import temporalio.api.enums.v1 +import temporalio.client import temporalio.common -from temporalio import client logger = logging.getLogger(__name__) -# TODO(nexus-prerelease): Confirm how exactly we want to expose Temporal Nexus operation context - -temporal_operation_context: ContextVar[_TemporalNexusOperationContext] = ContextVar( +_temporal_operation_context: ContextVar[_TemporalNexusOperationContext] = ContextVar( "temporal-operation-context" ) +@dataclass(frozen=True) +class Info: + """Information about the running Nexus operation. + + Retrieved inside a Nexus operation handler via :py:func:`info`. + """ + + task_queue: str + """The task queue of the worker handling this Nexus operation.""" + + +def info() -> Info: + """ + Get the current Nexus operation information. + """ + return _TemporalNexusOperationContext.get().info() + + +def client() -> temporalio.client.Client: + """ + Get the Temporal client used by the worker handling the current Nexus operation. + """ + return _TemporalNexusOperationContext.get().client + + @dataclass class _TemporalNexusOperationContext: """ Context for a Nexus operation being handled by a Temporal Nexus Worker. """ + info: Callable[[], Info] + """Information about the running Nexus operation.""" + nexus_operation_context: Union[StartOperationContext, CancelOperationContext] - client: client.Client + client: temporalio.client.Client """The Temporal client in use by the worker handling this Nexus operation.""" - task_queue: str - """The task queue of the worker handling this Nexus operation.""" + @classmethod + def get(cls) -> _TemporalNexusOperationContext: + ctx = _temporal_operation_context.get(None) + if ctx is None: + raise RuntimeError("Not in Nexus operation context.") + return ctx @property def _temporal_start_operation_context( @@ -67,7 +98,7 @@ class _TemporalStartOperationContext: def get_completion_callbacks( self, - ) -> list[client.NexusCompletionCallback]: + ) -> list[temporalio.client.NexusCompletionCallback]: ctx = self.nexus_operation_context return ( [ @@ -76,7 +107,7 @@ def get_completion_callbacks( # StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links # (for backwards compatibility). PR reference in Go SDK: # https://github.com/temporalio/sdk-go/pull/1945 - client.NexusCompletionCallback( + temporalio.client.NexusCompletionCallback( url=ctx.callback_url, header=ctx.callback_headers, ) @@ -94,7 +125,9 @@ def get_workflow_event_links( event_links.append(link) return event_links - def add_outbound_links(self, workflow_handle: client.WorkflowHandle[Any, Any]): + def add_outbound_links( + self, workflow_handle: temporalio.client.WorkflowHandle[Any, Any] + ): try: link = _workflow_event_to_nexus_link( _workflow_handle_to_workflow_execution_started_event_link( @@ -124,7 +157,7 @@ class _TemporalCancelOperationContext: def _workflow_handle_to_workflow_execution_started_event_link( - handle: client.WorkflowHandle[Any, Any], + handle: temporalio.client.WorkflowHandle[Any, Any], ) -> temporalio.api.common.v1.Link.WorkflowEvent: if handle.first_execution_run_id is None: raise ValueError( diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index 06faf246a..31083dfe6 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -26,7 +26,7 @@ from temporalio import client from temporalio.nexus._operation_context import ( - temporal_operation_context, + _temporal_operation_context, ) from temporalio.nexus._token import WorkflowHandle @@ -132,7 +132,7 @@ async def fetch_result( type=HandlerErrorType.NOT_FOUND, cause=err, ) - ctx = temporal_operation_context.get() + ctx = _temporal_operation_context.get() try: client_handle = nexus_handle.to_workflow_handle( ctx.client, result_type=self._output_type @@ -166,7 +166,7 @@ async def cancel_operation( cause=err, ) - ctx = temporal_operation_context.get() + ctx = _temporal_operation_context.get() try: client_workflow_handle = nexus_workflow_handle._to_client_workflow_handle( ctx.client diff --git a/temporalio/nexus/_workflow.py b/temporalio/nexus/_workflow.py index aa8c28f1c..272610533 100644 --- a/temporalio/nexus/_workflow.py +++ b/temporalio/nexus/_workflow.py @@ -12,7 +12,7 @@ import temporalio.api.common.v1 import temporalio.api.enums.v1 import temporalio.common -from temporalio.nexus._operation_context import temporal_operation_context +from temporalio.nexus._operation_context import _TemporalNexusOperationContext from temporalio.nexus._token import WorkflowHandle from temporalio.types import ( MethodAsyncSingleParam, @@ -83,8 +83,8 @@ async def start_workflow( Nexus caller is itself a workflow, this means that the workflow in the caller namespace web UI will contain links to the started workflow, and vice versa. """ - ctx = temporal_operation_context.get() - start_operation_context = ctx._temporal_start_operation_context + tctx = _TemporalNexusOperationContext.get() + start_operation_context = tctx._temporal_start_operation_context if not start_operation_context: raise RuntimeError( "temporalio.nexus.start_workflow() must be called from " @@ -103,11 +103,11 @@ async def start_workflow( # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, # but these are deliberately not exposed in overloads, hence the type-check # violation. - wf_handle = await ctx.client.start_workflow( # type: ignore + wf_handle = await tctx.client.start_workflow( # type: ignore workflow=workflow, arg=arg, id=id, - task_queue=task_queue or ctx.task_queue, + task_queue=task_queue or tctx.info().task_queue, execution_timeout=execution_timeout, run_timeout=run_timeout, task_timeout=task_timeout, diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 4f1ff4df4..230433617 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -32,9 +32,10 @@ import temporalio.nexus from temporalio.exceptions import ApplicationError from temporalio.nexus import ( + Info, + _temporal_operation_context, _TemporalNexusOperationContext, logger, - temporal_operation_context, ) from temporalio.service import RPCError, RPCStatusCode @@ -166,11 +167,11 @@ async def _handle_cancel_operation_task( service=request.service, operation=request.operation, ) - temporal_operation_context.set( + _temporal_operation_context.set( _TemporalNexusOperationContext( + info=lambda: Info(task_queue=self._task_queue), nexus_operation_context=ctx, client=self._client, - task_queue=self._task_queue, ) ) # TODO(nexus-prerelease): headers @@ -263,11 +264,11 @@ async def _start_operation( ], callback_headers=dict(start_request.callback_header), ) - temporal_operation_context.set( + _temporal_operation_context.set( _TemporalNexusOperationContext( nexus_operation_context=ctx, client=self._client, - task_queue=self._task_queue, + info=lambda: Info(task_queue=self._task_queue), ) ) input = LazyValue( diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index e85471fe8..5014a48c8 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -1065,12 +1065,11 @@ async def operation_backed_by_a_workflow( async def operation_that_executes_a_workflow_before_starting_the_backing_workflow( self, ctx: StartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: - tctx = nexus.temporal_operation_context.get() - await tctx.client.start_workflow( + await nexus.client().start_workflow( EchoWorkflow.run, input, id=input.value, - task_queue=tctx.task_queue, + task_queue=nexus.info().task_queue, ) # This should fail. It will not fail if the Nexus request ID was incorrectly # propagated to both StartWorkflow requests. @@ -1097,10 +1096,10 @@ async def test_request_id_becomes_start_workflow_request_id(env: WorkflowEnviron ) async def start_two_workflows_with_conflicting_workflow_ids( - request_ids: tuple[tuple[str, int], tuple[str, int]], + request_ids: tuple[tuple[str, int, str], tuple[str, int, str]], ): workflow_id = str(uuid.uuid4()) - for request_id, status_code in request_ids: + for request_id, status_code, error_message in request_ids: resp = await service_client.start_operation( "operation_backed_by_a_workflow", dataclass_as_dict(Input(workflow_id)), @@ -1111,13 +1110,18 @@ async def start_two_workflows_with_conflicting_workflow_ids( f"but got {resp.status_code} for response content " f"{pprint.pformat(resp.content.decode())}" ) - if status_code == 201: + if not error_message: + assert status_code == 201 op_info = resp.json() assert op_info["token"] assert op_info["state"] == nexusrpc.OperationState.RUNNING.value + else: + assert status_code >= 400 + failure = Failure(**resp.json()) + assert failure.message == error_message async def start_two_workflows_in_a_single_operation( - request_id: str, status_code: int + request_id: str, status_code: int, error_message: str ): resp = await service_client.start_operation( "operation_that_executes_a_workflow_before_starting_the_backing_workflow", @@ -1125,6 +1129,9 @@ async def start_two_workflows_in_a_single_operation( {"Nexus-Request-Id": request_id}, ) assert resp.status_code == status_code + if error_message: + failure = Failure(**resp.json()) + assert failure.message == error_message async with Worker( env.client, @@ -1135,17 +1142,22 @@ async def start_two_workflows_in_a_single_operation( request_id_1, request_id_2 = str(uuid.uuid4()), str(uuid.uuid4()) # Reusing the same request ID does not fail await start_two_workflows_with_conflicting_workflow_ids( - ((request_id_1, 201), (request_id_1, 201)) + ((request_id_1, 201, ""), (request_id_1, 201, "")) ) # Using a different request ID does fail # TODO(nexus-prerelease) I think that this should be a 409 per the spec. Go and # Java are not doing that. await start_two_workflows_with_conflicting_workflow_ids( - ((request_id_1, 201), (request_id_2, 500)) + ( + (request_id_1, 201, ""), + (request_id_2, 500, "Workflow execution already started"), + ) ) # Two workflows started in the same operation should fail, since the Nexus # request ID should be propagated to the backing workflow only. - await start_two_workflows_in_a_single_operation(request_id_1, 500) + await start_two_workflows_in_a_single_operation( + request_id_1, 500, "Workflow execution already started" + ) def server_address(env: WorkflowEnvironment) -> str: diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 09043eba1..3fe699608 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -916,12 +916,11 @@ class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: async def my_workflow_run_operation( self, ctx: StartOperationContext, input: None ) -> nexus.WorkflowHandle[str]: - tctx = nexus.temporal_operation_context.get() - result_1 = await tctx.client.execute_workflow( + result_1 = await nexus.client().execute_workflow( EchoWorkflow.run, "result-1", id=str(uuid.uuid4()), - task_queue=tctx.task_queue, + task_queue=nexus.info().task_queue, ) # In case result_1 is incorrectly being delivered to the caller as the operation # result, give time for that incorrect behavior to occur. From 009faca0916913f826901b38ee4f6f7cda45a284 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 08:35:31 -0400 Subject: [PATCH 090/183] Fix broken test Accidentally broken at f8077c5f4e8722a154207835f2f08bc602abbfca --- tests/nexus/test_handler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 5014a48c8..5d3ec66d4 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -250,10 +250,9 @@ def __call__( ) @sync_operation - async def sync_operation_without_type_annotations( - self, ctx: StartOperationContext, input: Input - ) -> Output: - # The input type from the op definition in the service definition is used to deserialize the input. + async def sync_operation_without_type_annotations(self, ctx, input): + # Despite the lack of type annotations, the input type from the op definition in + # the service definition is used to deserialize the input. return Output( value=f"from start method on {self.__class__.__name__} without type annotations: {input}" ) From 844a9c36b5c571d1abca4f0300439bf9e344fb45 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 08:54:54 -0400 Subject: [PATCH 091/183] Fix another test --- tests/nexus/test_workflow_run_operation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index d45895222..7259fcb21 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -94,7 +94,7 @@ async def test_workflow_run_operation( service_client = ServiceClient( server_address=server_address(env), endpoint=endpoint, - service=service_handler_cls.__name__, + service=service_handler_cls.__nexus_service__.name, ) async with Worker( env.client, From af00209e5958d17b944a0e16d95edaf04784d0a7 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 09:16:07 -0400 Subject: [PATCH 092/183] Move Failure tests utility --- tests/helpers/nexus.py | 33 +++++++++++++++++++++++++++++ tests/nexus/test_handler.py | 42 ++++++------------------------------- 2 files changed, 39 insertions(+), 36 deletions(-) diff --git a/tests/helpers/nexus.py b/tests/helpers/nexus.py index f0f2f3410..5fb134140 100644 --- a/tests/helpers/nexus.py +++ b/tests/helpers/nexus.py @@ -3,11 +3,13 @@ from typing import Any, Mapping, Optional import httpx +from google.protobuf import json_format import temporalio.api import temporalio.api.nexus.v1 import temporalio.api.operatorservice.v1 from temporalio.client import Client +from temporalio.converter import FailureConverter, PayloadConverter def make_nexus_endpoint_name(task_queue: str) -> str: @@ -106,3 +108,34 @@ def dataclass_as_dict(dataclass: Any) -> dict[str, Any]: field.name: getattr(dataclass, field.name) for field in dataclasses.fields(dataclass) } + + +@dataclass +class Failure: + """A Nexus Failure object, with details parsed into an exception. + + https://github.com/nexus-rpc/api/blob/main/SPEC.md#failure + """ + + message: str = "" + metadata: Optional[dict[str, str]] = None + details: Optional[dict[str, Any]] = None + + exception_from_details: Optional[BaseException] = dataclasses.field( + init=False, default=None + ) + + def __post_init__(self) -> None: + if self.metadata and (error_type := self.metadata.get("type")): + self.exception_from_details = self._instantiate_exception( + error_type, self.details + ) + + def _instantiate_exception( + self, error_type: str, details: Optional[dict[str, Any]] + ) -> BaseException: + proto = { + "temporal.api.failure.v1.Failure": temporalio.api.failure.v1.Failure, + }[error_type]() + json_format.ParseDict(self.details, proto, ignore_unknown_fields=True) + return FailureConverter.default.from_failure(proto, PayloadConverter.default) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 5d3ec66d4..e934d88b1 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -15,7 +15,6 @@ import asyncio import concurrent.futures -import dataclasses import json import logging import pprint @@ -28,7 +27,6 @@ import httpx import nexusrpc import pytest -from google.protobuf import json_format from nexusrpc import OperationInfo from nexusrpc.handler import ( CancelOperationContext, @@ -45,16 +43,19 @@ ) from nexusrpc.handler._decorators import operation_handler -import temporalio.api.failure.v1 from temporalio import nexus, workflow from temporalio.client import Client from temporalio.common import WorkflowIDReusePolicy -from temporalio.converter import FailureConverter, PayloadConverter from temporalio.exceptions import ApplicationError from temporalio.nexus import workflow_run_operation from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers.nexus import ServiceClient, create_nexus_endpoint, dataclass_as_dict +from tests.helpers.nexus import ( + Failure, + ServiceClient, + create_nexus_endpoint, + dataclass_as_dict, +) HTTP_PORT = 7243 @@ -325,37 +326,6 @@ async def non_serializable_output( return NonSerializableOutput() -@dataclass -class Failure: - """A Nexus Failure object, with details parsed into an exception. - - https://github.com/nexus-rpc/api/blob/main/SPEC.md#failure - """ - - message: str = "" - metadata: Optional[dict[str, str]] = None - details: Optional[dict[str, Any]] = None - - exception_from_details: Optional[BaseException] = dataclasses.field( - init=False, default=None - ) - - def __post_init__(self) -> None: - if self.metadata and (error_type := self.metadata.get("type")): - self.exception_from_details = self._instantiate_exception( - error_type, self.details - ) - - def _instantiate_exception( - self, error_type: str, details: Optional[dict[str, Any]] - ) -> BaseException: - proto = { - "temporal.api.failure.v1.Failure": temporalio.api.failure.v1.Failure, - }[error_type]() - json_format.ParseDict(self.details, proto, ignore_unknown_fields=True) - return FailureConverter.default.from_failure(proto, PayloadConverter.default) - - # Immutable dicts that can be used as dataclass field defaults SUCCESSFUL_RESPONSE_HEADERS = MappingProxyType( From 914b35ea0c0eaaa30f9fe8dba333a0551e01b54a Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 09:18:12 -0400 Subject: [PATCH 093/183] Fix test --- tests/nexus/test_workflow_run_operation.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index 7259fcb21..d121a53c3 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -16,7 +16,12 @@ from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers.nexus import ServiceClient, create_nexus_endpoint, dataclass_as_dict +from tests.helpers.nexus import ( + Failure, + ServiceClient, + create_nexus_endpoint, + dataclass_as_dict, +) HTTP_PORT = 7243 @@ -69,11 +74,16 @@ class SubclassingNoInputOutputTypeAnnotationsWithoutServiceDefinition: def op(self) -> OperationHandler: return MyOperation() + __expected__error__ = 500, "'dict' object has no attribute 'value'" + @service_handler(service=Service) class SubclassingNoInputOutputTypeAnnotationsWithServiceDefinition: + # Despite the lack of annotations on the service impl, the service definition + # provides the type needed to deserialize the input into Input so that input.value + # succeeds. @operation_handler - def op(self) -> OperationHandler[Input, str]: + def op(self) -> OperationHandler: return MyOperation() @@ -105,7 +115,13 @@ async def test_workflow_run_operation( "op", dataclass_as_dict(Input(value="test")), ) - assert resp.status_code == 201 + if hasattr(service_handler_cls, "__expected__error__"): + status_code, message = service_handler_cls.__expected__error__ + assert resp.status_code == status_code + failure = Failure(**resp.json()) + assert failure.message == message + else: + assert resp.status_code == 201 def server_address(env: WorkflowEnvironment) -> str: From 16432f722ea564b33e1e7fabdb9bee84237847fa Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 11:45:13 -0400 Subject: [PATCH 094/183] RTU: relocate OperationError --- temporalio/worker/_nexus.py | 4 ++-- tests/nexus/test_handler.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 230433617..07536f343 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -306,7 +306,7 @@ async def _start_operation( "nexusrpc.handler.StartOperationResultAsync." ) ) - except nexusrpc.handler.OperationError as err: + except nexusrpc.OperationError as err: return temporalio.api.nexus.v1.StartOperationResponse( operation_error=await self._operation_error_to_proto(err), ) @@ -332,7 +332,7 @@ async def _exception_to_failure_proto( async def _operation_error_to_proto( self, - err: nexusrpc.handler.OperationError, + err: nexusrpc.OperationError, ) -> temporalio.api.nexus.v1.UnsuccessfulOperationError: cause = err.__cause__ if cause is None: diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index e934d88b1..8613e8812 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -27,15 +27,13 @@ import httpx import nexusrpc import pytest -from nexusrpc import OperationInfo +from nexusrpc import OperationError, OperationErrorState, OperationInfo from nexusrpc.handler import ( CancelOperationContext, FetchOperationInfoContext, FetchOperationResultContext, HandlerError, HandlerErrorType, - OperationError, - OperationErrorState, OperationHandler, StartOperationContext, service_handler, From 0f3b85e06953ccdeafc245543a359b86f94c066d Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 14:15:05 -0400 Subject: [PATCH 095/183] Copy get_types utility from nexusrpc --- temporalio/nexus/_util.py | 54 +++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index 2f8277e13..67151555a 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -10,12 +10,10 @@ Callable, Optional, Type, + Union, ) -from nexusrpc.handler import ( - StartOperationContext, - get_start_method_input_and_output_type_annotations, -) +from nexusrpc.handler import StartOperationContext from nexusrpc.types import ( InputT, OutputT, @@ -41,7 +39,8 @@ def get_workflow_run_start_method_input_and_output_type_annotations( `start` must be a type-annotated start method that returns a :py:class:`temporalio.nexus.WorkflowHandle`. """ - input_type, output_type = get_start_method_input_and_output_type_annotations(start) + + input_type, output_type = _get_start_method_input_and_output_type_annotations(start) origin_type = typing.get_origin(output_type) if not origin_type: output_type = None @@ -66,6 +65,51 @@ def get_workflow_run_start_method_input_and_output_type_annotations( return input_type, output_type +def _get_start_method_input_and_output_type_annotations( + start: Callable[ + [ServiceHandlerT, WorkflowRunOperationContext, InputT], + Union[OutputT, Awaitable[OutputT]], + ], +) -> tuple[ + Optional[Type[InputT]], + Optional[Type[OutputT]], +]: + """Return operation input and output types. + + `start` must be a type-annotated start method that returns a synchronous result. + """ + try: + type_annotations = typing.get_type_hints(start) + except TypeError: + # TODO(preview): stacklevel + warnings.warn( + f"Expected decorated start method {start} to have type annotations" + ) + return None, None + output_type = type_annotations.pop("return", None) + + if len(type_annotations) != 2: + # TODO(preview): stacklevel + suffix = f": {type_annotations}" if type_annotations else "" + warnings.warn( + f"Expected decorated start method {start} to have exactly 2 " + f"type-annotated parameters (ctx and input), but it has {len(type_annotations)}" + f"{suffix}." + ) + input_type = None + else: + ctx_type, input_type = type_annotations.values() + if not issubclass(ctx_type, WorkflowRunOperationContext): + # TODO(preview): stacklevel + warnings.warn( + f"Expected first parameter of {start} to be an instance of " + f"WorkflowRunOperationContext, but is {ctx_type}." + ) + input_type = None + + return input_type, output_type + + def get_callable_name(fn: Callable[..., Any]) -> str: method_name = getattr(fn, "__name__", None) if not method_name and callable(fn) and hasattr(fn, "__call__"): From 686f1564a8424b26cc9292a53819dc4e129ff9e5 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 15:07:37 -0400 Subject: [PATCH 096/183] Fixup: eliminate references to WorkflowOperationToken --- temporalio/nexus/_operation_handlers.py | 26 +++---------------------- temporalio/nexus/_util.py | 2 +- temporalio/workflow.py | 8 ++++---- 3 files changed, 8 insertions(+), 28 deletions(-) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index 31083dfe6..721f9b1f7 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -42,26 +42,6 @@ class WorkflowRunOperationHandler(OperationHandler[InputT, OutputT]): Use this class to create an operation handler that starts a workflow by passing your ``start`` method to the constructor. Your ``start`` method must use :py:func:`temporalio.nexus.start_workflow` to start the workflow. - - Example: - - .. code-block:: python - - @service_handler(service=MyNexusService) - class MyNexusServiceHandler: - @operation_handler - def my_workflow_run_operation( - self, - ) -> OperationHandler[MyInput, MyOutput]: - async def start( - ctx: StartOperationContext, input: MyInput - ) -> WorkflowOperationToken[MyOutput]: - return await start_workflow( - WorkflowStartedByNexusOperation.run, input, - id=str(uuid.uuid4()), - ) - - return WorkflowRunOperationHandler.from_start_workflow(start) """ def __init__( @@ -95,13 +75,13 @@ async def start( if not isinstance(handle, WorkflowHandle): if isinstance(handle, client.WorkflowHandle): raise RuntimeError( - f"Expected {handle} to be a WorkflowOperationToken, but got a client.WorkflowHandle. " - f"You must use temporalio.nexus.start_workflow " + f"Expected {handle} to be a nexus.WorkflowHandle, but got a client.WorkflowHandle. " + f"You must use WorkflowRunOperationContext.start_workflow " "to start a workflow that will deliver the result of the Nexus operation, " "not client.Client.start_workflow." ) raise RuntimeError( - f"Expected {handle} to be a WorkflowOperationToken, but got {type(handle)}. " + f"Expected {handle} to be a nexus.WorkflowHandle, but got {type(handle)}. " ) return StartOperationResultAsync(handle.to_token()) diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index 67151555a..a46ff3151 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -46,7 +46,7 @@ def get_workflow_run_start_method_input_and_output_type_annotations( output_type = None elif not issubclass(origin_type, WorkflowHandle): warnings.warn( - f"Expected return type of {start.__name__} to be a subclass of WorkflowOperationToken, " + f"Expected return type of {start.__name__} to be a subclass of WorkflowHandle, " f"but is {output_type}" ) output_type = None diff --git a/temporalio/workflow.py b/temporalio/workflow.py index fe7752c9f..02d34fd90 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -859,7 +859,7 @@ async def workflow_start_nexus_operation( Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], Callable[ [S, nexusrpc.handler.StartOperationContext, I], - Awaitable[temporalio.nexus.WorkflowOperationToken[O]], + Awaitable[temporalio.nexus.WorkflowHandle[O]], ], str, ], @@ -4429,7 +4429,7 @@ async def start_nexus_operation( Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], Callable[ [S, nexusrpc.handler.StartOperationContext, I], - Awaitable[temporalio.nexus.WorkflowOperationToken[O]], + Awaitable[temporalio.nexus.WorkflowHandle[O]], ], str, ], @@ -5215,7 +5215,7 @@ async def start_operation( Callable[[S], nexusrpc.handler.OperationHandler[I, O]], Callable[ [S, nexusrpc.handler.StartOperationContext, I], - Awaitable[temporalio.nexus.WorkflowOperationToken[O]], + Awaitable[temporalio.nexus.WorkflowHandle[O]], ], str, ], @@ -5243,7 +5243,7 @@ async def execute_operation( Callable[[S], nexusrpc.handler.OperationHandler[I, O]], Callable[ [S, nexusrpc.handler.StartOperationContext, I], - Awaitable[temporalio.nexus.WorkflowOperationToken[O]], + Awaitable[temporalio.nexus.WorkflowHandle[O]], ], str, ], From 33f4f82d0417d0f75f63a2cea450374926445e0b Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 15:14:09 -0400 Subject: [PATCH 097/183] Move start_workflow to WorkflowRunOperationContext --- temporalio/nexus/__init__.py | 2 +- temporalio/nexus/_decorators.py | 23 +- temporalio/nexus/_operation_handlers.py | 3 +- temporalio/nexus/_util.py | 5 +- temporalio/nexus/_workflow.py | 227 +++++++++--------- tests/nexus/test_handler.py | 27 ++- .../test_handler_interface_implementation.py | 4 +- .../test_handler_operation_definitions.py | 9 +- tests/nexus/test_workflow_caller.py | 15 +- tests/nexus/test_workflow_run_operation.py | 6 +- 10 files changed, 168 insertions(+), 153 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index 0c42dd564..5ac1c1599 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -18,7 +18,7 @@ from ._operation_context import info as info from ._operation_handlers import cancel_operation as cancel_operation from ._token import WorkflowHandle as WorkflowHandle -from ._workflow import start_workflow as start_workflow +from ._workflow import WorkflowRunOperationContext as WorkflowRunOperationContext class LoggerAdapter(logging.LoggerAdapter): diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index 868c68da2..07044d151 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -25,16 +25,17 @@ get_callable_name, get_workflow_run_start_method_input_and_output_type_annotations, ) +from temporalio.nexus._workflow import WorkflowRunOperationContext @overload def workflow_run_operation( start: Callable[ - [ServiceHandlerT, StartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ) -> Callable[ - [ServiceHandlerT, StartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ]: ... @@ -46,12 +47,12 @@ def workflow_run_operation( ) -> Callable[ [ Callable[ - [ServiceHandlerT, StartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ] ], Callable[ - [ServiceHandlerT, StartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ]: ... @@ -60,7 +61,7 @@ def workflow_run_operation( def workflow_run_operation( start: Optional[ Callable[ - [ServiceHandlerT, StartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ] ] = None, @@ -68,18 +69,18 @@ def workflow_run_operation( name: Optional[str] = None, ) -> Union[ Callable[ - [ServiceHandlerT, StartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], Callable[ [ Callable[ - [ServiceHandlerT, StartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ] ], Callable[ - [ServiceHandlerT, StartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ], @@ -90,11 +91,11 @@ def workflow_run_operation( def decorator( start: Callable[ - [ServiceHandlerT, StartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ) -> Callable[ - [ServiceHandlerT, StartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ]: ( @@ -108,7 +109,7 @@ def operation_handler_factory( async def _start( ctx: StartOperationContext, input: InputT ) -> WorkflowHandle[OutputT]: - return await start(self, ctx, input) + return await start(self, WorkflowRunOperationContext(ctx), input) _start.__doc__ = start.__doc__ return WorkflowRunOperationHandler(_start, input_type, output_type) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index 721f9b1f7..2a7c0d4d5 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -41,7 +41,8 @@ class WorkflowRunOperationHandler(OperationHandler[InputT, OutputT]): Use this class to create an operation handler that starts a workflow by passing your ``start`` method to the constructor. Your ``start`` method must use - :py:func:`temporalio.nexus.start_workflow` to start the workflow. + :py:func:`temporalio.nexus.WorkflowRunOperationContext.start_workflow` to start the + workflow. """ def __init__( diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index a46ff3151..e60094108 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -13,13 +13,14 @@ Union, ) -from nexusrpc.handler import StartOperationContext from nexusrpc.types import ( InputT, OutputT, ServiceHandlerT, ) +from temporalio.nexus._workflow import WorkflowRunOperationContext + from ._token import ( WorkflowHandle as WorkflowHandle, ) @@ -27,7 +28,7 @@ def get_workflow_run_start_method_input_and_output_type_annotations( start: Callable[ - [ServiceHandlerT, StartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ) -> tuple[ diff --git a/temporalio/nexus/_workflow.py b/temporalio/nexus/_workflow.py index 272610533..61b5c5017 100644 --- a/temporalio/nexus/_workflow.py +++ b/temporalio/nexus/_workflow.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from datetime import timedelta from typing import ( Any, @@ -9,6 +10,8 @@ Union, ) +from nexusrpc.handler import StartOperationContext + import temporalio.api.common.v1 import temporalio.api.enums.v1 import temporalio.common @@ -22,116 +25,118 @@ ) -# Overload for single-param workflow -# TODO(nexus-prerelease): bring over other overloads -async def start_workflow( - workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], - arg: ParamType, - *, - id: str, - task_queue: Optional[str] = None, - execution_timeout: Optional[timedelta] = None, - run_timeout: Optional[timedelta] = None, - task_timeout: Optional[timedelta] = None, - id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, - id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, - retry_policy: Optional[temporalio.common.RetryPolicy] = None, - cron_schedule: str = "", - memo: Optional[Mapping[str, Any]] = None, - search_attributes: Optional[ - Union[ - temporalio.common.TypedSearchAttributes, - temporalio.common.SearchAttributes, - ] - ] = None, - static_summary: Optional[str] = None, - static_details: Optional[str] = None, - start_delay: Optional[timedelta] = None, - start_signal: Optional[str] = None, - start_signal_args: Sequence[Any] = [], - rpc_metadata: Mapping[str, str] = {}, - rpc_timeout: Optional[timedelta] = None, - request_eager_start: bool = False, - priority: temporalio.common.Priority = temporalio.common.Priority.default, - versioning_override: Optional[temporalio.common.VersioningOverride] = None, -) -> WorkflowHandle[ReturnType]: - """Start a workflow that will deliver the result of the Nexus operation. - - The workflow will be started in the same namespace as the Nexus worker, using - the same client as the worker. If task queue is not specified, the worker's task - queue will be used. - - See :py:meth:`temporalio.client.Client.start_workflow` for all arguments. - - The return value is :py:class:`temporalio.nexus.WorkflowOperationToken`. - Use :py:meth:`temporalio.nexus.WorkflowOperationToken.to_workflow_handle` - to get a :py:class:`temporalio.client.WorkflowHandle` for interacting with the - workflow. - - The workflow will be started as usual, with the following modifications: - - - On workflow completion, Temporal server will deliver the workflow result to - the Nexus operation caller, using the callback from the Nexus operation start - request. - - - The request ID from the Nexus operation start request will be used as the - request ID for the start workflow request. - - - Inbound links to the caller that were submitted in the Nexus start operation - request will be attached to the started workflow and, outbound links to the - started workflow will be added to the Nexus start operation response. If the - Nexus caller is itself a workflow, this means that the workflow in the caller - namespace web UI will contain links to the started workflow, and vice versa. - """ - tctx = _TemporalNexusOperationContext.get() - start_operation_context = tctx._temporal_start_operation_context - if not start_operation_context: - raise RuntimeError( - "temporalio.nexus.start_workflow() must be called from " - "within a Nexus start operation context" +@dataclass +class WorkflowRunOperationContext: + start_operation_context: StartOperationContext + + # Overload for single-param workflow + # TODO(nexus-prerelease): bring over other overloads + async def start_workflow( + self, + workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: Optional[str] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, + ) -> WorkflowHandle[ReturnType]: + """Start a workflow that will deliver the result of the Nexus operation. + + The workflow will be started in the same namespace as the Nexus worker, using + the same client as the worker. If task queue is not specified, the worker's task + queue will be used. + + See :py:meth:`temporalio.client.Client.start_workflow` for all arguments. + + The return value is :py:class:`temporalio.nexus.WorkflowHandle`. + + The workflow will be started as usual, with the following modifications: + + - On workflow completion, Temporal server will deliver the workflow result to + the Nexus operation caller, using the callback from the Nexus operation start + request. + + - The request ID from the Nexus operation start request will be used as the + request ID for the start workflow request. + + - Inbound links to the caller that were submitted in the Nexus start operation + request will be attached to the started workflow and, outbound links to the + started workflow will be added to the Nexus start operation response. If the + Nexus caller is itself a workflow, this means that the workflow in the caller + namespace web UI will contain links to the started workflow, and vice versa. + """ + tctx = _TemporalNexusOperationContext.get() + start_operation_context = tctx._temporal_start_operation_context + if not start_operation_context: + raise RuntimeError( + "WorkflowRunOperationContext.start_workflow() must be called from " + "within a Nexus start operation context" + ) + + # TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this: + # if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') { + # internalOptions.onConflictOptions = { + # attachLinks: true, + # attachCompletionCallbacks: true, + # attachRequestId: true, + # }; + # } + + # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, + # but these are deliberately not exposed in overloads, hence the type-check + # violation. + wf_handle = await tctx.client.start_workflow( # type: ignore + workflow=workflow, + arg=arg, + id=id, + task_queue=task_queue or tctx.info().task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + start_signal=start_signal, + start_signal_args=start_signal_args, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + request_eager_start=request_eager_start, + priority=priority, + versioning_override=versioning_override, + nexus_completion_callbacks=start_operation_context.get_completion_callbacks(), + workflow_event_links=start_operation_context.get_workflow_event_links(), + request_id=start_operation_context.nexus_operation_context.request_id, ) - # TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this: - # if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') { - # internalOptions.onConflictOptions = { - # attachLinks: true, - # attachCompletionCallbacks: true, - # attachRequestId: true, - # }; - # } - - # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, - # but these are deliberately not exposed in overloads, hence the type-check - # violation. - wf_handle = await tctx.client.start_workflow( # type: ignore - workflow=workflow, - arg=arg, - id=id, - task_queue=task_queue or tctx.info().task_queue, - execution_timeout=execution_timeout, - run_timeout=run_timeout, - task_timeout=task_timeout, - id_reuse_policy=id_reuse_policy, - id_conflict_policy=id_conflict_policy, - retry_policy=retry_policy, - cron_schedule=cron_schedule, - memo=memo, - search_attributes=search_attributes, - static_summary=static_summary, - static_details=static_details, - start_delay=start_delay, - start_signal=start_signal, - start_signal_args=start_signal_args, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - request_eager_start=request_eager_start, - priority=priority, - versioning_override=versioning_override, - nexus_completion_callbacks=start_operation_context.get_completion_callbacks(), - workflow_event_links=start_operation_context.get_workflow_event_links(), - request_id=start_operation_context.nexus_operation_context.request_id, - ) - - start_operation_context.add_outbound_links(wf_handle) - - return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) + start_operation_context.add_outbound_links(wf_handle) + + return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 8613e8812..7cdd10f1a 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -45,7 +45,7 @@ from temporalio.client import Client from temporalio.common import WorkflowIDReusePolicy from temporalio.exceptions import ApplicationError -from temporalio.nexus import workflow_run_operation +from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ( @@ -208,9 +208,9 @@ async def log(self, ctx: StartOperationContext, input: Input) -> Output: @workflow_run_operation async def workflow_run_operation_happy_path( - self, ctx: StartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: - return await nexus.start_workflow( + return await ctx.start_workflow( MyWorkflow.run, input, id=str(uuid.uuid4()), @@ -258,7 +258,7 @@ async def sync_operation_without_type_annotations(self, ctx, input): @workflow_run_operation async def workflow_run_operation_without_type_annotations(self, ctx, input): - return await nexus.start_workflow( + return await ctx.start_workflow( WorkflowWithoutTypeAnnotations.run, input, id=str(uuid.uuid4()), @@ -266,15 +266,16 @@ async def workflow_run_operation_without_type_annotations(self, ctx, input): @workflow_run_operation async def workflow_run_op_link_test( - self, ctx: StartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: + nctx = ctx.start_operation_context assert any( - link.url == "http://inbound-link/" for link in ctx.inbound_links + link.url == "http://inbound-link/" for link in nctx.inbound_links ), "Inbound link not found" - assert ctx.request_id == "test-request-id-123", "Request ID mismatch" - ctx.outbound_links.extend(ctx.inbound_links) + assert nctx.request_id == "test-request-id-123", "Request ID mismatch" + nctx.outbound_links.extend(nctx.inbound_links) - return await nexus.start_workflow( + return await ctx.start_workflow( MyLinkTestWorkflow.run, input, id=str(uuid.uuid4()), @@ -1019,9 +1020,9 @@ async def run(self, input: Input) -> Output: class ServiceHandlerForRequestIdTest: @workflow_run_operation async def operation_backed_by_a_workflow( - self, ctx: StartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: - return await nexus.start_workflow( + return await ctx.start_workflow( EchoWorkflow.run, input, id=input.value, @@ -1030,7 +1031,7 @@ async def operation_backed_by_a_workflow( @workflow_run_operation async def operation_that_executes_a_workflow_before_starting_the_backing_workflow( - self, ctx: StartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: await nexus.client().start_workflow( EchoWorkflow.run, @@ -1040,7 +1041,7 @@ async def operation_that_executes_a_workflow_before_starting_the_backing_workflo ) # This should fail. It will not fail if the Nexus request ID was incorrectly # propagated to both StartWorkflow requests. - return await nexus.start_workflow( + return await ctx.start_workflow( EchoWorkflow.run, input, id=input.value, diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index 114b1dacc..be98ff6d6 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -6,7 +6,7 @@ from nexusrpc.handler import StartOperationContext, sync_operation from temporalio import nexus -from temporalio.nexus import workflow_run_operation +from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation HTTP_PORT = 7243 @@ -37,7 +37,7 @@ class Interface: class Impl: @workflow_run_operation async def op( - self, ctx: StartOperationContext, input: str + self, ctx: WorkflowRunOperationContext, input: str ) -> nexus.WorkflowHandle[int]: ... error_message = None diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index cb0b41c99..b0c1f2ac4 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -8,10 +8,9 @@ import nexusrpc.handler import pytest -from nexusrpc.handler import StartOperationContext from temporalio import nexus -from temporalio.nexus import workflow_run_operation +from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation @dataclass @@ -35,7 +34,7 @@ class NotCalled(_TestCase): class Service: @workflow_run_operation async def my_workflow_run_operation_handler( - self, ctx: StartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... expected_operations = { @@ -53,7 +52,7 @@ class CalledWithoutArgs(_TestCase): class Service: @workflow_run_operation async def my_workflow_run_operation_handler( - self, ctx: StartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... expected_operations = NotCalled.expected_operations @@ -64,7 +63,7 @@ class CalledWithNameOverride(_TestCase): class Service: @workflow_run_operation(name="operation-name") async def workflow_run_operation_with_name_override( - self, ctx: StartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... expected_operations = { diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 3fe699608..4e953e554 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -39,6 +39,7 @@ from temporalio.common import WorkflowIDConflictPolicy from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError from temporalio.nexus import workflow_run_operation +from temporalio.nexus._workflow import WorkflowRunOperationContext from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -157,7 +158,11 @@ async def start( if isinstance(input.response_type, SyncResponse): return StartOperationResultSync(value=OpOutput(value="sync response")) elif isinstance(input.response_type, AsyncResponse): - handle = await nexus.start_workflow( + # TODO(nexus-preview): this is a hack; perhaps it should be should be called + # temporalio.nexus.StartOperationContext instead of + # WorkflowRunOperationContext. + tctx = WorkflowRunOperationContext(ctx) + handle = await tctx.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, @@ -203,7 +208,7 @@ async def sync_operation( @workflow_run_operation async def async_operation( - self, ctx: StartOperationContext, input: OpInput + self, ctx: WorkflowRunOperationContext, input: OpInput ) -> nexus.WorkflowHandle[HandlerWfOutput]: assert isinstance(input.response_type, AsyncResponse) if input.response_type.exception_in_operation_start: @@ -212,7 +217,7 @@ async def async_operation( RPCStatusCode.INVALID_ARGUMENT, b"", ) - return await nexus.start_workflow( + return await ctx.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, @@ -914,7 +919,7 @@ async def run(self, input: str) -> str: class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: @workflow_run_operation async def my_workflow_run_operation( - self, ctx: StartOperationContext, input: None + self, ctx: WorkflowRunOperationContext, input: None ) -> nexus.WorkflowHandle[str]: result_1 = await nexus.client().execute_workflow( EchoWorkflow.run, @@ -925,7 +930,7 @@ async def my_workflow_run_operation( # In case result_1 is incorrectly being delivered to the caller as the operation # result, give time for that incorrect behavior to occur. await asyncio.sleep(0.5) - return await nexus.start_workflow( + return await ctx.start_workflow( EchoWorkflow.run, f"{result_1}-result-2", id=str(uuid.uuid4()), diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index d121a53c3..31d62fba4 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -12,8 +12,9 @@ ) from nexusrpc.handler._decorators import operation_handler -from temporalio import nexus, workflow +from temporalio import workflow from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler +from temporalio.nexus._workflow import WorkflowRunOperationContext from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ( @@ -48,7 +49,8 @@ def __init__(self): async def start( self, ctx: StartOperationContext, input: Input ) -> StartOperationResultAsync: - handle = await nexus.start_workflow( + tctx = WorkflowRunOperationContext(ctx) + handle = await tctx.start_workflow( EchoWorkflow.run, input.value, id=str(uuid.uuid4()), From 39f220c5b1efa0344261c40d267de144e76cee4a Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 27 Jun 2025 08:08:46 -0400 Subject: [PATCH 098/183] Move WorkflowRunOperationContext to _operation_context module --- temporalio/nexus/__init__.py | 4 +- temporalio/nexus/_decorators.py | 2 +- temporalio/nexus/_operation_context.py | 127 ++++++++++++++++++ temporalio/nexus/_util.py | 2 +- temporalio/nexus/_workflow.py | 142 --------------------- tests/nexus/test_workflow_caller.py | 3 +- tests/nexus/test_workflow_run_operation.py | 2 +- 7 files changed, 134 insertions(+), 148 deletions(-) delete mode 100644 temporalio/nexus/_workflow.py diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index 5ac1c1599..bdcc0b7a9 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -8,6 +8,9 @@ from ._decorators import workflow_run_operation as workflow_run_operation from ._operation_context import Info as Info +from ._operation_context import ( + WorkflowRunOperationContext as WorkflowRunOperationContext, +) from ._operation_context import ( _temporal_operation_context as _temporal_operation_context, ) @@ -18,7 +21,6 @@ from ._operation_context import info as info from ._operation_handlers import cancel_operation as cancel_operation from ._token import WorkflowHandle as WorkflowHandle -from ._workflow import WorkflowRunOperationContext as WorkflowRunOperationContext class LoggerAdapter(logging.LoggerAdapter): diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index 07044d151..eda23031b 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -15,6 +15,7 @@ ) from nexusrpc.types import InputT, OutputT, ServiceHandlerT +from temporalio.nexus._operation_context import WorkflowRunOperationContext from temporalio.nexus._operation_handlers import ( WorkflowRunOperationHandler, ) @@ -25,7 +26,6 @@ get_callable_name, get_workflow_run_start_method_input_and_output_type_annotations, ) -from temporalio.nexus._workflow import WorkflowRunOperationContext @overload diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index e00f490fd..0c47237c1 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -5,10 +5,13 @@ import urllib.parse from contextvars import ContextVar from dataclasses import dataclass +from datetime import timedelta from typing import ( Any, Callable, + Mapping, Optional, + Sequence, Union, ) @@ -19,6 +22,13 @@ import temporalio.api.enums.v1 import temporalio.client import temporalio.common +from temporalio.nexus._token import WorkflowHandle +from temporalio.types import ( + MethodAsyncSingleParam, + ParamType, + ReturnType, + SelfType, +) logger = logging.getLogger(__name__) @@ -92,6 +102,123 @@ def _temporal_cancel_operation_context( return _TemporalCancelOperationContext(ctx) +@dataclass +class WorkflowRunOperationContext: + start_operation_context: StartOperationContext + + # Overload for single-param workflow + # TODO(nexus-prerelease): bring over other overloads + async def start_workflow( + self, + workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: Optional[str] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, + ) -> WorkflowHandle[ReturnType]: + """Start a workflow that will deliver the result of the Nexus operation. + + The workflow will be started in the same namespace as the Nexus worker, using + the same client as the worker. If task queue is not specified, the worker's task + queue will be used. + + See :py:meth:`temporalio.client.Client.start_workflow` for all arguments. + + The return value is :py:class:`temporalio.nexus.WorkflowHandle`. + + The workflow will be started as usual, with the following modifications: + + - On workflow completion, Temporal server will deliver the workflow result to + the Nexus operation caller, using the callback from the Nexus operation start + request. + + - The request ID from the Nexus operation start request will be used as the + request ID for the start workflow request. + + - Inbound links to the caller that were submitted in the Nexus start operation + request will be attached to the started workflow and, outbound links to the + started workflow will be added to the Nexus start operation response. If the + Nexus caller is itself a workflow, this means that the workflow in the caller + namespace web UI will contain links to the started workflow, and vice versa. + """ + tctx = _TemporalNexusOperationContext.get() + start_operation_context = tctx._temporal_start_operation_context + if not start_operation_context: + raise RuntimeError( + "WorkflowRunOperationContext.start_workflow() must be called from " + "within a Nexus start operation context" + ) + + # TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this: + # if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') { + # internalOptions.onConflictOptions = { + # attachLinks: true, + # attachCompletionCallbacks: true, + # attachRequestId: true, + # }; + # } + + # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, + # but these are deliberately not exposed in overloads, hence the type-check + # violation. + wf_handle = await tctx.client.start_workflow( # type: ignore + workflow=workflow, + arg=arg, + id=id, + task_queue=task_queue or tctx.info().task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + start_signal=start_signal, + start_signal_args=start_signal_args, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + request_eager_start=request_eager_start, + priority=priority, + versioning_override=versioning_override, + nexus_completion_callbacks=start_operation_context.get_completion_callbacks(), + workflow_event_links=start_operation_context.get_workflow_event_links(), + request_id=start_operation_context.nexus_operation_context.request_id, + ) + + start_operation_context.add_outbound_links(wf_handle) + + return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) + + @dataclass class _TemporalStartOperationContext: nexus_operation_context: StartOperationContext diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index e60094108..9cb4af50f 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -19,7 +19,7 @@ ServiceHandlerT, ) -from temporalio.nexus._workflow import WorkflowRunOperationContext +from temporalio.nexus._operation_context import WorkflowRunOperationContext from ._token import ( WorkflowHandle as WorkflowHandle, diff --git a/temporalio/nexus/_workflow.py b/temporalio/nexus/_workflow.py deleted file mode 100644 index 61b5c5017..000000000 --- a/temporalio/nexus/_workflow.py +++ /dev/null @@ -1,142 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from datetime import timedelta -from typing import ( - Any, - Mapping, - Optional, - Sequence, - Union, -) - -from nexusrpc.handler import StartOperationContext - -import temporalio.api.common.v1 -import temporalio.api.enums.v1 -import temporalio.common -from temporalio.nexus._operation_context import _TemporalNexusOperationContext -from temporalio.nexus._token import WorkflowHandle -from temporalio.types import ( - MethodAsyncSingleParam, - ParamType, - ReturnType, - SelfType, -) - - -@dataclass -class WorkflowRunOperationContext: - start_operation_context: StartOperationContext - - # Overload for single-param workflow - # TODO(nexus-prerelease): bring over other overloads - async def start_workflow( - self, - workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], - arg: ParamType, - *, - id: str, - task_queue: Optional[str] = None, - execution_timeout: Optional[timedelta] = None, - run_timeout: Optional[timedelta] = None, - task_timeout: Optional[timedelta] = None, - id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, - id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, - retry_policy: Optional[temporalio.common.RetryPolicy] = None, - cron_schedule: str = "", - memo: Optional[Mapping[str, Any]] = None, - search_attributes: Optional[ - Union[ - temporalio.common.TypedSearchAttributes, - temporalio.common.SearchAttributes, - ] - ] = None, - static_summary: Optional[str] = None, - static_details: Optional[str] = None, - start_delay: Optional[timedelta] = None, - start_signal: Optional[str] = None, - start_signal_args: Sequence[Any] = [], - rpc_metadata: Mapping[str, str] = {}, - rpc_timeout: Optional[timedelta] = None, - request_eager_start: bool = False, - priority: temporalio.common.Priority = temporalio.common.Priority.default, - versioning_override: Optional[temporalio.common.VersioningOverride] = None, - ) -> WorkflowHandle[ReturnType]: - """Start a workflow that will deliver the result of the Nexus operation. - - The workflow will be started in the same namespace as the Nexus worker, using - the same client as the worker. If task queue is not specified, the worker's task - queue will be used. - - See :py:meth:`temporalio.client.Client.start_workflow` for all arguments. - - The return value is :py:class:`temporalio.nexus.WorkflowHandle`. - - The workflow will be started as usual, with the following modifications: - - - On workflow completion, Temporal server will deliver the workflow result to - the Nexus operation caller, using the callback from the Nexus operation start - request. - - - The request ID from the Nexus operation start request will be used as the - request ID for the start workflow request. - - - Inbound links to the caller that were submitted in the Nexus start operation - request will be attached to the started workflow and, outbound links to the - started workflow will be added to the Nexus start operation response. If the - Nexus caller is itself a workflow, this means that the workflow in the caller - namespace web UI will contain links to the started workflow, and vice versa. - """ - tctx = _TemporalNexusOperationContext.get() - start_operation_context = tctx._temporal_start_operation_context - if not start_operation_context: - raise RuntimeError( - "WorkflowRunOperationContext.start_workflow() must be called from " - "within a Nexus start operation context" - ) - - # TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this: - # if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') { - # internalOptions.onConflictOptions = { - # attachLinks: true, - # attachCompletionCallbacks: true, - # attachRequestId: true, - # }; - # } - - # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, - # but these are deliberately not exposed in overloads, hence the type-check - # violation. - wf_handle = await tctx.client.start_workflow( # type: ignore - workflow=workflow, - arg=arg, - id=id, - task_queue=task_queue or tctx.info().task_queue, - execution_timeout=execution_timeout, - run_timeout=run_timeout, - task_timeout=task_timeout, - id_reuse_policy=id_reuse_policy, - id_conflict_policy=id_conflict_policy, - retry_policy=retry_policy, - cron_schedule=cron_schedule, - memo=memo, - search_attributes=search_attributes, - static_summary=static_summary, - static_details=static_details, - start_delay=start_delay, - start_signal=start_signal, - start_signal_args=start_signal_args, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - request_eager_start=request_eager_start, - priority=priority, - versioning_override=versioning_override, - nexus_completion_callbacks=start_operation_context.get_completion_callbacks(), - workflow_event_links=start_operation_context.get_workflow_event_links(), - request_id=start_operation_context.nexus_operation_context.request_id, - ) - - start_operation_context.add_outbound_links(wf_handle) - - return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 4e953e554..c55a8d692 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -38,8 +38,7 @@ ) from temporalio.common import WorkflowIDConflictPolicy from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError -from temporalio.nexus import workflow_run_operation -from temporalio.nexus._workflow import WorkflowRunOperationContext +from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index 31d62fba4..740615f3e 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -13,8 +13,8 @@ from nexusrpc.handler._decorators import operation_handler from temporalio import workflow +from temporalio.nexus import WorkflowRunOperationContext from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler -from temporalio.nexus._workflow import WorkflowRunOperationContext from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ( From 6ce70ba6e8ea9f8c76074d05b95cd6f8d6cf4c99 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 15:18:16 -0400 Subject: [PATCH 099/183] Wire through additional context type in union --- temporalio/worker/_interceptor.py | 10 ++++++- temporalio/worker/_workflow_instance.py | 10 ++++++- temporalio/workflow.py | 36 ++++++++++++++++++++++--- 3 files changed, 50 insertions(+), 6 deletions(-) diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 59f8c8671..5e61e85e8 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -29,6 +29,7 @@ import temporalio.common import temporalio.nexus import temporalio.workflow +from temporalio.nexus import WorkflowRunOperationContext from temporalio.workflow import VersioningIntent @@ -301,7 +302,14 @@ class StartNexusOperationInput(Generic[InputT, OutputT]): nexusrpc.Operation[InputT, OutputT], Callable[[Any], nexusrpc.handler.OperationHandler[InputT, OutputT]], Callable[ - [Any, nexusrpc.handler.StartOperationContext, InputT], + [ + Any, + Union[ + nexusrpc.handler.StartOperationContext, + WorkflowRunOperationContext, + ], + InputT, + ], Awaitable[temporalio.nexus.WorkflowHandle[OutputT]], ], str, diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index cc7397601..7af15fc21 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -61,6 +61,7 @@ import temporalio.exceptions import temporalio.nexus import temporalio.workflow +from temporalio.nexus import WorkflowRunOperationContext from temporalio.service import __version__ from ._interceptor import ( @@ -1502,7 +1503,14 @@ async def workflow_start_nexus_operation( nexusrpc.Operation[I, O], Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], Callable[ - [Any, nexusrpc.handler.StartOperationContext, I], + [ + Any, + Union[ + nexusrpc.handler.StartOperationContext, + WorkflowRunOperationContext, + ], + I, + ], Awaitable[temporalio.nexus.WorkflowHandle[O]], ], str, diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 02d34fd90..eb125dfb1 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -58,6 +58,7 @@ import temporalio.exceptions import temporalio.nexus import temporalio.workflow +from temporalio.nexus import WorkflowRunOperationContext from .types import ( AnyType, @@ -858,7 +859,14 @@ async def workflow_start_nexus_operation( nexusrpc.Operation[I, O], Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], Callable[ - [S, nexusrpc.handler.StartOperationContext, I], + [ + S, + Union[ + nexusrpc.handler.StartOperationContext, + WorkflowRunOperationContext, + ], + I, + ], Awaitable[temporalio.nexus.WorkflowHandle[O]], ], str, @@ -4428,7 +4436,13 @@ async def start_nexus_operation( nexusrpc.Operation[I, O], Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], Callable[ - [S, nexusrpc.handler.StartOperationContext, I], + [ + S, + Union[ + nexusrpc.handler.StartOperationContext, WorkflowRunOperationContext + ], + I, + ], Awaitable[temporalio.nexus.WorkflowHandle[O]], ], str, @@ -5214,7 +5228,14 @@ async def start_operation( nexusrpc.Operation[I, O], Callable[[S], nexusrpc.handler.OperationHandler[I, O]], Callable[ - [S, nexusrpc.handler.StartOperationContext, I], + [ + S, + Union[ + nexusrpc.handler.StartOperationContext, + WorkflowRunOperationContext, + ], + I, + ], Awaitable[temporalio.nexus.WorkflowHandle[O]], ], str, @@ -5242,7 +5263,14 @@ async def execute_operation( nexusrpc.Operation[I, O], Callable[[S], nexusrpc.handler.OperationHandler[I, O]], Callable[ - [S, nexusrpc.handler.StartOperationContext, I], + [ + S, + Union[ + nexusrpc.handler.StartOperationContext, + WorkflowRunOperationContext, + ], + I, + ], Awaitable[temporalio.nexus.WorkflowHandle[O]], ], str, From f723602dad1eca43271c9ebd34766b6de468d70f Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 15:25:18 -0400 Subject: [PATCH 100/183] Eliminate unnecessary modeling of callable types --- temporalio/worker/_interceptor.py | 26 ++-------- temporalio/worker/_workflow_instance.py | 18 +------ temporalio/workflow.py | 68 ++----------------------- 3 files changed, 10 insertions(+), 102 deletions(-) diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 5e61e85e8..c19771921 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -29,7 +29,6 @@ import temporalio.common import temporalio.nexus import temporalio.workflow -from temporalio.nexus import WorkflowRunOperationContext from temporalio.workflow import VersioningIntent @@ -298,22 +297,7 @@ class StartNexusOperationInput(Generic[InputT, OutputT]): endpoint: str service: str - operation: Union[ - nexusrpc.Operation[InputT, OutputT], - Callable[[Any], nexusrpc.handler.OperationHandler[InputT, OutputT]], - Callable[ - [ - Any, - Union[ - nexusrpc.handler.StartOperationContext, - WorkflowRunOperationContext, - ], - InputT, - ], - Awaitable[temporalio.nexus.WorkflowHandle[OutputT]], - ], - str, - ] + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]] input: InputT schedule_to_close_timeout: Optional[timedelta] headers: Optional[Mapping[str, str]] @@ -324,13 +308,13 @@ class StartNexusOperationInput(Generic[InputT, OutputT]): # TODO(nexus-prerelease): update this logic to handle service impl start methods def __post_init__(self) -> None: - if isinstance(self.operation, str): - self._operation_name = self.operation - self._input_type = None - elif isinstance(self.operation, nexusrpc.Operation): + if isinstance(self.operation, nexusrpc.Operation): self._operation_name = self.operation.name self._input_type = self.operation.input_type self.output_type = self.operation.output_type + elif isinstance(self.operation, str): + self._operation_name = self.operation + self._input_type = None elif isinstance(self.operation, Callable): _, op = nexusrpc.handler.get_operation_factory(self.operation) if isinstance(op, nexusrpc.Operation): diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 7af15fc21..f22c6d4c6 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -61,7 +61,6 @@ import temporalio.exceptions import temporalio.nexus import temporalio.workflow -from temporalio.nexus import WorkflowRunOperationContext from temporalio.service import __version__ from ._interceptor import ( @@ -1499,22 +1498,7 @@ async def workflow_start_nexus_operation( self, endpoint: str, service: str, - operation: Union[ - nexusrpc.Operation[I, O], - Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], - Callable[ - [ - Any, - Union[ - nexusrpc.handler.StartOperationContext, - WorkflowRunOperationContext, - ], - I, - ], - Awaitable[temporalio.nexus.WorkflowHandle[O]], - ], - str, - ], + operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]], input: Any, output_type: Optional[Type[O]] = None, schedule_to_close_timeout: Optional[timedelta] = None, diff --git a/temporalio/workflow.py b/temporalio/workflow.py index eb125dfb1..3a14989e3 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -58,7 +58,6 @@ import temporalio.exceptions import temporalio.nexus import temporalio.workflow -from temporalio.nexus import WorkflowRunOperationContext from .types import ( AnyType, @@ -855,22 +854,7 @@ async def workflow_start_nexus_operation( self, endpoint: str, service: str, - operation: Union[ - nexusrpc.Operation[I, O], - Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], - Callable[ - [ - S, - Union[ - nexusrpc.handler.StartOperationContext, - WorkflowRunOperationContext, - ], - I, - ], - Awaitable[temporalio.nexus.WorkflowHandle[O]], - ], - str, - ], + operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]], input: Any, output_type: Optional[Type[O]] = None, schedule_to_close_timeout: Optional[timedelta] = None, @@ -4432,21 +4416,7 @@ def operation_token(self) -> Optional[str]: async def start_nexus_operation( endpoint: str, service: str, - operation: Union[ - nexusrpc.Operation[I, O], - Callable[[Any], nexusrpc.handler.OperationHandler[I, O]], - Callable[ - [ - S, - Union[ - nexusrpc.handler.StartOperationContext, WorkflowRunOperationContext - ], - I, - ], - Awaitable[temporalio.nexus.WorkflowHandle[O]], - ], - str, - ], + operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]], input: Any, *, output_type: Optional[Type[O]] = None, @@ -5224,22 +5194,7 @@ def __init__( # TODO(nexus-prerelease): should it be an error to use a reference to a method on a class other than that supplied? async def start_operation( self, - operation: Union[ - nexusrpc.Operation[I, O], - Callable[[S], nexusrpc.handler.OperationHandler[I, O]], - Callable[ - [ - S, - Union[ - nexusrpc.handler.StartOperationContext, - WorkflowRunOperationContext, - ], - I, - ], - Awaitable[temporalio.nexus.WorkflowHandle[O]], - ], - str, - ], + operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]], input: I, *, output_type: Optional[Type[O]] = None, @@ -5259,22 +5214,7 @@ async def start_operation( # TODO(nexus-prerelease): overloads: no-input, ret type async def execute_operation( self, - operation: Union[ - nexusrpc.Operation[I, O], - Callable[[S], nexusrpc.handler.OperationHandler[I, O]], - Callable[ - [ - S, - Union[ - nexusrpc.handler.StartOperationContext, - WorkflowRunOperationContext, - ], - I, - ], - Awaitable[temporalio.nexus.WorkflowHandle[O]], - ], - str, - ], + operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]], input: I, *, output_type: Optional[Type[O]] = None, From ddaee4ff12e6d7dda8e070a4f84c8f3cda7b624b Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 16:28:53 -0400 Subject: [PATCH 101/183] Fix passing Nexus context headers/request ID from worker --- temporalio/worker/_nexus.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 07536f343..54b94a6e7 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -9,6 +9,7 @@ from typing import ( Any, Callable, + Mapping, NoReturn, Optional, Sequence, @@ -105,7 +106,9 @@ async def raise_from_exception_queue() -> NoReturn: # tasks as we do start operation tasks? asyncio.create_task( self._handle_cancel_operation_task( - task.request.cancel_operation, task.task_token + task.task_token, + task.request.cancel_operation, + dict(task.request.header), ) ) else: @@ -155,7 +158,10 @@ async def wait_all_completed(self) -> None: # "Any call up to this function and including this one will be trimmed out of stack traces."" async def _handle_cancel_operation_task( - self, request: temporalio.api.nexus.v1.CancelOperationRequest, task_token: bytes + self, + task_token: bytes, + request: temporalio.api.nexus.v1.CancelOperationRequest, + headers: Mapping[str, str], ) -> None: """ Handle a cancel operation task. @@ -163,9 +169,11 @@ async def _handle_cancel_operation_task( Attempt to execute the user cancel_operation method. Handle errors and send the task completion. """ + # TODO(nexus-prerelease): headers ctx = CancelOperationContext( service=request.service, operation=request.operation, + headers=headers, ) _temporal_operation_context.set( _TemporalNexusOperationContext( @@ -174,7 +182,6 @@ async def _handle_cancel_operation_task( client=self._client, ) ) - # TODO(nexus-prerelease): headers try: await self._handler.cancel_operation(ctx, request.operation_token) except Exception as err: @@ -202,7 +209,7 @@ async def _handle_start_operation_task( self, task_token: bytes, start_request: temporalio.api.nexus.v1.StartOperationRequest, - headers: dict[str, str], + headers: Mapping[str, str], ) -> None: """ Handle a start operation task. @@ -243,7 +250,7 @@ async def _handle_start_operation_task( async def _start_operation( self, start_request: temporalio.api.nexus.v1.StartOperationRequest, - headers: dict[str, str], + headers: Mapping[str, str], ) -> temporalio.api.nexus.v1.StartOperationResponse: """ Invoke the Nexus handler's start_operation method and construct the StartOperationResponse. From fc285e001218ce50bde6c5b1c2bc81f3ec23a16b Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 17:07:48 -0400 Subject: [PATCH 102/183] Always passthrough nexusrpc --- temporalio/worker/workflow_sandbox/_restrictions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/temporalio/worker/workflow_sandbox/_restrictions.py b/temporalio/worker/workflow_sandbox/_restrictions.py index fdc126809..32f7ba012 100644 --- a/temporalio/worker/workflow_sandbox/_restrictions.py +++ b/temporalio/worker/workflow_sandbox/_restrictions.py @@ -471,6 +471,7 @@ def with_child_unrestricted(self, *child_path: str) -> SandboxMatcher: # https://wrapt.readthedocs.io/en/latest/issues.html#using-issubclass-on-abstract-classes "asyncio", "abc", + "nexusrpc", "temporalio", # Due to pkg_resources use of base classes caused by the ABC issue # above, and Otel's use of pkg_resources, we pass it through From bd888670deb50de40397dadf148ace63227c5c91 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 17:09:16 -0400 Subject: [PATCH 103/183] Revert disabling of sandbox for nexus workflow tests --- tests/nexus/test_workflow_caller.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index c55a8d692..cefe7b46a 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -40,7 +40,7 @@ from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation from temporalio.service import RPCError, RPCStatusCode -from temporalio.worker import UnsandboxedWorkflowRunner, Worker +from temporalio.worker import Worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name # TODO(dan): test availability of Temporal client etc in async context set by worker @@ -444,8 +444,6 @@ async def test_sync_response( nexus_service_handlers=[ServiceImpl()], workflows=[CallerWorkflow, HandlerWorkflow], task_queue=task_queue, - # TODO(dan): enable sandbox - workflow_runner=UnsandboxedWorkflowRunner(), workflow_failure_exception_types=[Exception], ): await create_nexus_endpoint(task_queue, client) @@ -517,7 +515,6 @@ async def test_async_response( nexus_service_handlers=[ServiceImpl()], workflows=[CallerWorkflow, HandlerWorkflow], task_queue=task_queue, - workflow_runner=UnsandboxedWorkflowRunner(), workflow_failure_exception_types=[Exception], ): caller_wf_handle, handler_wf_handle = await _start_wf_and_nexus_op( @@ -673,7 +670,6 @@ async def test_untyped_caller( workflows=[UntypedCallerWorkflow, HandlerWorkflow], nexus_service_handlers=[ServiceImpl()], task_queue=task_queue, - workflow_runner=UnsandboxedWorkflowRunner(), workflow_failure_exception_types=[Exception], ): if response_type == SyncResponse: @@ -851,7 +847,6 @@ async def test_service_interface_and_implementation_names(client: Client): ], workflows=[ServiceInterfaceAndImplCallerWorkflow], task_queue=task_queue, - workflow_runner=UnsandboxedWorkflowRunner(), workflow_failure_exception_types=[Exception], ): await create_nexus_endpoint(task_queue, client) @@ -965,7 +960,6 @@ async def test_workflow_run_operation_can_execute_workflow_before_starting_backi ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow(), ], task_queue=task_queue, - workflow_runner=UnsandboxedWorkflowRunner(), ): await create_nexus_endpoint(task_queue, client) result = await client.execute_workflow( From d153d94d981bfcda264194f8c6ac920cedf27792 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 17:12:36 -0400 Subject: [PATCH 104/183] Passthrough 3rd-party imports in tests helpers module --- tests/helpers/nexus.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/helpers/nexus.py b/tests/helpers/nexus.py index 5fb134140..46460d77c 100644 --- a/tests/helpers/nexus.py +++ b/tests/helpers/nexus.py @@ -2,15 +2,17 @@ from dataclasses import dataclass from typing import Any, Mapping, Optional -import httpx -from google.protobuf import json_format - import temporalio.api import temporalio.api.nexus.v1 import temporalio.api.operatorservice.v1 +import temporalio.workflow from temporalio.client import Client from temporalio.converter import FailureConverter, PayloadConverter +with temporalio.workflow.unsafe.imports_passed_through(): + import httpx + from google.protobuf import json_format + def make_nexus_endpoint_name(task_queue: str) -> str: # Create endpoints for different task queues without name collisions. From 7358f0036961af8129b95ee820dc71e77ad58c4e Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 17:43:41 -0400 Subject: [PATCH 105/183] uv.lock --- uv.lock | 131 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 127 insertions(+), 4 deletions(-) diff --git a/uv.lock b/uv.lock index f753830c7..fdea8c415 100644 --- a/uv.lock +++ b/uv.lock @@ -287,6 +287,85 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b8/40/c199d095151addf69efdb4b9ca3a4f20f70e20508d6222bffb9b76f58573/constantly-23.10.4-py3-none-any.whl", hash = "sha256:3fd9b4d1c3dc1ec9757f3c52aef7e53ad9323dbe39f51dfd4c43853b68dfa3f9", size = 13547, upload-time = "2023-10-28T23:18:23.038Z" }, ] +[[package]] +name = "coverage" +version = "7.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/e0/98670a80884f64578f0c22cd70c5e81a6e07b08167721c7487b4d70a7ca0/coverage-7.9.1.tar.gz", hash = "sha256:6cf43c78c4282708a28e466316935ec7489a9c487518a77fa68f716c67909cec", size = 813650, upload-time = "2025-06-13T13:02:28.627Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/78/1c1c5ec58f16817c09cbacb39783c3655d54a221b6552f47ff5ac9297603/coverage-7.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cc94d7c5e8423920787c33d811c0be67b7be83c705f001f7180c7b186dcf10ca", size = 212028, upload-time = "2025-06-13T13:00:29.293Z" }, + { url = "https://files.pythonhosted.org/packages/98/db/e91b9076f3a888e3b4ad7972ea3842297a52cc52e73fd1e529856e473510/coverage-7.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:16aa0830d0c08a2c40c264cef801db8bc4fc0e1892782e45bcacbd5889270509", size = 212420, upload-time = "2025-06-13T13:00:34.027Z" }, + { url = "https://files.pythonhosted.org/packages/0e/d0/2b3733412954576b0aea0a16c3b6b8fbe95eb975d8bfa10b07359ead4252/coverage-7.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf95981b126f23db63e9dbe4cf65bd71f9a6305696fa5e2262693bc4e2183f5b", size = 241529, upload-time = "2025-06-13T13:00:35.786Z" }, + { url = "https://files.pythonhosted.org/packages/b3/00/5e2e5ae2e750a872226a68e984d4d3f3563cb01d1afb449a17aa819bc2c4/coverage-7.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f05031cf21699785cd47cb7485f67df619e7bcdae38e0fde40d23d3d0210d3c3", size = 239403, upload-time = "2025-06-13T13:00:37.399Z" }, + { url = "https://files.pythonhosted.org/packages/37/3b/a2c27736035156b0a7c20683afe7df498480c0dfdf503b8c878a21b6d7fb/coverage-7.9.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4fbcab8764dc072cb651a4bcda4d11fb5658a1d8d68842a862a6610bd8cfa3", size = 240548, upload-time = "2025-06-13T13:00:39.647Z" }, + { url = "https://files.pythonhosted.org/packages/98/f5/13d5fc074c3c0e0dc80422d9535814abf190f1254d7c3451590dc4f8b18c/coverage-7.9.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0f16649a7330ec307942ed27d06ee7e7a38417144620bb3d6e9a18ded8a2d3e5", size = 240459, upload-time = "2025-06-13T13:00:40.934Z" }, + { url = "https://files.pythonhosted.org/packages/36/24/24b9676ea06102df824c4a56ffd13dc9da7904478db519efa877d16527d5/coverage-7.9.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:cea0a27a89e6432705fffc178064503508e3c0184b4f061700e771a09de58187", size = 239128, upload-time = "2025-06-13T13:00:42.343Z" }, + { url = "https://files.pythonhosted.org/packages/be/05/242b7a7d491b369ac5fee7908a6e5ba42b3030450f3ad62c645b40c23e0e/coverage-7.9.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e980b53a959fa53b6f05343afbd1e6f44a23ed6c23c4b4c56c6662bbb40c82ce", size = 239402, upload-time = "2025-06-13T13:00:43.634Z" }, + { url = "https://files.pythonhosted.org/packages/73/e0/4de7f87192fa65c9c8fbaeb75507e124f82396b71de1797da5602898be32/coverage-7.9.1-cp310-cp310-win32.whl", hash = "sha256:70760b4c5560be6ca70d11f8988ee6542b003f982b32f83d5ac0b72476607b70", size = 214518, upload-time = "2025-06-13T13:00:45.622Z" }, + { url = "https://files.pythonhosted.org/packages/d5/ab/5e4e2fe458907d2a65fab62c773671cfc5ac704f1e7a9ddd91996f66e3c2/coverage-7.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:a66e8f628b71f78c0e0342003d53b53101ba4e00ea8dabb799d9dba0abbbcebe", size = 215436, upload-time = "2025-06-13T13:00:47.245Z" }, + { url = "https://files.pythonhosted.org/packages/60/34/fa69372a07d0903a78ac103422ad34db72281c9fc625eba94ac1185da66f/coverage-7.9.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:95c765060e65c692da2d2f51a9499c5e9f5cf5453aeaf1420e3fc847cc060582", size = 212146, upload-time = "2025-06-13T13:00:48.496Z" }, + { url = "https://files.pythonhosted.org/packages/27/f0/da1894915d2767f093f081c42afeba18e760f12fdd7a2f4acbe00564d767/coverage-7.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ba383dc6afd5ec5b7a0d0c23d38895db0e15bcba7fb0fa8901f245267ac30d86", size = 212536, upload-time = "2025-06-13T13:00:51.535Z" }, + { url = "https://files.pythonhosted.org/packages/10/d5/3fc33b06e41e390f88eef111226a24e4504d216ab8e5d1a7089aa5a3c87a/coverage-7.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37ae0383f13cbdcf1e5e7014489b0d71cc0106458878ccde52e8a12ced4298ed", size = 245092, upload-time = "2025-06-13T13:00:52.883Z" }, + { url = "https://files.pythonhosted.org/packages/0a/39/7aa901c14977aba637b78e95800edf77f29f5a380d29768c5b66f258305b/coverage-7.9.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:69aa417a030bf11ec46149636314c24c8d60fadb12fc0ee8f10fda0d918c879d", size = 242806, upload-time = "2025-06-13T13:00:54.571Z" }, + { url = "https://files.pythonhosted.org/packages/43/fc/30e5cfeaf560b1fc1989227adedc11019ce4bb7cce59d65db34fe0c2d963/coverage-7.9.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a4be2a28656afe279b34d4f91c3e26eccf2f85500d4a4ff0b1f8b54bf807338", size = 244610, upload-time = "2025-06-13T13:00:56.932Z" }, + { url = "https://files.pythonhosted.org/packages/bf/15/cca62b13f39650bc87b2b92bb03bce7f0e79dd0bf2c7529e9fc7393e4d60/coverage-7.9.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:382e7ddd5289f140259b610e5f5c58f713d025cb2f66d0eb17e68d0a94278875", size = 244257, upload-time = "2025-06-13T13:00:58.545Z" }, + { url = "https://files.pythonhosted.org/packages/cd/1a/c0f2abe92c29e1464dbd0ff9d56cb6c88ae2b9e21becdb38bea31fcb2f6c/coverage-7.9.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e5532482344186c543c37bfad0ee6069e8ae4fc38d073b8bc836fc8f03c9e250", size = 242309, upload-time = "2025-06-13T13:00:59.836Z" }, + { url = "https://files.pythonhosted.org/packages/57/8d/c6fd70848bd9bf88fa90df2af5636589a8126d2170f3aade21ed53f2b67a/coverage-7.9.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a39d18b3f50cc121d0ce3838d32d58bd1d15dab89c910358ebefc3665712256c", size = 242898, upload-time = "2025-06-13T13:01:02.506Z" }, + { url = "https://files.pythonhosted.org/packages/c2/9e/6ca46c7bff4675f09a66fe2797cd1ad6a24f14c9c7c3b3ebe0470a6e30b8/coverage-7.9.1-cp311-cp311-win32.whl", hash = "sha256:dd24bd8d77c98557880def750782df77ab2b6885a18483dc8588792247174b32", size = 214561, upload-time = "2025-06-13T13:01:04.012Z" }, + { url = "https://files.pythonhosted.org/packages/a1/30/166978c6302010742dabcdc425fa0f938fa5a800908e39aff37a7a876a13/coverage-7.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:6b55ad10a35a21b8015eabddc9ba31eb590f54adc9cd39bcf09ff5349fd52125", size = 215493, upload-time = "2025-06-13T13:01:05.702Z" }, + { url = "https://files.pythonhosted.org/packages/60/07/a6d2342cd80a5be9f0eeab115bc5ebb3917b4a64c2953534273cf9bc7ae6/coverage-7.9.1-cp311-cp311-win_arm64.whl", hash = "sha256:6ad935f0016be24c0e97fc8c40c465f9c4b85cbbe6eac48934c0dc4d2568321e", size = 213869, upload-time = "2025-06-13T13:01:09.345Z" }, + { url = "https://files.pythonhosted.org/packages/68/d9/7f66eb0a8f2fce222de7bdc2046ec41cb31fe33fb55a330037833fb88afc/coverage-7.9.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a8de12b4b87c20de895f10567639c0797b621b22897b0af3ce4b4e204a743626", size = 212336, upload-time = "2025-06-13T13:01:10.909Z" }, + { url = "https://files.pythonhosted.org/packages/20/20/e07cb920ef3addf20f052ee3d54906e57407b6aeee3227a9c91eea38a665/coverage-7.9.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5add197315a054e92cee1b5f686a2bcba60c4c3e66ee3de77ace6c867bdee7cb", size = 212571, upload-time = "2025-06-13T13:01:12.518Z" }, + { url = "https://files.pythonhosted.org/packages/78/f8/96f155de7e9e248ca9c8ff1a40a521d944ba48bec65352da9be2463745bf/coverage-7.9.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:600a1d4106fe66f41e5d0136dfbc68fe7200a5cbe85610ddf094f8f22e1b0300", size = 246377, upload-time = "2025-06-13T13:01:14.87Z" }, + { url = "https://files.pythonhosted.org/packages/3e/cf/1d783bd05b7bca5c10ded5f946068909372e94615a4416afadfe3f63492d/coverage-7.9.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a876e4c3e5a2a1715a6608906aa5a2e0475b9c0f68343c2ada98110512ab1d8", size = 243394, upload-time = "2025-06-13T13:01:16.23Z" }, + { url = "https://files.pythonhosted.org/packages/02/dd/e7b20afd35b0a1abea09fb3998e1abc9f9bd953bee548f235aebd2b11401/coverage-7.9.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81f34346dd63010453922c8e628a52ea2d2ccd73cb2487f7700ac531b247c8a5", size = 245586, upload-time = "2025-06-13T13:01:17.532Z" }, + { url = "https://files.pythonhosted.org/packages/4e/38/b30b0006fea9d617d1cb8e43b1bc9a96af11eff42b87eb8c716cf4d37469/coverage-7.9.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:888f8eee13f2377ce86d44f338968eedec3291876b0b8a7289247ba52cb984cd", size = 245396, upload-time = "2025-06-13T13:01:19.164Z" }, + { url = "https://files.pythonhosted.org/packages/31/e4/4d8ec1dc826e16791f3daf1b50943e8e7e1eb70e8efa7abb03936ff48418/coverage-7.9.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9969ef1e69b8c8e1e70d591f91bbc37fc9a3621e447525d1602801a24ceda898", size = 243577, upload-time = "2025-06-13T13:01:22.433Z" }, + { url = "https://files.pythonhosted.org/packages/25/f4/b0e96c5c38e6e40ef465c4bc7f138863e2909c00e54a331da335faf0d81a/coverage-7.9.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:60c458224331ee3f1a5b472773e4a085cc27a86a0b48205409d364272d67140d", size = 244809, upload-time = "2025-06-13T13:01:24.143Z" }, + { url = "https://files.pythonhosted.org/packages/8a/65/27e0a1fa5e2e5079bdca4521be2f5dabf516f94e29a0defed35ac2382eb2/coverage-7.9.1-cp312-cp312-win32.whl", hash = "sha256:5f646a99a8c2b3ff4c6a6e081f78fad0dde275cd59f8f49dc4eab2e394332e74", size = 214724, upload-time = "2025-06-13T13:01:25.435Z" }, + { url = "https://files.pythonhosted.org/packages/9b/a8/d5b128633fd1a5e0401a4160d02fa15986209a9e47717174f99dc2f7166d/coverage-7.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:30f445f85c353090b83e552dcbbdad3ec84c7967e108c3ae54556ca69955563e", size = 215535, upload-time = "2025-06-13T13:01:27.861Z" }, + { url = "https://files.pythonhosted.org/packages/a3/37/84bba9d2afabc3611f3e4325ee2c6a47cd449b580d4a606b240ce5a6f9bf/coverage-7.9.1-cp312-cp312-win_arm64.whl", hash = "sha256:af41da5dca398d3474129c58cb2b106a5d93bbb196be0d307ac82311ca234342", size = 213904, upload-time = "2025-06-13T13:01:29.202Z" }, + { url = "https://files.pythonhosted.org/packages/d0/a7/a027970c991ca90f24e968999f7d509332daf6b8c3533d68633930aaebac/coverage-7.9.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:31324f18d5969feef7344a932c32428a2d1a3e50b15a6404e97cba1cc9b2c631", size = 212358, upload-time = "2025-06-13T13:01:30.909Z" }, + { url = "https://files.pythonhosted.org/packages/f2/48/6aaed3651ae83b231556750280682528fea8ac7f1232834573472d83e459/coverage-7.9.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0c804506d624e8a20fb3108764c52e0eef664e29d21692afa375e0dd98dc384f", size = 212620, upload-time = "2025-06-13T13:01:32.256Z" }, + { url = "https://files.pythonhosted.org/packages/6c/2a/f4b613f3b44d8b9f144847c89151992b2b6b79cbc506dee89ad0c35f209d/coverage-7.9.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ef64c27bc40189f36fcc50c3fb8f16ccda73b6a0b80d9bd6e6ce4cffcd810bbd", size = 245788, upload-time = "2025-06-13T13:01:33.948Z" }, + { url = "https://files.pythonhosted.org/packages/04/d2/de4fdc03af5e4e035ef420ed26a703c6ad3d7a07aff2e959eb84e3b19ca8/coverage-7.9.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d4fe2348cc6ec372e25adec0219ee2334a68d2f5222e0cba9c0d613394e12d86", size = 243001, upload-time = "2025-06-13T13:01:35.285Z" }, + { url = "https://files.pythonhosted.org/packages/f5/e8/eed18aa5583b0423ab7f04e34659e51101135c41cd1dcb33ac1d7013a6d6/coverage-7.9.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34ed2186fe52fcc24d4561041979a0dec69adae7bce2ae8d1c49eace13e55c43", size = 244985, upload-time = "2025-06-13T13:01:36.712Z" }, + { url = "https://files.pythonhosted.org/packages/17/f8/ae9e5cce8885728c934eaa58ebfa8281d488ef2afa81c3dbc8ee9e6d80db/coverage-7.9.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:25308bd3d00d5eedd5ae7d4357161f4df743e3c0240fa773ee1b0f75e6c7c0f1", size = 245152, upload-time = "2025-06-13T13:01:39.303Z" }, + { url = "https://files.pythonhosted.org/packages/5a/c8/272c01ae792bb3af9b30fac14d71d63371db227980682836ec388e2c57c0/coverage-7.9.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:73e9439310f65d55a5a1e0564b48e34f5369bee943d72c88378f2d576f5a5751", size = 243123, upload-time = "2025-06-13T13:01:40.727Z" }, + { url = "https://files.pythonhosted.org/packages/8c/d0/2819a1e3086143c094ab446e3bdf07138527a7b88cb235c488e78150ba7a/coverage-7.9.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:37ab6be0859141b53aa89412a82454b482c81cf750de4f29223d52268a86de67", size = 244506, upload-time = "2025-06-13T13:01:42.184Z" }, + { url = "https://files.pythonhosted.org/packages/8b/4e/9f6117b89152df7b6112f65c7a4ed1f2f5ec8e60c4be8f351d91e7acc848/coverage-7.9.1-cp313-cp313-win32.whl", hash = "sha256:64bdd969456e2d02a8b08aa047a92d269c7ac1f47e0c977675d550c9a0863643", size = 214766, upload-time = "2025-06-13T13:01:44.482Z" }, + { url = "https://files.pythonhosted.org/packages/27/0f/4b59f7c93b52c2c4ce7387c5a4e135e49891bb3b7408dcc98fe44033bbe0/coverage-7.9.1-cp313-cp313-win_amd64.whl", hash = "sha256:be9e3f68ca9edb897c2184ad0eee815c635565dbe7a0e7e814dc1f7cbab92c0a", size = 215568, upload-time = "2025-06-13T13:01:45.772Z" }, + { url = "https://files.pythonhosted.org/packages/09/1e/9679826336f8c67b9c39a359352882b24a8a7aee48d4c9cad08d38d7510f/coverage-7.9.1-cp313-cp313-win_arm64.whl", hash = "sha256:1c503289ffef1d5105d91bbb4d62cbe4b14bec4d13ca225f9c73cde9bb46207d", size = 213939, upload-time = "2025-06-13T13:01:47.087Z" }, + { url = "https://files.pythonhosted.org/packages/bb/5b/5c6b4e7a407359a2e3b27bf9c8a7b658127975def62077d441b93a30dbe8/coverage-7.9.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0b3496922cb5f4215bf5caaef4cf12364a26b0be82e9ed6d050f3352cf2d7ef0", size = 213079, upload-time = "2025-06-13T13:01:48.554Z" }, + { url = "https://files.pythonhosted.org/packages/a2/22/1e2e07279fd2fd97ae26c01cc2186e2258850e9ec125ae87184225662e89/coverage-7.9.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:9565c3ab1c93310569ec0d86b017f128f027cab0b622b7af288696d7ed43a16d", size = 213299, upload-time = "2025-06-13T13:01:49.997Z" }, + { url = "https://files.pythonhosted.org/packages/14/c0/4c5125a4b69d66b8c85986d3321520f628756cf524af810baab0790c7647/coverage-7.9.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2241ad5dbf79ae1d9c08fe52b36d03ca122fb9ac6bca0f34439e99f8327ac89f", size = 256535, upload-time = "2025-06-13T13:01:51.314Z" }, + { url = "https://files.pythonhosted.org/packages/81/8b/e36a04889dda9960be4263e95e777e7b46f1bb4fc32202612c130a20c4da/coverage-7.9.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bb5838701ca68b10ebc0937dbd0eb81974bac54447c55cd58dea5bca8451029", size = 252756, upload-time = "2025-06-13T13:01:54.403Z" }, + { url = "https://files.pythonhosted.org/packages/98/82/be04eff8083a09a4622ecd0e1f31a2c563dbea3ed848069e7b0445043a70/coverage-7.9.1-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b30a25f814591a8c0c5372c11ac8967f669b97444c47fd794926e175c4047ece", size = 254912, upload-time = "2025-06-13T13:01:56.769Z" }, + { url = "https://files.pythonhosted.org/packages/0f/25/c26610a2c7f018508a5ab958e5b3202d900422cf7cdca7670b6b8ca4e8df/coverage-7.9.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2d04b16a6062516df97969f1ae7efd0de9c31eb6ebdceaa0d213b21c0ca1a683", size = 256144, upload-time = "2025-06-13T13:01:58.19Z" }, + { url = "https://files.pythonhosted.org/packages/c5/8b/fb9425c4684066c79e863f1e6e7ecebb49e3a64d9f7f7860ef1688c56f4a/coverage-7.9.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:7931b9e249edefb07cd6ae10c702788546341d5fe44db5b6108a25da4dca513f", size = 254257, upload-time = "2025-06-13T13:01:59.645Z" }, + { url = "https://files.pythonhosted.org/packages/93/df/27b882f54157fc1131e0e215b0da3b8d608d9b8ef79a045280118a8f98fe/coverage-7.9.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:52e92b01041151bf607ee858e5a56c62d4b70f4dac85b8c8cb7fb8a351ab2c10", size = 255094, upload-time = "2025-06-13T13:02:01.37Z" }, + { url = "https://files.pythonhosted.org/packages/41/5f/cad1c3dbed8b3ee9e16fa832afe365b4e3eeab1fb6edb65ebbf745eabc92/coverage-7.9.1-cp313-cp313t-win32.whl", hash = "sha256:684e2110ed84fd1ca5f40e89aa44adf1729dc85444004111aa01866507adf363", size = 215437, upload-time = "2025-06-13T13:02:02.905Z" }, + { url = "https://files.pythonhosted.org/packages/99/4d/fad293bf081c0e43331ca745ff63673badc20afea2104b431cdd8c278b4c/coverage-7.9.1-cp313-cp313t-win_amd64.whl", hash = "sha256:437c576979e4db840539674e68c84b3cda82bc824dd138d56bead1435f1cb5d7", size = 216605, upload-time = "2025-06-13T13:02:05.638Z" }, + { url = "https://files.pythonhosted.org/packages/1f/56/4ee027d5965fc7fc126d7ec1187529cc30cc7d740846e1ecb5e92d31b224/coverage-7.9.1-cp313-cp313t-win_arm64.whl", hash = "sha256:18a0912944d70aaf5f399e350445738a1a20b50fbea788f640751c2ed9208b6c", size = 214392, upload-time = "2025-06-13T13:02:07.642Z" }, + { url = "https://files.pythonhosted.org/packages/a5/d6/c41dd9b02bf16ec001aaf1cbef665537606899a3db1094e78f5ae17540ca/coverage-7.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f424507f57878e424d9a95dc4ead3fbdd72fd201e404e861e465f28ea469951", size = 212029, upload-time = "2025-06-13T13:02:09.058Z" }, + { url = "https://files.pythonhosted.org/packages/f8/c0/40420d81d731f84c3916dcdf0506b3e6c6570817bff2576b83f780914ae6/coverage-7.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:535fde4001b2783ac80865d90e7cc7798b6b126f4cd8a8c54acfe76804e54e58", size = 212407, upload-time = "2025-06-13T13:02:11.151Z" }, + { url = "https://files.pythonhosted.org/packages/9b/87/f0db7d62d0e09f14d6d2f6ae8c7274a2f09edf74895a34b412a0601e375a/coverage-7.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02532fd3290bb8fa6bec876520842428e2a6ed6c27014eca81b031c2d30e3f71", size = 241160, upload-time = "2025-06-13T13:02:12.864Z" }, + { url = "https://files.pythonhosted.org/packages/a9/b7/3337c064f058a5d7696c4867159651a5b5fb01a5202bcf37362f0c51400e/coverage-7.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:56f5eb308b17bca3bbff810f55ee26d51926d9f89ba92707ee41d3c061257e55", size = 239027, upload-time = "2025-06-13T13:02:14.294Z" }, + { url = "https://files.pythonhosted.org/packages/7e/a9/5898a283f66d1bd413c32c2e0e05408196fd4f37e206e2b06c6e0c626e0e/coverage-7.9.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfa447506c1a52271f1b0de3f42ea0fa14676052549095e378d5bff1c505ff7b", size = 240145, upload-time = "2025-06-13T13:02:15.745Z" }, + { url = "https://files.pythonhosted.org/packages/e0/33/d96e3350078a3c423c549cb5b2ba970de24c5257954d3e4066e2b2152d30/coverage-7.9.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:9ca8e220006966b4a7b68e8984a6aee645a0384b0769e829ba60281fe61ec4f7", size = 239871, upload-time = "2025-06-13T13:02:17.344Z" }, + { url = "https://files.pythonhosted.org/packages/1d/6e/6fb946072455f71a820cac144d49d11747a0f1a21038060a68d2d0200499/coverage-7.9.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:49f1d0788ba5b7ba65933f3a18864117c6506619f5ca80326b478f72acf3f385", size = 238122, upload-time = "2025-06-13T13:02:18.849Z" }, + { url = "https://files.pythonhosted.org/packages/e4/5c/bc43f25c8586840ce25a796a8111acf6a2b5f0909ba89a10d41ccff3920d/coverage-7.9.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:68cd53aec6f45b8e4724c0950ce86eacb775c6be01ce6e3669fe4f3a21e768ed", size = 239058, upload-time = "2025-06-13T13:02:21.423Z" }, + { url = "https://files.pythonhosted.org/packages/11/d8/ce2007418dd7fd00ff8c8b898bb150bb4bac2d6a86df05d7b88a07ff595f/coverage-7.9.1-cp39-cp39-win32.whl", hash = "sha256:95335095b6c7b1cc14c3f3f17d5452ce677e8490d101698562b2ffcacc304c8d", size = 214532, upload-time = "2025-06-13T13:02:22.857Z" }, + { url = "https://files.pythonhosted.org/packages/20/21/334e76fa246e92e6d69cab217f7c8a70ae0cc8f01438bd0544103f29528e/coverage-7.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:e1b5191d1648acc439b24721caab2fd0c86679d8549ed2c84d5a7ec1bedcc244", size = 215439, upload-time = "2025-06-13T13:02:24.268Z" }, + { url = "https://files.pythonhosted.org/packages/3e/e5/c723545c3fd3204ebde3b4cc4b927dce709d3b6dc577754bb57f63ca4a4a/coverage-7.9.1-pp39.pp310.pp311-none-any.whl", hash = "sha256:db0f04118d1db74db6c9e1cb1898532c7dcc220f1d2718f058601f7c3f499514", size = 204009, upload-time = "2025-06-13T13:02:25.787Z" }, + { url = "https://files.pythonhosted.org/packages/08/b8/7ddd1e8ba9701dea08ce22029917140e6f66a859427406579fd8d0ca7274/coverage-7.9.1-py3-none-any.whl", hash = "sha256:66b974b145aa189516b6bf2d8423e888b742517d37872f6ee4c5be0073bd9a3c", size = 204000, upload-time = "2025-06-13T13:02:27.173Z" }, +] + +[package.optional-dependencies] +toml = [ + { name = "tomli", marker = "python_full_version <= '3.11'" }, +] + [[package]] name = "cryptography" version = "45.0.4" @@ -962,6 +1041,29 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e8/73/d6b999782ae22f16971cc05378b3b33f6a89ede3b9619e8366aa23484bca/mypy_protobuf-3.6.0-py3-none-any.whl", hash = "sha256:56176e4d569070e7350ea620262478b49b7efceba4103d468448f1d21492fd6c", size = 16434, upload-time = "2024-04-01T20:24:40.583Z" }, ] +[[package]] +name = "nexus-rpc" +version = "0.1.0" +source = { editable = "../nexus-sdk-python" } +dependencies = [ + { name = "typing-extensions" }, +] + +[package.metadata] +requires-dist = [{ name = "typing-extensions", specifier = ">=4.12.2" }] + +[package.metadata.requires-dev] +dev = [ + { name = "mypy", specifier = ">=1.15.0" }, + { name = "pydoctor", specifier = ">=25.4.0" }, + { name = "pyright", specifier = ">=1.1.400" }, + { name = "pytest", specifier = ">=8.3.5" }, + { name = "pytest-asyncio", specifier = ">=0.26.0" }, + { name = "pytest-cov", specifier = ">=6.1.1" }, + { name = "pytest-pretty", specifier = ">=1.3.0" }, + { name = "ruff", specifier = ">=0.12.0" }, +] + [[package]] name = "nh3" version = "0.2.21" @@ -1336,14 +1438,15 @@ wheels = [ [[package]] name = "pyright" -version = "1.1.377" +version = "1.1.400" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodeenv" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/49/f0/25b0db363d6888164adb7c828b877bbf2c30936955fb9513922ae03e70e4/pyright-1.1.377.tar.gz", hash = "sha256:aabc30fedce0ded34baa0c49b24f10e68f4bfc8f68ae7f3d175c4b0f256b4fcf", size = 17484, upload-time = "2024-08-21T02:25:15.74Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/cb/c306618a02d0ee8aed5fb8d0fe0ecfed0dbf075f71468f03a30b5f4e1fe0/pyright-1.1.400.tar.gz", hash = "sha256:b8a3ba40481aa47ba08ffb3228e821d22f7d391f83609211335858bf05686bdb", size = 3846546, upload-time = "2025-04-24T12:55:18.907Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/34/c9/89c40c4de44fe9463e77dddd0c4e2d2dd7a93e8ddc6858dfe7d5f75d263d/pyright-1.1.377-py3-none-any.whl", hash = "sha256:af0dd2b6b636c383a6569a083f8c5a8748ae4dcde5df7914b3f3f267e14dd162", size = 18223, upload-time = "2024-08-21T02:25:14.585Z" }, + { url = "https://files.pythonhosted.org/packages/c8/a5/5d285e4932cf149c90e3c425610c5efaea005475d5f96f1bfdb452956c62/pyright-1.1.400-py3-none-any.whl", hash = "sha256:c80d04f98b5a4358ad3a35e241dbf2a408eee33a40779df365644f8054d2517e", size = 5563460, upload-time = "2025-04-24T12:55:17.002Z" }, ] [[package]] @@ -1375,6 +1478,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/ce/1e4b53c213dce25d6e8b163697fbce2d43799d76fa08eea6ad270451c370/pytest_asyncio-0.21.2-py3-none-any.whl", hash = "sha256:ab664c88bb7998f711d8039cacd4884da6430886ae8bbd4eded552ed2004f16b", size = 13368, upload-time = "2024-04-29T13:23:23.126Z" }, ] +[[package]] +name = "pytest-cov" +version = "6.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage", extra = ["toml"] }, + { name = "pluggy" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/18/99/668cade231f434aaa59bbfbf49469068d2ddd945000621d3d165d2e7dd7b/pytest_cov-6.2.1.tar.gz", hash = "sha256:25cc6cc0a5358204b8108ecedc51a9b57b34cc6b8c967cc2c01a4e00d8a67da2", size = 69432, upload-time = "2025-06-12T10:47:47.684Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/16/4ea354101abb1287856baa4af2732be351c7bee728065aed451b678153fd/pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5", size = 24644, upload-time = "2025-06-12T10:47:45.932Z" }, +] + [[package]] name = "pytest-pretty" version = "1.3.0" @@ -1606,6 +1723,7 @@ name = "temporalio" version = "1.13.0" source = { virtual = "." } dependencies = [ + { name = "nexus-rpc" }, { name = "protobuf" }, { name = "python-dateutil", marker = "python_full_version < '3.11'" }, { name = "types-protobuf" }, @@ -1632,6 +1750,7 @@ pydantic = [ dev = [ { name = "cibuildwheel" }, { name = "grpcio-tools" }, + { name = "httpx" }, { name = "maturin" }, { name = "mypy" }, { name = "mypy-protobuf" }, @@ -1641,6 +1760,7 @@ dev = [ { name = "pyright" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "pytest-cov" }, { name = "pytest-pretty" }, { name = "pytest-timeout" }, { name = "ruff" }, @@ -1652,6 +1772,7 @@ dev = [ requires-dist = [ { name = "eval-type-backport", marker = "python_full_version < '3.10' and extra == 'openai-agents'", specifier = ">=0.2.2" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, + { name = "nexus-rpc", editable = "../nexus-sdk-python" }, { name = "openai-agents", marker = "extra == 'openai-agents'", specifier = ">=0.0.19,<0.1" }, { name = "opentelemetry-api", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, { name = "opentelemetry-sdk", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, @@ -1667,15 +1788,17 @@ provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents"] dev = [ { name = "cibuildwheel", specifier = ">=2.22.0,<3" }, { name = "grpcio-tools", specifier = ">=1.48.2,<2" }, + { name = "httpx", specifier = ">=0.28.1" }, { name = "maturin", specifier = ">=1.8.2" }, { name = "mypy", specifier = "==1.4.1" }, { name = "mypy-protobuf", specifier = ">=3.3.0,<4" }, { name = "psutil", specifier = ">=5.9.3,<6" }, { name = "pydocstyle", specifier = ">=6.3.0,<7" }, { name = "pydoctor", specifier = ">=24.11.1,<25" }, - { name = "pyright", specifier = "==1.1.377" }, + { name = "pyright", specifier = "==1.1.400" }, { name = "pytest", specifier = "~=7.4" }, { name = "pytest-asyncio", specifier = ">=0.21,<0.22" }, + { name = "pytest-cov", specifier = ">=6.1.1" }, { name = "pytest-pretty", specifier = ">=1.3.0" }, { name = "pytest-timeout", specifier = "~=2.2" }, { name = "ruff", specifier = ">=0.5.0,<0.6" }, From 2f160aa142bbad10eb8ee05f37b4dd40ef280638 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 18:38:02 -0400 Subject: [PATCH 106/183] Strengthen warning note --- temporalio/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/temporalio/client.py b/temporalio/client.py index 6a67e328e..aff72ff9c 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -466,7 +466,9 @@ async def start_workflow( request_eager_start: bool = False, priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, - # The following options are deliberately not exposed in overloads + # The following options should not be considered part of the public API. They + # are deliberately not exposed in overloads, and are not subject to any + # backwards compatibility guarantees. nexus_completion_callbacks: Sequence[NexusCompletionCallback] = [], workflow_event_links: Sequence[ temporalio.api.common.v1.Link.WorkflowEvent From 7e53850cab50a01f234811c73cb47e668c2a39ba Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 18:43:54 -0400 Subject: [PATCH 107/183] Docstrings, comments --- temporalio/client.py | 7 ++++++- temporalio/nexus/_decorators.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index aff72ff9c..2d091626a 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -5203,6 +5203,7 @@ class StartWorkflowInput: rpc_timeout: Optional[timedelta] request_eager_start: bool priority: temporalio.common.Priority + # The following options are experimental and unstable. nexus_completion_callbacks: Sequence[NexusCompletionCallback] workflow_event_links: Sequence[temporalio.api.common.v1.Link.WorkflowEvent] request_id: Optional[str] @@ -7264,7 +7265,11 @@ def api_key(self, value: Optional[str]) -> None: @dataclass(frozen=True) class NexusCompletionCallback: - """Nexus callback to attach to events such as workflow completion.""" + """Nexus callback to attach to events such as workflow completion. + + .. warning:: + This option is experimental and unstable. + """ url: str """Callback URL.""" diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index eda23031b..30a26ceb2 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -115,7 +115,7 @@ async def _start( return WorkflowRunOperationHandler(_start, input_type, output_type) method_name = get_callable_name(start) - # TODO(preview): make double-underscore attrs private to nexusrpc and expose getters/setters + # TODO(nexus-preview): make double-underscore attrs private to nexusrpc and expose getters/setters operation_handler_factory.__nexus_operation__ = nexusrpc.Operation( name=name or method_name, method_name=method_name, From 15beaffdf74bc5d77da5a704445a97aad5530030 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 26 Jun 2025 19:46:30 -0400 Subject: [PATCH 108/183] Type-level cleanup/evolution in workflow caller --- temporalio/nexus/_decorators.py | 5 ++- temporalio/nexus/_operation_handlers.py | 10 ++--- temporalio/nexus/_token.py | 2 +- temporalio/nexus/_util.py | 6 ++- temporalio/worker/_interceptor.py | 9 ++--- temporalio/worker/_workflow_instance.py | 37 ++++++++--------- temporalio/workflow.py | 53 ++++++++++--------------- 7 files changed, 56 insertions(+), 66 deletions(-) diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index 30a26ceb2..68267fa3e 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -4,16 +4,17 @@ Awaitable, Callable, Optional, + TypeVar, Union, overload, ) import nexusrpc +from nexusrpc import InputT, OutputT from nexusrpc.handler import ( OperationHandler, StartOperationContext, ) -from nexusrpc.types import InputT, OutputT, ServiceHandlerT from temporalio.nexus._operation_context import WorkflowRunOperationContext from temporalio.nexus._operation_handlers import ( @@ -27,6 +28,8 @@ get_workflow_run_start_method_input_and_output_type_annotations, ) +ServiceHandlerT = TypeVar("ServiceHandlerT") + @overload def workflow_run_operation( diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index 2a7c0d4d5..efdc41bac 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -8,7 +8,11 @@ Type, ) -from nexusrpc import OperationInfo +from nexusrpc import ( + InputT, + OperationInfo, + OutputT, +) from nexusrpc.handler import ( CancelOperationContext, FetchOperationInfoContext, @@ -19,10 +23,6 @@ StartOperationContext, StartOperationResultAsync, ) -from nexusrpc.types import ( - InputT, - OutputT, -) from temporalio import client from temporalio.nexus._operation_context import ( diff --git a/temporalio/nexus/_token.py b/temporalio/nexus/_token.py index 9f2957888..a6290111c 100644 --- a/temporalio/nexus/_token.py +++ b/temporalio/nexus/_token.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import Any, Generic, Literal, Optional, Type -from nexusrpc.types import OutputT +from nexusrpc import OutputT from temporalio import client diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index 9cb4af50f..8b24383ad 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -10,13 +10,13 @@ Callable, Optional, Type, + TypeVar, Union, ) -from nexusrpc.types import ( +from nexusrpc import ( InputT, OutputT, - ServiceHandlerT, ) from temporalio.nexus._operation_context import WorkflowRunOperationContext @@ -25,6 +25,8 @@ WorkflowHandle as WorkflowHandle, ) +ServiceHandlerT = TypeVar("ServiceHandlerT") + def get_workflow_run_start_method_input_and_output_type_annotations( start: Callable[ diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index c19771921..2b20dcb46 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -19,10 +19,7 @@ ) import nexusrpc.handler -from nexusrpc.types import ( - InputT, - OutputT, -) +from nexusrpc import InputT, OutputT import temporalio.activity import temporalio.api.common.v1 @@ -464,7 +461,7 @@ def start_local_activity( return self.next.start_local_activity(input) async def start_nexus_operation( - self, input: StartNexusOperationInput - ) -> temporalio.workflow.NexusOperationHandle[Any]: + self, input: StartNexusOperationInput[InputT, OutputT] + ) -> temporalio.workflow.NexusOperationHandle[OutputT]: """Called for every :py:func:`temporalio.workflow.start_nexus_operation` call.""" return await self.next.start_nexus_operation(input) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index f22c6d4c6..0ab367459 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -44,6 +44,7 @@ ) import nexusrpc.handler +from nexusrpc import InputT, OutputT from typing_extensions import Self, TypeAlias, TypedDict import temporalio.activity @@ -1498,12 +1499,12 @@ async def workflow_start_nexus_operation( self, endpoint: str, service: str, - operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]], + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], input: Any, - output_type: Optional[Type[O]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, - ) -> temporalio.workflow.NexusOperationHandle[Any]: + ) -> temporalio.workflow.NexusOperationHandle[OutputT]: # start_nexus_operation return await self._outbound.start_nexus_operation( StartNexusOperationInput( @@ -1822,8 +1823,8 @@ async def run_child() -> Any: apply_child_cancel_error() async def _outbound_start_nexus_operation( - self, input: StartNexusOperationInput - ) -> _NexusOperationHandle[Any]: + self, input: StartNexusOperationInput[Any, OutputT] + ) -> _NexusOperationHandle[OutputT]: # A Nexus operation handle contains two futures: self._start_fut is resolved as a # result of the Nexus operation starting (activation job: # resolve_nexus_operation_start), and self._result_fut is resolved as a result of @@ -1838,9 +1839,9 @@ async def _outbound_start_nexus_operation( # and start will be resolved with an operation token). See comments in # tests/worker/test_nexus.py for worked examples of the evolution of the resulting # handle state machine in the sync and async Nexus response cases. - handle: _NexusOperationHandle + handle: _NexusOperationHandle[OutputT] - async def operation_handle_fn() -> Any: + async def operation_handle_fn() -> OutputT: while True: try: return await asyncio.shield(handle._result_fut) @@ -2599,8 +2600,8 @@ async def start_child_workflow( return await self._instance._outbound_start_child_workflow(input) async def start_nexus_operation( - self, input: StartNexusOperationInput - ) -> temporalio.workflow.NexusOperationHandle[Any]: + self, input: StartNexusOperationInput[Any, OutputT] + ) -> _NexusOperationHandle[OutputT]: return await self._instance._outbound_start_nexus_operation(input) def start_local_activity( @@ -2989,27 +2990,23 @@ async def cancel(self) -> None: await self._instance._cancel_external_workflow(command) -I = TypeVar("I") -O = TypeVar("O") - - # TODO(dan): are we sure we don't want to inherit from asyncio.Task as ActivityHandle and # ChildWorkflowHandle do? I worry that we should provide .done(), .result(), .exception() # etc for consistency. -class _NexusOperationHandle(temporalio.workflow.NexusOperationHandle[O]): +class _NexusOperationHandle(temporalio.workflow.NexusOperationHandle[OutputT]): def __init__( self, instance: _WorkflowInstanceImpl, seq: int, - input: StartNexusOperationInput, - fn: Coroutine[Any, Any, O], + input: StartNexusOperationInput[Any, OutputT], + fn: Coroutine[Any, Any, OutputT], ): self._instance = instance self._seq = seq self._input = input self._task = asyncio.Task(fn) self._start_fut: asyncio.Future[Optional[str]] = instance.create_future() - self._result_fut: asyncio.Future[Optional[O]] = instance.create_future() + self._result_fut: asyncio.Future[Optional[OutputT]] = instance.create_future() @property def operation_token(self) -> Optional[str]: @@ -3023,10 +3020,10 @@ def operation_token(self) -> Optional[str]: except BaseException: return None - async def result(self) -> O: + async def result(self) -> OutputT: return await self._task - def __await__(self) -> Generator[Any, Any, O]: + def __await__(self) -> Generator[Any, Any, OutputT]: return self._task.__await__() def __repr__(self) -> str: @@ -3043,7 +3040,7 @@ def _resolve_start_success(self, operation_token: Optional[str]) -> None: # We intentionally let this error if already done self._start_fut.set_result(operation_token) - def _resolve_success(self, result: Any) -> None: + def _resolve_success(self, result: OutputT) -> None: # We intentionally let this error if already done self._result_fut.set_result(result) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 3a14989e3..fd18f0128 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -42,6 +42,7 @@ import nexusrpc import nexusrpc.handler +from nexusrpc import InputT, OutputT from typing_extensions import ( Concatenate, Literal, @@ -854,12 +855,12 @@ async def workflow_start_nexus_operation( self, endpoint: str, service: str, - operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]], + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], input: Any, - output_type: Optional[Type[O]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, - ) -> NexusOperationHandle[Any]: ... + ) -> NexusOperationHandle[OutputT]: ... @abstractmethod def workflow_time_ns(self) -> int: ... @@ -4383,14 +4384,8 @@ async def execute_child_workflow( return await handle -# TODO(nexus-prerelease): use types from nexusrpc -I = TypeVar("I") -O = TypeVar("O") -S = TypeVar("S") - - # TODO(nexus-prerelease): ABC / inherit from asyncio.Task? -class NexusOperationHandle(Generic[O]): +class NexusOperationHandle(Generic[OutputT]): def cancel(self) -> bool: # TODO(nexus-prerelease): docstring """ @@ -4404,7 +4399,7 @@ def cancel(self) -> bool: """ raise NotImplementedError - def __await__(self) -> Generator[Any, Any, O]: + def __await__(self) -> Generator[Any, Any, OutputT]: raise NotImplementedError # TODO(nexus-prerelease): check SDK-wide consistency for @property vs nullary accessor methods. @@ -4416,13 +4411,13 @@ def operation_token(self) -> Optional[str]: async def start_nexus_operation( endpoint: str, service: str, - operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]], + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], input: Any, *, - output_type: Optional[Type[O]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, -) -> NexusOperationHandle[Any]: +) -> NexusOperationHandle[OutputT]: """Start a Nexus operation and return its handle. Args: @@ -5161,17 +5156,13 @@ def _to_proto(self) -> temporalio.bridge.proto.common.VersioningIntent.ValueType # Nexus +ServiceT = TypeVar("ServiceT") -class NexusClient(Generic[S]): + +class NexusClient(Generic[ServiceT]): def __init__( self, - service: Union[ - # TODO(nexus-prerelease): Type[S] is modeling the interface case as well the impl case, but - # the typevar S is used below only in the impl case. I think this is OK, but - # think about it again before deleting this TODO. - Type[S], - str, - ], + service: Union[Type[ServiceT], str], *, endpoint: str, ) -> None: @@ -5194,13 +5185,13 @@ def __init__( # TODO(nexus-prerelease): should it be an error to use a reference to a method on a class other than that supplied? async def start_operation( self, - operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]], - input: I, + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], + input: InputT, *, - output_type: Optional[Type[O]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, - ) -> NexusOperationHandle[O]: + ) -> NexusOperationHandle[OutputT]: return await temporalio.workflow.start_nexus_operation( endpoint=self._endpoint, service=self._service_name, @@ -5214,14 +5205,14 @@ async def start_operation( # TODO(nexus-prerelease): overloads: no-input, ret type async def execute_operation( self, - operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]], - input: I, + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], + input: InputT, *, - output_type: Optional[Type[O]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, - ) -> O: - handle: NexusOperationHandle[O] = await self.start_operation( + ) -> OutputT: + handle: NexusOperationHandle[OutputT] = await self.start_operation( operation, input, output_type=output_type, From 132f693e2533346d1793e162bd4bb2ed11e5e5c5 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 27 Jun 2025 10:46:51 -0400 Subject: [PATCH 109/183] TODOs - fetch result - error handling --- temporalio/nexus/_operation_handlers.py | 1 + temporalio/worker/_nexus.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index efdc41bac..e9a2631fc 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -104,6 +104,7 @@ async def fetch_result( "Temporal Nexus operation handlers do not support fetching the operation result." ) # An implementation is provided for future reference: + # TODO: honor `wait` param and Request-Timeout header try: nexus_handle = WorkflowHandle[OutputT].from_token(token) except Exception as err: diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 54b94a6e7..33239ce55 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -341,6 +341,8 @@ async def _operation_error_to_proto( self, err: nexusrpc.OperationError, ) -> temporalio.api.nexus.v1.UnsuccessfulOperationError: + # TODO(nexus-prerelease): why are we accessing __cause__ here for OperationError + # and not for HandlerError? cause = err.__cause__ if cause is None: cause = Exception(*err.args).with_traceback(err.__traceback__) From d057268124c373228ee9b4b9897ce86130650df4 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 27 Jun 2025 10:12:47 -0400 Subject: [PATCH 110/183] Move logger --- temporalio/nexus/__init__.py | 28 +------------------------- temporalio/nexus/_operation_context.py | 21 ++++++++++++++++++- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index bdcc0b7a9..c3baa50b7 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -1,11 +1,3 @@ -import logging -from typing import ( - Any, - Mapping, - MutableMapping, - Optional, -) - from ._decorators import workflow_run_operation as workflow_run_operation from ._operation_context import Info as Info from ._operation_context import ( @@ -19,29 +11,11 @@ ) from ._operation_context import client as client from ._operation_context import info as info +from ._operation_context import logger as logger from ._operation_handlers import cancel_operation as cancel_operation from ._token import WorkflowHandle as WorkflowHandle -class LoggerAdapter(logging.LoggerAdapter): - def __init__(self, logger: logging.Logger, extra: Optional[Mapping[str, Any]]): - super().__init__(logger, extra or {}) - - def process( - self, msg: Any, kwargs: MutableMapping[str, Any] - ) -> tuple[Any, MutableMapping[str, Any]]: - extra = dict(self.extra or {}) - if tctx := _temporal_operation_context.get(None): - extra["service"] = tctx.nexus_operation_context.service - extra["operation"] = tctx.nexus_operation_context.operation - extra["task_queue"] = tctx.info().task_queue - kwargs["extra"] = extra | kwargs.get("extra", {}) - return msg, kwargs - - -logger = LoggerAdapter(logging.getLogger(__name__), None) -"""Logger that emits additional data describing the current Nexus operation.""" - # TODO(nexus-prerelease) WARN temporal_sdk_core::worker::nexus: Failed to parse nexus timeout header value '9.155416ms' # 2025-06-25T12:58:05.749589Z WARN temporal_sdk_core::worker::nexus: Failed to parse nexus timeout header value '9.155416ms' # 2025-06-25T12:58:05.763052Z WARN temporal_sdk_core::worker::nexus: Nexus task not found on completion. This may happen if the operation has already been cancelled but completed anyway. details=Status { code: NotFound, message: "Nexus task not found or already expired", details: b"\x08\x05\x12'Nexus task not found or already expired\x1aB\n@type.googleapis.com/temporal.api.errordetails.v1.NotFoundFailure", metadata: MetadataMap { headers: {"content-type": "application/grpc"} }, source: None } diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 0c47237c1..d11ef56ff 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -30,7 +30,6 @@ SelfType, ) -logger = logging.getLogger(__name__) _temporal_operation_context: ContextVar[_TemporalNexusOperationContext] = ContextVar( "temporal-operation-context" @@ -367,3 +366,23 @@ def _nexus_link_to_workflow_event( run_id=urllib.parse.unquote(groups["run_id"]), event_ref=event_ref, ) + + +class _LoggerAdapter(logging.LoggerAdapter): + def __init__(self, logger: logging.Logger, extra: Optional[Mapping[str, Any]]): + super().__init__(logger, extra or {}) + + def process( + self, msg: Any, kwargs: MutableMapping[str, Any] + ) -> tuple[Any, MutableMapping[str, Any]]: + extra = dict(self.extra or {}) + if tctx := _temporal_operation_context.get(None): + extra["service"] = tctx.nexus_operation_context.service + extra["operation"] = tctx.nexus_operation_context.operation + extra["task_queue"] = tctx.info().task_queue + kwargs["extra"] = extra | kwargs.get("extra", {}) + return msg, kwargs + + +logger = _LoggerAdapter(logging.getLogger("temporalio.nexus"), None) +"""Logger that emits additional data describing the current Nexus operation.""" From 0ec14d821edf6f460d882c1deae9f216ed067d25 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 27 Jun 2025 10:13:05 -0400 Subject: [PATCH 111/183] Separate Temporal context for each operation verb --- temporalio/nexus/__init__.py | 4 +- temporalio/nexus/_decorators.py | 8 +- temporalio/nexus/_operation_context.py | 227 ++++++++++++--------- temporalio/nexus/_operation_handlers.py | 6 +- temporalio/worker/_nexus.py | 28 ++- tests/nexus/test_handler.py | 10 +- tests/nexus/test_workflow_caller.py | 9 +- tests/nexus/test_workflow_run_operation.py | 3 +- 8 files changed, 162 insertions(+), 133 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index c3baa50b7..5573df4a6 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -4,10 +4,10 @@ WorkflowRunOperationContext as WorkflowRunOperationContext, ) from ._operation_context import ( - _temporal_operation_context as _temporal_operation_context, + _TemporalCancelOperationContext as _TemporalCancelOperationContext, ) from ._operation_context import ( - _TemporalNexusOperationContext as _TemporalNexusOperationContext, + _TemporalStartOperationContext as _TemporalStartOperationContext, ) from ._operation_context import client as client from ._operation_context import info as info diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index 68267fa3e..978b69dcc 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -16,7 +16,10 @@ StartOperationContext, ) -from temporalio.nexus._operation_context import WorkflowRunOperationContext +from temporalio.nexus._operation_context import ( + WorkflowRunOperationContext, + _TemporalStartOperationContext, +) from temporalio.nexus._operation_handlers import ( WorkflowRunOperationHandler, ) @@ -112,7 +115,8 @@ def operation_handler_factory( async def _start( ctx: StartOperationContext, input: InputT ) -> WorkflowHandle[OutputT]: - return await start(self, WorkflowRunOperationContext(ctx), input) + tctx = _TemporalStartOperationContext.get() + return await start(self, WorkflowRunOperationContext(tctx), input) _start.__doc__ = start.__doc__ return WorkflowRunOperationHandler(_start, input_type, output_type) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index d11ef56ff..fae56fc91 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -10,6 +10,7 @@ Any, Callable, Mapping, + MutableMapping, Optional, Sequence, Union, @@ -30,9 +31,16 @@ SelfType, ) +# The Temporal Nexus worker always builds a nexusrpc StartOperationContext or +# CancelOperationContext and passes it as the first parameter to the nexusrpc operation +# handler. In addition, it sets one of the following context vars. -_temporal_operation_context: ContextVar[_TemporalNexusOperationContext] = ContextVar( - "temporal-operation-context" +_temporal_start_operation_context: ContextVar[_TemporalStartOperationContext] = ( + ContextVar("temporal-start-operation-context") +) + +_temporal_cancel_operation_context: ContextVar[_TemporalCancelOperationContext] = ( + ContextVar("temporal-cancel-operation-context") ) @@ -51,59 +59,126 @@ def info() -> Info: """ Get the current Nexus operation information. """ - return _TemporalNexusOperationContext.get().info() + return _temporal_context().info() def client() -> temporalio.client.Client: """ Get the Temporal client used by the worker handling the current Nexus operation. """ - return _TemporalNexusOperationContext.get().client + return _temporal_context().client + + +def _temporal_context() -> ( + Union[_TemporalStartOperationContext, _TemporalCancelOperationContext] +): + ctx = _try_temporal_context() + if ctx is None: + raise RuntimeError("Not in Nexus operation context.") + return ctx + + +def _try_temporal_context() -> ( + Optional[Union[_TemporalStartOperationContext, _TemporalCancelOperationContext]] +): + start_ctx = _temporal_start_operation_context.get(None) + cancel_ctx = _temporal_cancel_operation_context.get(None) + if start_ctx and cancel_ctx: + raise RuntimeError("Cannot be in both start and cancel operation contexts.") + return start_ctx or cancel_ctx @dataclass -class _TemporalNexusOperationContext: +class _TemporalStartOperationContext: """ - Context for a Nexus operation being handled by a Temporal Nexus Worker. + Context for a Nexus start operation being handled by a Temporal Nexus Worker. """ - info: Callable[[], Info] - """Information about the running Nexus operation.""" + nexus_context: StartOperationContext + """Nexus-specific start operation context.""" - nexus_operation_context: Union[StartOperationContext, CancelOperationContext] + info: Callable[[], Info] + """Temporal information about the running Nexus operation.""" client: temporalio.client.Client """The Temporal client in use by the worker handling this Nexus operation.""" @classmethod - def get(cls) -> _TemporalNexusOperationContext: - ctx = _temporal_operation_context.get(None) + def get(cls) -> _TemporalStartOperationContext: + ctx = _temporal_start_operation_context.get(None) if ctx is None: raise RuntimeError("Not in Nexus operation context.") return ctx - @property - def _temporal_start_operation_context( + def set(self) -> None: + _temporal_start_operation_context.set(self) + + def get_completion_callbacks( self, - ) -> Optional[_TemporalStartOperationContext]: - ctx = self.nexus_operation_context - if not isinstance(ctx, StartOperationContext): - return None - return _TemporalStartOperationContext(ctx) + ) -> list[temporalio.client.NexusCompletionCallback]: + ctx = self.nexus_context + return ( + [ + # TODO(nexus-prerelease): For WorkflowRunOperation, when it handles the Nexus + # request, it needs to copy the links to the callback in + # StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links + # (for backwards compatibility). PR reference in Go SDK: + # https://github.com/temporalio/sdk-go/pull/1945 + temporalio.client.NexusCompletionCallback( + url=ctx.callback_url, + header=ctx.callback_headers, + ) + ] + if ctx.callback_url + else [] + ) - @property - def _temporal_cancel_operation_context( + def get_workflow_event_links( self, - ) -> Optional[_TemporalCancelOperationContext]: - ctx = self.nexus_operation_context - if not isinstance(ctx, CancelOperationContext): - return None - return _TemporalCancelOperationContext(ctx) + ) -> list[temporalio.api.common.v1.Link.WorkflowEvent]: + event_links = [] + for inbound_link in self.nexus_context.inbound_links: + if link := _nexus_link_to_workflow_event(inbound_link): + event_links.append(link) + return event_links + + def add_outbound_links( + self, workflow_handle: temporalio.client.WorkflowHandle[Any, Any] + ): + try: + link = _workflow_event_to_nexus_link( + _workflow_handle_to_workflow_execution_started_event_link( + workflow_handle + ) + ) + except Exception as e: + logger.warning( + f"Failed to create WorkflowExecutionStarted event link for workflow {id}: {e}" + ) + else: + self.nexus_context.outbound_links.append( + # TODO(nexus-prerelease): Before, WorkflowRunOperation was generating an EventReference + # link to send back to the caller. Now, it checks if the server returned + # the link in the StartWorkflowExecutionResponse, and if so, send the link + # from the response to the caller. Fallback to generating the link for + # backwards compatibility. PR reference in Go SDK: + # https://github.com/temporalio/sdk-go/pull/1934 + link + ) + return workflow_handle @dataclass class WorkflowRunOperationContext: - start_operation_context: StartOperationContext + temporal_context: _TemporalStartOperationContext + + @property + def nexus_context(self) -> StartOperationContext: + return self.temporal_context.nexus_context + + @classmethod + def get(cls) -> WorkflowRunOperationContext: + return cls(_TemporalStartOperationContext.get()) # Overload for single-param workflow # TODO(nexus-prerelease): bring over other overloads @@ -164,14 +239,6 @@ async def start_workflow( Nexus caller is itself a workflow, this means that the workflow in the caller namespace web UI will contain links to the started workflow, and vice versa. """ - tctx = _TemporalNexusOperationContext.get() - start_operation_context = tctx._temporal_start_operation_context - if not start_operation_context: - raise RuntimeError( - "WorkflowRunOperationContext.start_workflow() must be called from " - "within a Nexus start operation context" - ) - # TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this: # if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') { # internalOptions.onConflictOptions = { @@ -184,11 +251,11 @@ async def start_workflow( # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, # but these are deliberately not exposed in overloads, hence the type-check # violation. - wf_handle = await tctx.client.start_workflow( # type: ignore + wf_handle = await self.temporal_context.client.start_workflow( # type: ignore workflow=workflow, arg=arg, id=id, - task_queue=task_queue or tctx.info().task_queue, + task_queue=task_queue or self.temporal_context.info().task_queue, execution_timeout=execution_timeout, run_timeout=run_timeout, task_timeout=task_timeout, @@ -208,78 +275,40 @@ async def start_workflow( request_eager_start=request_eager_start, priority=priority, versioning_override=versioning_override, - nexus_completion_callbacks=start_operation_context.get_completion_callbacks(), - workflow_event_links=start_operation_context.get_workflow_event_links(), - request_id=start_operation_context.nexus_operation_context.request_id, + nexus_completion_callbacks=self.temporal_context.get_completion_callbacks(), + workflow_event_links=self.temporal_context.get_workflow_event_links(), + request_id=self.temporal_context.nexus_context.request_id, ) - start_operation_context.add_outbound_links(wf_handle) + self.temporal_context.add_outbound_links(wf_handle) return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) @dataclass -class _TemporalStartOperationContext: - nexus_operation_context: StartOperationContext +class _TemporalCancelOperationContext: + """ + Context for a Nexus cancel operation being handled by a Temporal Nexus Worker. + """ - def get_completion_callbacks( - self, - ) -> list[temporalio.client.NexusCompletionCallback]: - ctx = self.nexus_operation_context - return ( - [ - # TODO(nexus-prerelease): For WorkflowRunOperation, when it handles the Nexus - # request, it needs to copy the links to the callback in - # StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links - # (for backwards compatibility). PR reference in Go SDK: - # https://github.com/temporalio/sdk-go/pull/1945 - temporalio.client.NexusCompletionCallback( - url=ctx.callback_url, - header=ctx.callback_headers, - ) - ] - if ctx.callback_url - else [] - ) + nexus_context: CancelOperationContext + """Nexus-specific cancel operation context.""" - def get_workflow_event_links( - self, - ) -> list[temporalio.api.common.v1.Link.WorkflowEvent]: - event_links = [] - for inbound_link in self.nexus_operation_context.inbound_links: - if link := _nexus_link_to_workflow_event(inbound_link): - event_links.append(link) - return event_links + info: Callable[[], Info] + """Temporal information about the running Nexus cancel operation.""" - def add_outbound_links( - self, workflow_handle: temporalio.client.WorkflowHandle[Any, Any] - ): - try: - link = _workflow_event_to_nexus_link( - _workflow_handle_to_workflow_execution_started_event_link( - workflow_handle - ) - ) - except Exception as e: - logger.warning( - f"Failed to create WorkflowExecutionStarted event link for workflow {id}: {e}" - ) - else: - self.nexus_operation_context.outbound_links.append( - # TODO(nexus-prerelease): Before, WorkflowRunOperation was generating an EventReference - # link to send back to the caller. Now, it checks if the server returned - # the link in the StartWorkflowExecutionResponse, and if so, send the link - # from the response to the caller. Fallback to generating the link for - # backwards compatibility. PR reference in Go SDK: - # https://github.com/temporalio/sdk-go/pull/1934 - link - ) - return workflow_handle + client: temporalio.client.Client + """The Temporal client in use by the worker handling the current Nexus operation.""" + @classmethod + def get(cls) -> _TemporalCancelOperationContext: + ctx = _temporal_cancel_operation_context.get(None) + if ctx is None: + raise RuntimeError("Not in Nexus cancel operation context.") + return ctx -@dataclass -class _TemporalCancelOperationContext: - nexus_operation_context: CancelOperationContext + def set(self) -> None: + _temporal_cancel_operation_context.set(self) def _workflow_handle_to_workflow_execution_started_event_link( @@ -376,9 +405,9 @@ def process( self, msg: Any, kwargs: MutableMapping[str, Any] ) -> tuple[Any, MutableMapping[str, Any]]: extra = dict(self.extra or {}) - if tctx := _temporal_operation_context.get(None): - extra["service"] = tctx.nexus_operation_context.service - extra["operation"] = tctx.nexus_operation_context.operation + if tctx := _try_temporal_context(): + extra["service"] = tctx.nexus_context.service + extra["operation"] = tctx.nexus_context.operation extra["task_queue"] = tctx.info().task_queue kwargs["extra"] = extra | kwargs.get("extra", {}) return msg, kwargs diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index e9a2631fc..436e2e478 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -26,7 +26,7 @@ from temporalio import client from temporalio.nexus._operation_context import ( - _temporal_operation_context, + _temporal_start_operation_context, ) from temporalio.nexus._token import WorkflowHandle @@ -114,7 +114,7 @@ async def fetch_result( type=HandlerErrorType.NOT_FOUND, cause=err, ) - ctx = _temporal_operation_context.get() + ctx = _temporal_start_operation_context.get() try: client_handle = nexus_handle.to_workflow_handle( ctx.client, result_type=self._output_type @@ -148,7 +148,7 @@ async def cancel_operation( cause=err, ) - ctx = _temporal_operation_context.get() + ctx = _temporal_start_operation_context.get() try: client_workflow_handle = nexus_workflow_handle._to_client_workflow_handle( ctx.client diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 33239ce55..58a8b14f7 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -34,8 +34,8 @@ from temporalio.exceptions import ApplicationError from temporalio.nexus import ( Info, - _temporal_operation_context, - _TemporalNexusOperationContext, + _TemporalCancelOperationContext, + _TemporalStartOperationContext, logger, ) from temporalio.service import RPCError, RPCStatusCode @@ -175,13 +175,11 @@ async def _handle_cancel_operation_task( operation=request.operation, headers=headers, ) - _temporal_operation_context.set( - _TemporalNexusOperationContext( - info=lambda: Info(task_queue=self._task_queue), - nexus_operation_context=ctx, - client=self._client, - ) - ) + _TemporalCancelOperationContext( + info=lambda: Info(task_queue=self._task_queue), + nexus_context=ctx, + client=self._client, + ).set() try: await self._handler.cancel_operation(ctx, request.operation_token) except Exception as err: @@ -271,13 +269,11 @@ async def _start_operation( ], callback_headers=dict(start_request.callback_header), ) - _temporal_operation_context.set( - _TemporalNexusOperationContext( - nexus_operation_context=ctx, - client=self._client, - info=lambda: Info(task_queue=self._task_queue), - ) - ) + _TemporalStartOperationContext( + nexus_context=ctx, + client=self._client, + info=lambda: Info(task_queue=self._task_queue), + ).set() input = LazyValue( serializer=_DummyPayloadSerializer( data_converter=self._data_converter, diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 7cdd10f1a..54d16abe3 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -268,12 +268,14 @@ async def workflow_run_operation_without_type_annotations(self, ctx, input): async def workflow_run_op_link_test( self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: - nctx = ctx.start_operation_context assert any( - link.url == "http://inbound-link/" for link in nctx.inbound_links + link.url == "http://inbound-link/" + for link in ctx.nexus_context.inbound_links ), "Inbound link not found" - assert nctx.request_id == "test-request-id-123", "Request ID mismatch" - nctx.outbound_links.extend(nctx.inbound_links) + assert ( + ctx.nexus_context.request_id == "test-request-id-123" + ), "Request ID mismatch" + ctx.nexus_context.outbound_links.extend(ctx.nexus_context.inbound_links) return await ctx.start_workflow( MyLinkTestWorkflow.run, diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index cefe7b46a..e5dbd9f9c 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -157,11 +157,10 @@ async def start( if isinstance(input.response_type, SyncResponse): return StartOperationResultSync(value=OpOutput(value="sync response")) elif isinstance(input.response_type, AsyncResponse): - # TODO(nexus-preview): this is a hack; perhaps it should be should be called - # temporalio.nexus.StartOperationContext instead of - # WorkflowRunOperationContext. - tctx = WorkflowRunOperationContext(ctx) - handle = await tctx.start_workflow( + # TODO(nexus-preview): what do we want the DX to be for a user who is + # starting a Nexus backing workflow from a custom start method? (They may + # need to do this in order to customize the cancel method). + handle = await WorkflowRunOperationContext.get().start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index 740615f3e..8c21c10d4 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -49,8 +49,7 @@ def __init__(self): async def start( self, ctx: StartOperationContext, input: Input ) -> StartOperationResultAsync: - tctx = WorkflowRunOperationContext(ctx) - handle = await tctx.start_workflow( + handle = await WorkflowRunOperationContext.get().start_workflow( EchoWorkflow.run, input.value, id=str(uuid.uuid4()), From b1b9ea37609fef9ab25171e636c6044187a514c9 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 27 Jun 2025 10:22:55 -0400 Subject: [PATCH 112/183] Make Temporal context classes non-private --- temporalio/nexus/__init__.py | 4 ++-- temporalio/nexus/_decorators.py | 4 ++-- temporalio/nexus/_operation_context.py | 14 +++++++------- temporalio/worker/_nexus.py | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index 5573df4a6..fe35f8e34 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -1,13 +1,13 @@ from ._decorators import workflow_run_operation as workflow_run_operation from ._operation_context import Info as Info from ._operation_context import ( - WorkflowRunOperationContext as WorkflowRunOperationContext, + TemporalStartOperationContext as TemporalStartOperationContext, ) from ._operation_context import ( _TemporalCancelOperationContext as _TemporalCancelOperationContext, ) from ._operation_context import ( - _TemporalStartOperationContext as _TemporalStartOperationContext, + WorkflowRunOperationContext as WorkflowRunOperationContext, ) from ._operation_context import client as client from ._operation_context import info as info diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index 978b69dcc..020c75e30 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -17,8 +17,8 @@ ) from temporalio.nexus._operation_context import ( + TemporalStartOperationContext, WorkflowRunOperationContext, - _TemporalStartOperationContext, ) from temporalio.nexus._operation_handlers import ( WorkflowRunOperationHandler, @@ -115,7 +115,7 @@ def operation_handler_factory( async def _start( ctx: StartOperationContext, input: InputT ) -> WorkflowHandle[OutputT]: - tctx = _TemporalStartOperationContext.get() + tctx = TemporalStartOperationContext.get() return await start(self, WorkflowRunOperationContext(tctx), input) _start.__doc__ = start.__doc__ diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index fae56fc91..a93b1aa49 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -35,7 +35,7 @@ # CancelOperationContext and passes it as the first parameter to the nexusrpc operation # handler. In addition, it sets one of the following context vars. -_temporal_start_operation_context: ContextVar[_TemporalStartOperationContext] = ( +_temporal_start_operation_context: ContextVar[TemporalStartOperationContext] = ( ContextVar("temporal-start-operation-context") ) @@ -70,7 +70,7 @@ def client() -> temporalio.client.Client: def _temporal_context() -> ( - Union[_TemporalStartOperationContext, _TemporalCancelOperationContext] + Union[TemporalStartOperationContext, _TemporalCancelOperationContext] ): ctx = _try_temporal_context() if ctx is None: @@ -79,7 +79,7 @@ def _temporal_context() -> ( def _try_temporal_context() -> ( - Optional[Union[_TemporalStartOperationContext, _TemporalCancelOperationContext]] + Optional[Union[TemporalStartOperationContext, _TemporalCancelOperationContext]] ): start_ctx = _temporal_start_operation_context.get(None) cancel_ctx = _temporal_cancel_operation_context.get(None) @@ -89,7 +89,7 @@ def _try_temporal_context() -> ( @dataclass -class _TemporalStartOperationContext: +class TemporalStartOperationContext: """ Context for a Nexus start operation being handled by a Temporal Nexus Worker. """ @@ -104,7 +104,7 @@ class _TemporalStartOperationContext: """The Temporal client in use by the worker handling this Nexus operation.""" @classmethod - def get(cls) -> _TemporalStartOperationContext: + def get(cls) -> TemporalStartOperationContext: ctx = _temporal_start_operation_context.get(None) if ctx is None: raise RuntimeError("Not in Nexus operation context.") @@ -170,7 +170,7 @@ def add_outbound_links( @dataclass class WorkflowRunOperationContext: - temporal_context: _TemporalStartOperationContext + temporal_context: TemporalStartOperationContext @property def nexus_context(self) -> StartOperationContext: @@ -178,7 +178,7 @@ def nexus_context(self) -> StartOperationContext: @classmethod def get(cls) -> WorkflowRunOperationContext: - return cls(_TemporalStartOperationContext.get()) + return cls(TemporalStartOperationContext.get()) # Overload for single-param workflow # TODO(nexus-prerelease): bring over other overloads diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 58a8b14f7..2d87b0ea6 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -34,8 +34,8 @@ from temporalio.exceptions import ApplicationError from temporalio.nexus import ( Info, + TemporalStartOperationContext, _TemporalCancelOperationContext, - _TemporalStartOperationContext, logger, ) from temporalio.service import RPCError, RPCStatusCode @@ -269,7 +269,7 @@ async def _start_operation( ], callback_headers=dict(start_request.callback_header), ) - _TemporalStartOperationContext( + TemporalStartOperationContext( nexus_context=ctx, client=self._client, info=lambda: Info(task_queue=self._task_queue), From 2ba2bc1ebc1e9c8d4edaa3d57e94cbc6540f13aa Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 27 Jun 2025 10:24:35 -0400 Subject: [PATCH 113/183] Use TemporalStartOperationContext instead of WorkflowRunOperationContext --- temporalio/nexus/__init__.py | 3 --- temporalio/nexus/_decorators.py | 24 ++++++++---------- temporalio/nexus/_operation_context.py | 25 +++++-------------- temporalio/nexus/_operation_handlers.py | 4 +-- temporalio/nexus/_util.py | 10 ++++---- tests/nexus/test_handler.py | 10 ++++---- .../test_handler_interface_implementation.py | 4 +-- .../test_handler_operation_definitions.py | 8 +++--- tests/nexus/test_workflow_caller.py | 8 +++--- tests/nexus/test_workflow_run_operation.py | 4 +-- 10 files changed, 41 insertions(+), 59 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index fe35f8e34..2352da25f 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -6,9 +6,6 @@ from ._operation_context import ( _TemporalCancelOperationContext as _TemporalCancelOperationContext, ) -from ._operation_context import ( - WorkflowRunOperationContext as WorkflowRunOperationContext, -) from ._operation_context import client as client from ._operation_context import info as info from ._operation_context import logger as logger diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index 020c75e30..cf59e46d6 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -18,7 +18,6 @@ from temporalio.nexus._operation_context import ( TemporalStartOperationContext, - WorkflowRunOperationContext, ) from temporalio.nexus._operation_handlers import ( WorkflowRunOperationHandler, @@ -37,11 +36,11 @@ @overload def workflow_run_operation( start: Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [ServiceHandlerT, TemporalStartOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ) -> Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [ServiceHandlerT, TemporalStartOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ]: ... @@ -53,12 +52,12 @@ def workflow_run_operation( ) -> Callable[ [ Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [ServiceHandlerT, TemporalStartOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ] ], Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [ServiceHandlerT, TemporalStartOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ]: ... @@ -67,7 +66,7 @@ def workflow_run_operation( def workflow_run_operation( start: Optional[ Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [ServiceHandlerT, TemporalStartOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ] ] = None, @@ -75,18 +74,18 @@ def workflow_run_operation( name: Optional[str] = None, ) -> Union[ Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [ServiceHandlerT, TemporalStartOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], Callable[ [ Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [ServiceHandlerT, TemporalStartOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ] ], Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [ServiceHandlerT, TemporalStartOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ], @@ -97,11 +96,11 @@ def workflow_run_operation( def decorator( start: Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [ServiceHandlerT, TemporalStartOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ) -> Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [ServiceHandlerT, TemporalStartOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ]: ( @@ -115,8 +114,7 @@ def operation_handler_factory( async def _start( ctx: StartOperationContext, input: InputT ) -> WorkflowHandle[OutputT]: - tctx = TemporalStartOperationContext.get() - return await start(self, WorkflowRunOperationContext(tctx), input) + return await start(self, TemporalStartOperationContext.get(), input) _start.__doc__ = start.__doc__ return WorkflowRunOperationHandler(_start, input_type, output_type) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index a93b1aa49..9e77ef257 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -167,19 +167,6 @@ def add_outbound_links( ) return workflow_handle - -@dataclass -class WorkflowRunOperationContext: - temporal_context: TemporalStartOperationContext - - @property - def nexus_context(self) -> StartOperationContext: - return self.temporal_context.nexus_context - - @classmethod - def get(cls) -> WorkflowRunOperationContext: - return cls(TemporalStartOperationContext.get()) - # Overload for single-param workflow # TODO(nexus-prerelease): bring over other overloads async def start_workflow( @@ -251,11 +238,11 @@ async def start_workflow( # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, # but these are deliberately not exposed in overloads, hence the type-check # violation. - wf_handle = await self.temporal_context.client.start_workflow( # type: ignore + wf_handle = await self.client.start_workflow( # type: ignore workflow=workflow, arg=arg, id=id, - task_queue=task_queue or self.temporal_context.info().task_queue, + task_queue=task_queue or self.info().task_queue, execution_timeout=execution_timeout, run_timeout=run_timeout, task_timeout=task_timeout, @@ -275,12 +262,12 @@ async def start_workflow( request_eager_start=request_eager_start, priority=priority, versioning_override=versioning_override, - nexus_completion_callbacks=self.temporal_context.get_completion_callbacks(), - workflow_event_links=self.temporal_context.get_workflow_event_links(), - request_id=self.temporal_context.nexus_context.request_id, + nexus_completion_callbacks=self.get_completion_callbacks(), + workflow_event_links=self.get_workflow_event_links(), + request_id=self.nexus_context.request_id, ) - self.temporal_context.add_outbound_links(wf_handle) + self.add_outbound_links(wf_handle) return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index 436e2e478..d71eece35 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -41,7 +41,7 @@ class WorkflowRunOperationHandler(OperationHandler[InputT, OutputT]): Use this class to create an operation handler that starts a workflow by passing your ``start`` method to the constructor. Your ``start`` method must use - :py:func:`temporalio.nexus.WorkflowRunOperationContext.start_workflow` to start the + :py:func:`temporalio.nexus.TemporalStartOperationContext.start_workflow` to start the workflow. """ @@ -77,7 +77,7 @@ async def start( if isinstance(handle, client.WorkflowHandle): raise RuntimeError( f"Expected {handle} to be a nexus.WorkflowHandle, but got a client.WorkflowHandle. " - f"You must use WorkflowRunOperationContext.start_workflow " + f"You must use TemporalStartOperationContext.start_workflow " "to start a workflow that will deliver the result of the Nexus operation, " "not client.Client.start_workflow." ) diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index 8b24383ad..77de7b2b4 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -19,7 +19,7 @@ OutputT, ) -from temporalio.nexus._operation_context import WorkflowRunOperationContext +from temporalio.nexus._operation_context import TemporalStartOperationContext from ._token import ( WorkflowHandle as WorkflowHandle, @@ -30,7 +30,7 @@ def get_workflow_run_start_method_input_and_output_type_annotations( start: Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [ServiceHandlerT, TemporalStartOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ) -> tuple[ @@ -70,7 +70,7 @@ def get_workflow_run_start_method_input_and_output_type_annotations( def _get_start_method_input_and_output_type_annotations( start: Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [ServiceHandlerT, TemporalStartOperationContext, InputT], Union[OutputT, Awaitable[OutputT]], ], ) -> tuple[ @@ -102,11 +102,11 @@ def _get_start_method_input_and_output_type_annotations( input_type = None else: ctx_type, input_type = type_annotations.values() - if not issubclass(ctx_type, WorkflowRunOperationContext): + if not issubclass(ctx_type, TemporalStartOperationContext): # TODO(preview): stacklevel warnings.warn( f"Expected first parameter of {start} to be an instance of " - f"WorkflowRunOperationContext, but is {ctx_type}." + f"TemporalStartOperationContext, but is {ctx_type}." ) input_type = None diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 54d16abe3..090759cb5 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -45,7 +45,7 @@ from temporalio.client import Client from temporalio.common import WorkflowIDReusePolicy from temporalio.exceptions import ApplicationError -from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation +from temporalio.nexus import TemporalStartOperationContext, workflow_run_operation from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ( @@ -208,7 +208,7 @@ async def log(self, ctx: StartOperationContext, input: Input) -> Output: @workflow_run_operation async def workflow_run_operation_happy_path( - self, ctx: WorkflowRunOperationContext, input: Input + self, ctx: TemporalStartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: return await ctx.start_workflow( MyWorkflow.run, @@ -266,7 +266,7 @@ async def workflow_run_operation_without_type_annotations(self, ctx, input): @workflow_run_operation async def workflow_run_op_link_test( - self, ctx: WorkflowRunOperationContext, input: Input + self, ctx: TemporalStartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: assert any( link.url == "http://inbound-link/" @@ -1022,7 +1022,7 @@ async def run(self, input: Input) -> Output: class ServiceHandlerForRequestIdTest: @workflow_run_operation async def operation_backed_by_a_workflow( - self, ctx: WorkflowRunOperationContext, input: Input + self, ctx: TemporalStartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: return await ctx.start_workflow( EchoWorkflow.run, @@ -1033,7 +1033,7 @@ async def operation_backed_by_a_workflow( @workflow_run_operation async def operation_that_executes_a_workflow_before_starting_the_backing_workflow( - self, ctx: WorkflowRunOperationContext, input: Input + self, ctx: TemporalStartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: await nexus.client().start_workflow( EchoWorkflow.run, diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index be98ff6d6..881d4da9f 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -6,7 +6,7 @@ from nexusrpc.handler import StartOperationContext, sync_operation from temporalio import nexus -from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation +from temporalio.nexus import TemporalStartOperationContext, workflow_run_operation HTTP_PORT = 7243 @@ -37,7 +37,7 @@ class Interface: class Impl: @workflow_run_operation async def op( - self, ctx: WorkflowRunOperationContext, input: str + self, ctx: TemporalStartOperationContext, input: str ) -> nexus.WorkflowHandle[int]: ... error_message = None diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index b0c1f2ac4..c734f97e4 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -10,7 +10,7 @@ import pytest from temporalio import nexus -from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation +from temporalio.nexus import TemporalStartOperationContext, workflow_run_operation @dataclass @@ -34,7 +34,7 @@ class NotCalled(_TestCase): class Service: @workflow_run_operation async def my_workflow_run_operation_handler( - self, ctx: WorkflowRunOperationContext, input: Input + self, ctx: TemporalStartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... expected_operations = { @@ -52,7 +52,7 @@ class CalledWithoutArgs(_TestCase): class Service: @workflow_run_operation async def my_workflow_run_operation_handler( - self, ctx: WorkflowRunOperationContext, input: Input + self, ctx: TemporalStartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... expected_operations = NotCalled.expected_operations @@ -63,7 +63,7 @@ class CalledWithNameOverride(_TestCase): class Service: @workflow_run_operation(name="operation-name") async def workflow_run_operation_with_name_override( - self, ctx: WorkflowRunOperationContext, input: Input + self, ctx: TemporalStartOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... expected_operations = { diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index e5dbd9f9c..3c38f14ec 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -38,7 +38,7 @@ ) from temporalio.common import WorkflowIDConflictPolicy from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError -from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation +from temporalio.nexus import TemporalStartOperationContext, workflow_run_operation from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import Worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -160,7 +160,7 @@ async def start( # TODO(nexus-preview): what do we want the DX to be for a user who is # starting a Nexus backing workflow from a custom start method? (They may # need to do this in order to customize the cancel method). - handle = await WorkflowRunOperationContext.get().start_workflow( + handle = await TemporalStartOperationContext.get().start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, @@ -206,7 +206,7 @@ async def sync_operation( @workflow_run_operation async def async_operation( - self, ctx: WorkflowRunOperationContext, input: OpInput + self, ctx: TemporalStartOperationContext, input: OpInput ) -> nexus.WorkflowHandle[HandlerWfOutput]: assert isinstance(input.response_type, AsyncResponse) if input.response_type.exception_in_operation_start: @@ -912,7 +912,7 @@ async def run(self, input: str) -> str: class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: @workflow_run_operation async def my_workflow_run_operation( - self, ctx: WorkflowRunOperationContext, input: None + self, ctx: TemporalStartOperationContext, input: None ) -> nexus.WorkflowHandle[str]: result_1 = await nexus.client().execute_workflow( EchoWorkflow.run, diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index 8c21c10d4..d1b1e33fd 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -13,7 +13,7 @@ from nexusrpc.handler._decorators import operation_handler from temporalio import workflow -from temporalio.nexus import WorkflowRunOperationContext +from temporalio.nexus import TemporalStartOperationContext from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker @@ -49,7 +49,7 @@ def __init__(self): async def start( self, ctx: StartOperationContext, input: Input ) -> StartOperationResultAsync: - handle = await WorkflowRunOperationContext.get().start_workflow( + handle = await TemporalStartOperationContext.get().start_workflow( EchoWorkflow.run, input.value, id=str(uuid.uuid4()), From f36c21534ec1751cfbf4f0f1062b0a765ece9d91 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 27 Jun 2025 13:54:07 -0400 Subject: [PATCH 114/183] Revert "Use TemporalStartOperationContext instead of WorkflowRunOperationContext" This reverts commit 75d16b021b7bd79babc44e70feb99cdaf3918cec. --- temporalio/nexus/__init__.py | 3 +++ temporalio/nexus/_decorators.py | 24 ++++++++++-------- temporalio/nexus/_operation_context.py | 25 ++++++++++++++----- temporalio/nexus/_operation_handlers.py | 4 +-- temporalio/nexus/_util.py | 10 ++++---- tests/nexus/test_handler.py | 10 ++++---- .../test_handler_interface_implementation.py | 4 +-- .../test_handler_operation_definitions.py | 8 +++--- tests/nexus/test_workflow_caller.py | 8 +++--- tests/nexus/test_workflow_run_operation.py | 4 +-- 10 files changed, 59 insertions(+), 41 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index 2352da25f..fe35f8e34 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -6,6 +6,9 @@ from ._operation_context import ( _TemporalCancelOperationContext as _TemporalCancelOperationContext, ) +from ._operation_context import ( + WorkflowRunOperationContext as WorkflowRunOperationContext, +) from ._operation_context import client as client from ._operation_context import info as info from ._operation_context import logger as logger diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index cf59e46d6..020c75e30 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -18,6 +18,7 @@ from temporalio.nexus._operation_context import ( TemporalStartOperationContext, + WorkflowRunOperationContext, ) from temporalio.nexus._operation_handlers import ( WorkflowRunOperationHandler, @@ -36,11 +37,11 @@ @overload def workflow_run_operation( start: Callable[ - [ServiceHandlerT, TemporalStartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ) -> Callable[ - [ServiceHandlerT, TemporalStartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ]: ... @@ -52,12 +53,12 @@ def workflow_run_operation( ) -> Callable[ [ Callable[ - [ServiceHandlerT, TemporalStartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ] ], Callable[ - [ServiceHandlerT, TemporalStartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ]: ... @@ -66,7 +67,7 @@ def workflow_run_operation( def workflow_run_operation( start: Optional[ Callable[ - [ServiceHandlerT, TemporalStartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ] ] = None, @@ -74,18 +75,18 @@ def workflow_run_operation( name: Optional[str] = None, ) -> Union[ Callable[ - [ServiceHandlerT, TemporalStartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], Callable[ [ Callable[ - [ServiceHandlerT, TemporalStartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ] ], Callable[ - [ServiceHandlerT, TemporalStartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ], @@ -96,11 +97,11 @@ def workflow_run_operation( def decorator( start: Callable[ - [ServiceHandlerT, TemporalStartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ) -> Callable[ - [ServiceHandlerT, TemporalStartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ]: ( @@ -114,7 +115,8 @@ def operation_handler_factory( async def _start( ctx: StartOperationContext, input: InputT ) -> WorkflowHandle[OutputT]: - return await start(self, TemporalStartOperationContext.get(), input) + tctx = TemporalStartOperationContext.get() + return await start(self, WorkflowRunOperationContext(tctx), input) _start.__doc__ = start.__doc__ return WorkflowRunOperationHandler(_start, input_type, output_type) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 9e77ef257..a93b1aa49 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -167,6 +167,19 @@ def add_outbound_links( ) return workflow_handle + +@dataclass +class WorkflowRunOperationContext: + temporal_context: TemporalStartOperationContext + + @property + def nexus_context(self) -> StartOperationContext: + return self.temporal_context.nexus_context + + @classmethod + def get(cls) -> WorkflowRunOperationContext: + return cls(TemporalStartOperationContext.get()) + # Overload for single-param workflow # TODO(nexus-prerelease): bring over other overloads async def start_workflow( @@ -238,11 +251,11 @@ async def start_workflow( # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, # but these are deliberately not exposed in overloads, hence the type-check # violation. - wf_handle = await self.client.start_workflow( # type: ignore + wf_handle = await self.temporal_context.client.start_workflow( # type: ignore workflow=workflow, arg=arg, id=id, - task_queue=task_queue or self.info().task_queue, + task_queue=task_queue or self.temporal_context.info().task_queue, execution_timeout=execution_timeout, run_timeout=run_timeout, task_timeout=task_timeout, @@ -262,12 +275,12 @@ async def start_workflow( request_eager_start=request_eager_start, priority=priority, versioning_override=versioning_override, - nexus_completion_callbacks=self.get_completion_callbacks(), - workflow_event_links=self.get_workflow_event_links(), - request_id=self.nexus_context.request_id, + nexus_completion_callbacks=self.temporal_context.get_completion_callbacks(), + workflow_event_links=self.temporal_context.get_workflow_event_links(), + request_id=self.temporal_context.nexus_context.request_id, ) - self.add_outbound_links(wf_handle) + self.temporal_context.add_outbound_links(wf_handle) return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index d71eece35..436e2e478 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -41,7 +41,7 @@ class WorkflowRunOperationHandler(OperationHandler[InputT, OutputT]): Use this class to create an operation handler that starts a workflow by passing your ``start`` method to the constructor. Your ``start`` method must use - :py:func:`temporalio.nexus.TemporalStartOperationContext.start_workflow` to start the + :py:func:`temporalio.nexus.WorkflowRunOperationContext.start_workflow` to start the workflow. """ @@ -77,7 +77,7 @@ async def start( if isinstance(handle, client.WorkflowHandle): raise RuntimeError( f"Expected {handle} to be a nexus.WorkflowHandle, but got a client.WorkflowHandle. " - f"You must use TemporalStartOperationContext.start_workflow " + f"You must use WorkflowRunOperationContext.start_workflow " "to start a workflow that will deliver the result of the Nexus operation, " "not client.Client.start_workflow." ) diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index 77de7b2b4..8b24383ad 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -19,7 +19,7 @@ OutputT, ) -from temporalio.nexus._operation_context import TemporalStartOperationContext +from temporalio.nexus._operation_context import WorkflowRunOperationContext from ._token import ( WorkflowHandle as WorkflowHandle, @@ -30,7 +30,7 @@ def get_workflow_run_start_method_input_and_output_type_annotations( start: Callable[ - [ServiceHandlerT, TemporalStartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ) -> tuple[ @@ -70,7 +70,7 @@ def get_workflow_run_start_method_input_and_output_type_annotations( def _get_start_method_input_and_output_type_annotations( start: Callable[ - [ServiceHandlerT, TemporalStartOperationContext, InputT], + [ServiceHandlerT, WorkflowRunOperationContext, InputT], Union[OutputT, Awaitable[OutputT]], ], ) -> tuple[ @@ -102,11 +102,11 @@ def _get_start_method_input_and_output_type_annotations( input_type = None else: ctx_type, input_type = type_annotations.values() - if not issubclass(ctx_type, TemporalStartOperationContext): + if not issubclass(ctx_type, WorkflowRunOperationContext): # TODO(preview): stacklevel warnings.warn( f"Expected first parameter of {start} to be an instance of " - f"TemporalStartOperationContext, but is {ctx_type}." + f"WorkflowRunOperationContext, but is {ctx_type}." ) input_type = None diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 090759cb5..54d16abe3 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -45,7 +45,7 @@ from temporalio.client import Client from temporalio.common import WorkflowIDReusePolicy from temporalio.exceptions import ApplicationError -from temporalio.nexus import TemporalStartOperationContext, workflow_run_operation +from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ( @@ -208,7 +208,7 @@ async def log(self, ctx: StartOperationContext, input: Input) -> Output: @workflow_run_operation async def workflow_run_operation_happy_path( - self, ctx: TemporalStartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: return await ctx.start_workflow( MyWorkflow.run, @@ -266,7 +266,7 @@ async def workflow_run_operation_without_type_annotations(self, ctx, input): @workflow_run_operation async def workflow_run_op_link_test( - self, ctx: TemporalStartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: assert any( link.url == "http://inbound-link/" @@ -1022,7 +1022,7 @@ async def run(self, input: Input) -> Output: class ServiceHandlerForRequestIdTest: @workflow_run_operation async def operation_backed_by_a_workflow( - self, ctx: TemporalStartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: return await ctx.start_workflow( EchoWorkflow.run, @@ -1033,7 +1033,7 @@ async def operation_backed_by_a_workflow( @workflow_run_operation async def operation_that_executes_a_workflow_before_starting_the_backing_workflow( - self, ctx: TemporalStartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: await nexus.client().start_workflow( EchoWorkflow.run, diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index 881d4da9f..be98ff6d6 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -6,7 +6,7 @@ from nexusrpc.handler import StartOperationContext, sync_operation from temporalio import nexus -from temporalio.nexus import TemporalStartOperationContext, workflow_run_operation +from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation HTTP_PORT = 7243 @@ -37,7 +37,7 @@ class Interface: class Impl: @workflow_run_operation async def op( - self, ctx: TemporalStartOperationContext, input: str + self, ctx: WorkflowRunOperationContext, input: str ) -> nexus.WorkflowHandle[int]: ... error_message = None diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index c734f97e4..b0c1f2ac4 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -10,7 +10,7 @@ import pytest from temporalio import nexus -from temporalio.nexus import TemporalStartOperationContext, workflow_run_operation +from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation @dataclass @@ -34,7 +34,7 @@ class NotCalled(_TestCase): class Service: @workflow_run_operation async def my_workflow_run_operation_handler( - self, ctx: TemporalStartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... expected_operations = { @@ -52,7 +52,7 @@ class CalledWithoutArgs(_TestCase): class Service: @workflow_run_operation async def my_workflow_run_operation_handler( - self, ctx: TemporalStartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... expected_operations = NotCalled.expected_operations @@ -63,7 +63,7 @@ class CalledWithNameOverride(_TestCase): class Service: @workflow_run_operation(name="operation-name") async def workflow_run_operation_with_name_override( - self, ctx: TemporalStartOperationContext, input: Input + self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: ... expected_operations = { diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 3c38f14ec..e5dbd9f9c 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -38,7 +38,7 @@ ) from temporalio.common import WorkflowIDConflictPolicy from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError -from temporalio.nexus import TemporalStartOperationContext, workflow_run_operation +from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import Worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -160,7 +160,7 @@ async def start( # TODO(nexus-preview): what do we want the DX to be for a user who is # starting a Nexus backing workflow from a custom start method? (They may # need to do this in order to customize the cancel method). - handle = await TemporalStartOperationContext.get().start_workflow( + handle = await WorkflowRunOperationContext.get().start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, @@ -206,7 +206,7 @@ async def sync_operation( @workflow_run_operation async def async_operation( - self, ctx: TemporalStartOperationContext, input: OpInput + self, ctx: WorkflowRunOperationContext, input: OpInput ) -> nexus.WorkflowHandle[HandlerWfOutput]: assert isinstance(input.response_type, AsyncResponse) if input.response_type.exception_in_operation_start: @@ -912,7 +912,7 @@ async def run(self, input: str) -> str: class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: @workflow_run_operation async def my_workflow_run_operation( - self, ctx: TemporalStartOperationContext, input: None + self, ctx: WorkflowRunOperationContext, input: None ) -> nexus.WorkflowHandle[str]: result_1 = await nexus.client().execute_workflow( EchoWorkflow.run, diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index d1b1e33fd..8c21c10d4 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -13,7 +13,7 @@ from nexusrpc.handler._decorators import operation_handler from temporalio import workflow -from temporalio.nexus import TemporalStartOperationContext +from temporalio.nexus import WorkflowRunOperationContext from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker @@ -49,7 +49,7 @@ def __init__(self): async def start( self, ctx: StartOperationContext, input: Input ) -> StartOperationResultAsync: - handle = await TemporalStartOperationContext.get().start_workflow( + handle = await WorkflowRunOperationContext.get().start_workflow( EchoWorkflow.run, input.value, id=str(uuid.uuid4()), From 38c1c5758ac3a55f6e95e32699b7d32149071c52 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 27 Jun 2025 14:19:50 -0400 Subject: [PATCH 115/183] Make WorkflowRunOperationContext subclass StartOperationContext --- temporalio/nexus/_decorators.py | 8 +++++--- temporalio/nexus/_operation_context.py | 22 +++++++++++++++++----- tests/nexus/test_workflow_caller.py | 3 ++- tests/nexus/test_workflow_run_operation.py | 3 ++- 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index 020c75e30..b40fb7634 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -17,7 +17,6 @@ ) from temporalio.nexus._operation_context import ( - TemporalStartOperationContext, WorkflowRunOperationContext, ) from temporalio.nexus._operation_handlers import ( @@ -115,8 +114,11 @@ def operation_handler_factory( async def _start( ctx: StartOperationContext, input: InputT ) -> WorkflowHandle[OutputT]: - tctx = TemporalStartOperationContext.get() - return await start(self, WorkflowRunOperationContext(tctx), input) + return await start( + self, + WorkflowRunOperationContext.from_start_operation_context(ctx), + input, + ) _start.__doc__ = start.__doc__ return WorkflowRunOperationHandler(_start, input_type, output_type) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index a93b1aa49..b67e65b41 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -1,5 +1,6 @@ from __future__ import annotations +import dataclasses import logging import re import urllib.parse @@ -168,17 +169,28 @@ def add_outbound_links( return workflow_handle -@dataclass -class WorkflowRunOperationContext: - temporal_context: TemporalStartOperationContext +@dataclass(frozen=True) +class WorkflowRunOperationContext(StartOperationContext): + _temporal_context: Optional[TemporalStartOperationContext] = None + + @property + def temporal_context(self) -> TemporalStartOperationContext: + if not self._temporal_context: + raise RuntimeError("Temporal context not set") + return self._temporal_context @property def nexus_context(self) -> StartOperationContext: return self.temporal_context.nexus_context @classmethod - def get(cls) -> WorkflowRunOperationContext: - return cls(TemporalStartOperationContext.get()) + def from_start_operation_context( + cls, ctx: StartOperationContext + ) -> WorkflowRunOperationContext: + return cls( + _temporal_context=TemporalStartOperationContext.get(), + **{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)}, + ) # Overload for single-param workflow # TODO(nexus-prerelease): bring over other overloads diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index e5dbd9f9c..7ca4d004b 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -160,7 +160,8 @@ async def start( # TODO(nexus-preview): what do we want the DX to be for a user who is # starting a Nexus backing workflow from a custom start method? (They may # need to do this in order to customize the cancel method). - handle = await WorkflowRunOperationContext.get().start_workflow( + tctx = WorkflowRunOperationContext.from_start_operation_context(ctx) + handle = await tctx.start_workflow( HandlerWorkflow.run, HandlerWfInput(op_input=input), id=input.response_type.operation_workflow_id, diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index 8c21c10d4..9faa39b3e 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -49,7 +49,8 @@ def __init__(self): async def start( self, ctx: StartOperationContext, input: Input ) -> StartOperationResultAsync: - handle = await WorkflowRunOperationContext.get().start_workflow( + tctx = WorkflowRunOperationContext.from_start_operation_context(ctx) + handle = await tctx.start_workflow( EchoWorkflow.run, input.value, id=str(uuid.uuid4()), From 3320f28d7af2cd66f67526ed58728639f9c8b748 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 27 Jun 2025 14:22:14 -0400 Subject: [PATCH 116/183] Mark TemporalStartOperationContext as private --- temporalio/nexus/__init__.py | 4 ++-- temporalio/nexus/_operation_context.py | 16 ++++++++-------- temporalio/worker/_nexus.py | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index fe35f8e34..5573df4a6 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -1,13 +1,13 @@ from ._decorators import workflow_run_operation as workflow_run_operation from ._operation_context import Info as Info from ._operation_context import ( - TemporalStartOperationContext as TemporalStartOperationContext, + WorkflowRunOperationContext as WorkflowRunOperationContext, ) from ._operation_context import ( _TemporalCancelOperationContext as _TemporalCancelOperationContext, ) from ._operation_context import ( - WorkflowRunOperationContext as WorkflowRunOperationContext, + _TemporalStartOperationContext as _TemporalStartOperationContext, ) from ._operation_context import client as client from ._operation_context import info as info diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index b67e65b41..79afd5a91 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -36,7 +36,7 @@ # CancelOperationContext and passes it as the first parameter to the nexusrpc operation # handler. In addition, it sets one of the following context vars. -_temporal_start_operation_context: ContextVar[TemporalStartOperationContext] = ( +_temporal_start_operation_context: ContextVar[_TemporalStartOperationContext] = ( ContextVar("temporal-start-operation-context") ) @@ -71,7 +71,7 @@ def client() -> temporalio.client.Client: def _temporal_context() -> ( - Union[TemporalStartOperationContext, _TemporalCancelOperationContext] + Union[_TemporalStartOperationContext, _TemporalCancelOperationContext] ): ctx = _try_temporal_context() if ctx is None: @@ -80,7 +80,7 @@ def _temporal_context() -> ( def _try_temporal_context() -> ( - Optional[Union[TemporalStartOperationContext, _TemporalCancelOperationContext]] + Optional[Union[_TemporalStartOperationContext, _TemporalCancelOperationContext]] ): start_ctx = _temporal_start_operation_context.get(None) cancel_ctx = _temporal_cancel_operation_context.get(None) @@ -90,7 +90,7 @@ def _try_temporal_context() -> ( @dataclass -class TemporalStartOperationContext: +class _TemporalStartOperationContext: """ Context for a Nexus start operation being handled by a Temporal Nexus Worker. """ @@ -105,7 +105,7 @@ class TemporalStartOperationContext: """The Temporal client in use by the worker handling this Nexus operation.""" @classmethod - def get(cls) -> TemporalStartOperationContext: + def get(cls) -> _TemporalStartOperationContext: ctx = _temporal_start_operation_context.get(None) if ctx is None: raise RuntimeError("Not in Nexus operation context.") @@ -171,10 +171,10 @@ def add_outbound_links( @dataclass(frozen=True) class WorkflowRunOperationContext(StartOperationContext): - _temporal_context: Optional[TemporalStartOperationContext] = None + _temporal_context: Optional[_TemporalStartOperationContext] = None @property - def temporal_context(self) -> TemporalStartOperationContext: + def temporal_context(self) -> _TemporalStartOperationContext: if not self._temporal_context: raise RuntimeError("Temporal context not set") return self._temporal_context @@ -188,7 +188,7 @@ def from_start_operation_context( cls, ctx: StartOperationContext ) -> WorkflowRunOperationContext: return cls( - _temporal_context=TemporalStartOperationContext.get(), + _temporal_context=_TemporalStartOperationContext.get(), **{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)}, ) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 2d87b0ea6..58a8b14f7 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -34,8 +34,8 @@ from temporalio.exceptions import ApplicationError from temporalio.nexus import ( Info, - TemporalStartOperationContext, _TemporalCancelOperationContext, + _TemporalStartOperationContext, logger, ) from temporalio.service import RPCError, RPCStatusCode @@ -269,7 +269,7 @@ async def _start_operation( ], callback_headers=dict(start_request.callback_header), ) - TemporalStartOperationContext( + _TemporalStartOperationContext( nexus_context=ctx, client=self._client, info=lambda: Info(task_queue=self._task_queue), From 5e563c0d2be03b66a0d2fd1f95a6c9c4a4528c50 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 27 Jun 2025 15:17:32 -0400 Subject: [PATCH 117/183] Handle OperationError consistently with HandlerError --- temporalio/worker/_nexus.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 58a8b14f7..661ea0892 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -337,14 +337,9 @@ async def _operation_error_to_proto( self, err: nexusrpc.OperationError, ) -> temporalio.api.nexus.v1.UnsuccessfulOperationError: - # TODO(nexus-prerelease): why are we accessing __cause__ here for OperationError - # and not for HandlerError? - cause = err.__cause__ - if cause is None: - cause = Exception(*err.args).with_traceback(err.__traceback__) return temporalio.api.nexus.v1.UnsuccessfulOperationError( operation_state=err.state.value, - failure=await self._exception_to_failure_proto(cause), + failure=await self._exception_to_failure_proto(err), ) async def _handler_error_to_proto( From a6e9777394f212b23e0c201b9ca8ce676eeaa752 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 27 Jun 2025 16:08:42 -0400 Subject: [PATCH 118/183] RTU: operation_id -> operation_token --- temporalio/worker/_workflow_instance.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 0ab367459..db24de6b7 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -856,11 +856,10 @@ def _apply_resolve_nexus_operation_start( raise RuntimeError( f"Failed to find nexus operation handle for job sequence number {job.seq}" ) - # TODO(dan): change core protos to use operation_token instead of operation_id - if job.HasField("operation_id"): + if job.HasField("operation_token"): # The Nexus operation started asynchronously. A `ResolveNexusOperation` job # will follow in a future activation. - handle._resolve_start_success(job.operation_id) + handle._resolve_start_success(job.operation_token) elif job.HasField("started_sync"): # The Nexus operation 'started' in the sense that it's already resolved. A # `ResolveNexusOperation` job will be in the same activation. From 72d14dfd86e1ed400d4b8c54ea93d4bdb5998104 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 27 Jun 2025 16:43:28 -0400 Subject: [PATCH 119/183] Fix cancellation context bug --- temporalio/nexus/_operation_handlers.py | 46 ++++++++++++------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index 436e2e478..5a1335b59 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -26,7 +26,7 @@ from temporalio import client from temporalio.nexus._operation_context import ( - _temporal_start_operation_context, + _temporal_cancel_operation_context, ) from temporalio.nexus._token import WorkflowHandle @@ -105,27 +105,27 @@ async def fetch_result( ) # An implementation is provided for future reference: # TODO: honor `wait` param and Request-Timeout header - try: - nexus_handle = WorkflowHandle[OutputT].from_token(token) - except Exception as err: - raise HandlerError( - "Failed to decode operation token as workflow operation token. " - "Fetching result for non-workflow operations is not supported.", - type=HandlerErrorType.NOT_FOUND, - cause=err, - ) - ctx = _temporal_start_operation_context.get() - try: - client_handle = nexus_handle.to_workflow_handle( - ctx.client, result_type=self._output_type - ) - except Exception as err: - raise HandlerError( - "Failed to construct workflow handle from workflow operation token", - type=HandlerErrorType.NOT_FOUND, - cause=err, - ) - return await client_handle.result() + # try: + # nexus_handle = WorkflowHandle[OutputT].from_token(token) + # except Exception as err: + # raise HandlerError( + # "Failed to decode operation token as workflow operation token. " + # "Fetching result for non-workflow operations is not supported.", + # type=HandlerErrorType.NOT_FOUND, + # cause=err, + # ) + # ctx = _temporal_fetch_operation_context.get() + # try: + # client_handle = nexus_handle.to_workflow_handle( + # ctx.client, result_type=self._output_type + # ) + # except Exception as err: + # raise HandlerError( + # "Failed to construct workflow handle from workflow operation token", + # type=HandlerErrorType.NOT_FOUND, + # cause=err, + # ) + # return await client_handle.result() async def cancel_operation( @@ -148,7 +148,7 @@ async def cancel_operation( cause=err, ) - ctx = _temporal_start_operation_context.get() + ctx = _temporal_cancel_operation_context.get() try: client_workflow_handle = nexus_workflow_handle._to_client_workflow_handle( ctx.client From db0973313a69279a65f3b5c6e1f974bb2b439d4a Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 28 Jun 2025 13:14:30 -0400 Subject: [PATCH 120/183] RTU: Use nexusrpc.get_service_definition --- temporalio/workflow.py | 3 +-- .../test_dynamic_creation_of_user_handler_classes.py | 3 ++- tests/nexus/test_handler_operation_definitions.py | 8 +++----- tests/nexus/test_workflow_run_operation.py | 4 +++- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index fd18f0128..4df127665 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -5170,8 +5170,7 @@ def __init__( # class. if isinstance(service, str): self._service_name = service - # TODO(preview): make double-underscore attrs private to nexusrpc and expose getters/setters - elif service_defn := getattr(service, "__nexus_service__", None): + elif service_defn := nexusrpc.get_service_definition(service): self._service_name = service_defn.name else: raise ValueError( diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index b15257a45..96c3e711b 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -66,7 +66,8 @@ async def test_dynamic_creation_of_user_handler_classes(client: Client): ) ) - service_name = service_cls.__nexus_service__.name + assert (service_defn := nexusrpc.get_service_definition(service_cls)) + service_name = service_defn.name endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id async with Worker( diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index b0c1f2ac4..bcca554c7 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -88,11 +88,9 @@ async def workflow_run_operation_with_name_override( async def test_collected_operation_names( test_case: Type[_TestCase], ): - service: nexusrpc.ServiceDefinition = getattr( - test_case.Service, "__nexus_service__" - ) - assert isinstance(service, nexusrpc.ServiceDefinition) - assert service.name == "Service" + service_defn = nexusrpc.get_service_definition(test_case.Service) + assert isinstance(service_defn, nexusrpc.ServiceDefinition) + assert service_defn.name == "Service" for method_name, expected_op in test_case.expected_operations.items(): _, actual_op = nexusrpc.handler.get_operation_factory( getattr(test_case.Service, method_name) diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index 9faa39b3e..8eefa9ac3 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Any, Type +import nexusrpc import pytest from nexusrpc import Operation, service from nexusrpc.handler import ( @@ -103,10 +104,11 @@ async def test_workflow_run_operation( ): task_queue = str(uuid.uuid4()) endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + assert (service_defn := nexusrpc.get_service_definition(service_handler_cls)) service_client = ServiceClient( server_address=server_address(env), endpoint=endpoint, - service=service_handler_cls.__nexus_service__.name, + service=service_defn.name, ) async with Worker( env.client, From e721f5565a830636891c8560481ec1366b114d7e Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 28 Jun 2025 20:37:50 -0400 Subject: [PATCH 121/183] Docstring --- temporalio/worker/_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 7d30b3511..80b70a055 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -162,8 +162,8 @@ def __init__( activities: Activity callables decorated with :py:func:`@activity.defn`. Activities may be async functions or non-async functions. - nexus_service_handlers: Nexus service handler instances decorated with - :py:func:`@nexusrpc.handler.service_handler`. + nexus_service_handlers: Instances of Nexus service handler classes + decorated with :py:func:`@nexusrpc.handler.service_handler`. workflows: Workflow classes decorated with :py:func:`@workflow.defn`. activity_executor: Concurrent executor to use for non-async From ca5f57286f11b7159599c5b75c56db0369e63329 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 28 Jun 2025 21:23:42 -0400 Subject: [PATCH 122/183] RTU get_operation_factory --- temporalio/worker/_interceptor.py | 2 +- tests/nexus/test_dynamic_creation_of_user_handler_classes.py | 2 +- tests/nexus/test_handler_operation_definitions.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 2b20dcb46..5ae9f382c 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -313,7 +313,7 @@ def __post_init__(self) -> None: self._operation_name = self.operation self._input_type = None elif isinstance(self.operation, Callable): - _, op = nexusrpc.handler.get_operation_factory(self.operation) + _, op = nexusrpc.get_operation_factory(self.operation) if isinstance(op, nexusrpc.Operation): self._operation_name = op.name self._input_type = op.input_type diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index 96c3e711b..c63e6931d 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -3,8 +3,8 @@ import httpx import nexusrpc.handler import pytest +from nexusrpc import get_operation_factory from nexusrpc.handler import sync_operation -from nexusrpc.handler._util import get_operation_factory from temporalio.client import Client from temporalio.worker import Worker diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index bcca554c7..c3f812d05 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -92,7 +92,7 @@ async def test_collected_operation_names( assert isinstance(service_defn, nexusrpc.ServiceDefinition) assert service_defn.name == "Service" for method_name, expected_op in test_case.expected_operations.items(): - _, actual_op = nexusrpc.handler.get_operation_factory( + _, actual_op = nexusrpc.get_operation_factory( getattr(test_case.Service, method_name) ) assert isinstance(actual_op, nexusrpc.Operation) From e7091aca9336a8e34c10b95c3598da1a711b4a56 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 28 Jun 2025 13:14:37 -0400 Subject: [PATCH 123/183] Workflow OperationError / HandlerError test --- tests/nexus/test_workflow_caller.py | 94 ++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 1 deletion(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 7ca4d004b..925e470f6 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -2,7 +2,7 @@ import uuid from dataclasses import dataclass from enum import IntEnum -from typing import Any, Callable, Union +from typing import Any, Callable, Literal, Union import nexusrpc import nexusrpc.handler @@ -28,6 +28,7 @@ import temporalio.api.nexus.v1 import temporalio.api.operatorservice import temporalio.api.operatorservice.v1 +import temporalio.exceptions from temporalio import nexus, workflow from temporalio.client import ( Client, @@ -1082,3 +1083,94 @@ async def assert_handler_workflow_has_link_to_caller_workflow( # f"{self._result_fut} " # f"Task[{self._task._state}] fut_waiter = {self._task._fut_waiter}) ({self._task._must_cancel})" # ) + + +# Handler + +ActionInSyncOp = Literal["raise_handler_error", "raise_operation_error"] + + +@dataclass +class ErrorTestInput: + task_queue: str + action_in_sync_op: ActionInSyncOp + + +@nexusrpc.handler.service_handler +class ErrorTestService: + @sync_operation + async def op(self, ctx: StartOperationContext, input: ErrorTestInput) -> None: + if input.action_in_sync_op == "raise_handler_error": + raise nexusrpc.handler.HandlerError( + "test", + type=nexusrpc.handler.HandlerErrorType.INTERNAL, + ) + elif input.action_in_sync_op == "raise_operation_error": + raise nexusrpc.OperationError( + "test", state=nexusrpc.OperationErrorState.FAILED + ) + else: + raise NotImplementedError( + f"Unhandled action_in_sync_op: {input.action_in_sync_op}" + ) + + +# Caller + + +@workflow.defn(sandboxed=False) +class ErrorTestCallerWorkflow: + @workflow.init + def __init__(self, input: ErrorTestInput): + self.nexus_client = workflow.NexusClient( + service=ErrorTestService, + endpoint=make_nexus_endpoint_name(input.task_queue), + ) + + @workflow.run + async def run(self, input: ErrorTestInput) -> list[str]: + try: + await self.nexus_client.execute_operation( + # TODO(nexus-preview): why wasn't this a type error? + # ErrorTestService.op, ErrorTestCallerWfInput() + ErrorTestService.op, + # TODO(nexus-preview): why wasn't this a type error? + # None + input, + ) + except Exception as err: + return [str(type(err).__name__), str(type(err.__cause__).__name__)] + assert False, "Unreachable" + + +@pytest.mark.parametrize( + "action_in_sync_op", ["raise_handler_error", "raise_operation_error"] +) +async def test_errors_raised_by_nexus_operation( + client: Client, action_in_sync_op: ActionInSyncOp +): + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_service_handlers=[ErrorTestService()], + workflows=[ErrorTestCallerWorkflow], + task_queue=task_queue, + ): + await create_nexus_endpoint(task_queue, client) + result = await client.execute_workflow( + ErrorTestCallerWorkflow.run, + ErrorTestInput( + task_queue=task_queue, + action_in_sync_op=action_in_sync_op, + ), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + if action_in_sync_op == "raise_handler_error": + assert result == ["NexusOperationError", "NexusHandlerError"] + elif action_in_sync_op == "raise_operation_error": + assert result == ["NexusOperationError", "ApplicationError"] + else: + raise NotImplementedError( + f"Unhandled action_in_sync_op: {action_in_sync_op}" + ) From a9bac66287fbccde67e50852647265a4c60abdfc Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 28 Jun 2025 22:50:42 -0400 Subject: [PATCH 124/183] Convert nexus_handler_failure_info as nexusrpc.HandlerError --- temporalio/converter.py | 15 +++++++++++++-- temporalio/exceptions.py | 16 ---------------- tests/nexus/test_workflow_caller.py | 10 +++++----- 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index b976eca08..43dbe305b 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -16,6 +16,7 @@ from datetime import datetime from enum import IntEnum from itertools import zip_longest +from logging import getLogger from typing import ( Any, Awaitable, @@ -40,6 +41,7 @@ import google.protobuf.json_format import google.protobuf.message import google.protobuf.symbol_database +import nexusrpc import typing_extensions import temporalio.api.common.v1 @@ -60,6 +62,8 @@ if sys.version_info >= (3, 10): from types import UnionType +logger = getLogger(__name__) + class PayloadConverter(ABC): """Base payload converter to/from multiple payloads/values.""" @@ -1014,9 +1018,16 @@ def from_failure( ) elif failure.HasField("nexus_handler_failure_info"): nexus_handler_failure_info = failure.nexus_handler_failure_info - err = temporalio.exceptions.NexusHandlerError( + try: + _type = nexusrpc.HandlerErrorType[nexus_handler_failure_info.type] + except KeyError: + logger.warning( + f"Unknown Nexus HandlerErrorType: {nexus_handler_failure_info.type}" + ) + _type = nexusrpc.HandlerErrorType.INTERNAL + return nexusrpc.HandlerError( failure.message or "Nexus handler error", - type=nexus_handler_failure_info.type, + type=_type, retryable={ temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE: True, temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: False, diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index e687482f6..0a1cd9a1d 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -362,22 +362,6 @@ def retry_state(self) -> Optional[RetryState]: return self._retry_state -class NexusHandlerError(FailureError): - """Error raised on Nexus handler failure.""" - - def __init__( - self, - message: str, - *, - type: str, - retryable: Optional[bool] = None, - ): - """Initialize a Nexus handler error.""" - super().__init__(message) - self._type = type - self._retryable = retryable - - class NexusOperationError(FailureError): """Error raised on Nexus operation failure.""" diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 925e470f6..b5fff54f8 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -38,7 +38,7 @@ WorkflowHandle, ) from temporalio.common import WorkflowIDConflictPolicy -from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError +from temporalio.exceptions import CancelledError, NexusOperationError from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import Worker @@ -479,7 +479,7 @@ async def test_sync_response( e = ei.value assert isinstance(e, WorkflowFailureError) assert isinstance(e.__cause__, NexusOperationError) - assert isinstance(e.__cause__.__cause__, NexusHandlerError) + assert isinstance(e.__cause__.__cause__, nexusrpc.HandlerError) # ID of first command assert e.__cause__.scheduled_event_id == 5 assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue) @@ -532,7 +532,7 @@ async def test_async_response( e = ei.value assert isinstance(e, WorkflowFailureError) assert isinstance(e.__cause__, NexusOperationError) - assert isinstance(e.__cause__.__cause__, NexusHandlerError) + assert isinstance(e.__cause__.__cause__, nexusrpc.HandlerError) # ID of first command after update accepted assert e.__cause__.scheduled_event_id == 6 assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue) @@ -709,7 +709,7 @@ async def test_untyped_caller( e = ei.value assert isinstance(e, WorkflowFailureError) assert isinstance(e.__cause__, NexusOperationError) - assert isinstance(e.__cause__.__cause__, NexusHandlerError) + assert isinstance(e.__cause__.__cause__, nexusrpc.HandlerError) else: result = await caller_wf_handle.result() assert result.op_output.value == ( @@ -1167,7 +1167,7 @@ async def test_errors_raised_by_nexus_operation( task_queue=task_queue, ) if action_in_sync_op == "raise_handler_error": - assert result == ["NexusOperationError", "NexusHandlerError"] + assert result == ["NexusOperationError", "HandlerError"] elif action_in_sync_op == "raise_operation_error": assert result == ["NexusOperationError", "ApplicationError"] else: From bf2a02d11e0bbbeee55c4f806bfd8706afafcafe Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 28 Jun 2025 22:43:54 -0400 Subject: [PATCH 125/183] RTU: Move HandlerError to root module --- temporalio/nexus/_operation_handlers.py | 4 +- temporalio/worker/_nexus.py | 54 ++++++++++++------------- tests/nexus/test_handler.py | 10 +++-- tests/nexus/test_workflow_caller.py | 4 +- 4 files changed, 38 insertions(+), 34 deletions(-) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index 5a1335b59..449e7ecb3 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -9,6 +9,8 @@ ) from nexusrpc import ( + HandlerError, + HandlerErrorType, InputT, OperationInfo, OutputT, @@ -17,8 +19,6 @@ CancelOperationContext, FetchOperationInfoContext, FetchOperationResultContext, - HandlerError, - HandlerErrorType, OperationHandler, StartOperationContext, StartOperationResultAsync, diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 661ea0892..59d3fc25f 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -343,7 +343,7 @@ async def _operation_error_to_proto( ) async def _handler_error_to_proto( - self, err: nexusrpc.handler.HandlerError + self, err: nexusrpc.HandlerError ) -> temporalio.api.nexus.v1.HandlerError: return temporalio.api.nexus.v1.HandlerError( error_type=err.type.value, @@ -378,33 +378,33 @@ async def deserialize( ) return input except Exception as err: - raise nexusrpc.handler.HandlerError( + raise nexusrpc.HandlerError( "Data converter failed to decode Nexus operation input", - type=nexusrpc.handler.HandlerErrorType.BAD_REQUEST, + type=nexusrpc.HandlerErrorType.BAD_REQUEST, cause=err, retryable=False, ) from err # TODO(nexus-prerelease): tests for this function -def _exception_to_handler_error(err: BaseException) -> nexusrpc.handler.HandlerError: +def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError: # Based on sdk-typescript's convertKnownErrors: # https://github.com/temporalio/sdk-typescript/blob/nexus/packages/worker/src/nexus.ts - if isinstance(err, nexusrpc.handler.HandlerError): + if isinstance(err, nexusrpc.HandlerError): return err elif isinstance(err, ApplicationError): - return nexusrpc.handler.HandlerError( + return nexusrpc.HandlerError( # TODO(nexus-prerelease): what should message be? err.message, - type=nexusrpc.handler.HandlerErrorType.INTERNAL, + type=nexusrpc.HandlerErrorType.INTERNAL, cause=err, retryable=not err.non_retryable, ) elif isinstance(err, RPCError): if err.status == RPCStatusCode.INVALID_ARGUMENT: - return nexusrpc.handler.HandlerError( + return nexusrpc.HandlerError( err.message, - type=nexusrpc.handler.HandlerErrorType.BAD_REQUEST, + type=nexusrpc.HandlerErrorType.BAD_REQUEST, cause=err, ) elif err.status in [ @@ -412,16 +412,16 @@ def _exception_to_handler_error(err: BaseException) -> nexusrpc.handler.HandlerE RPCStatusCode.FAILED_PRECONDITION, RPCStatusCode.OUT_OF_RANGE, ]: - return nexusrpc.handler.HandlerError( + return nexusrpc.HandlerError( err.message, - type=nexusrpc.handler.HandlerErrorType.INTERNAL, + type=nexusrpc.HandlerErrorType.INTERNAL, cause=err, retryable=False, ) elif err.status in [RPCStatusCode.ABORTED, RPCStatusCode.UNAVAILABLE]: - return nexusrpc.handler.HandlerError( + return nexusrpc.HandlerError( err.message, - type=nexusrpc.handler.HandlerErrorType.UNAVAILABLE, + type=nexusrpc.HandlerErrorType.UNAVAILABLE, cause=err, ) elif err.status in [ @@ -436,37 +436,37 @@ def _exception_to_handler_error(err: BaseException) -> nexusrpc.handler.HandlerE # we convert to internal because this is not a client auth error and happens # when the handler fails to auth with Temporal and should be considered # retryable. - return nexusrpc.handler.HandlerError( - err.message, type=nexusrpc.handler.HandlerErrorType.INTERNAL, cause=err + return nexusrpc.HandlerError( + err.message, type=nexusrpc.HandlerErrorType.INTERNAL, cause=err ) elif err.status == RPCStatusCode.NOT_FOUND: - return nexusrpc.handler.HandlerError( - err.message, type=nexusrpc.handler.HandlerErrorType.NOT_FOUND, cause=err + return nexusrpc.HandlerError( + err.message, type=nexusrpc.HandlerErrorType.NOT_FOUND, cause=err ) elif err.status == RPCStatusCode.RESOURCE_EXHAUSTED: - return nexusrpc.handler.HandlerError( + return nexusrpc.HandlerError( err.message, - type=nexusrpc.handler.HandlerErrorType.RESOURCE_EXHAUSTED, + type=nexusrpc.HandlerErrorType.RESOURCE_EXHAUSTED, cause=err, ) elif err.status == RPCStatusCode.UNIMPLEMENTED: - return nexusrpc.handler.HandlerError( + return nexusrpc.HandlerError( err.message, - type=nexusrpc.handler.HandlerErrorType.NOT_IMPLEMENTED, + type=nexusrpc.HandlerErrorType.NOT_IMPLEMENTED, cause=err, ) elif err.status == RPCStatusCode.DEADLINE_EXCEEDED: - return nexusrpc.handler.HandlerError( + return nexusrpc.HandlerError( err.message, - type=nexusrpc.handler.HandlerErrorType.UPSTREAM_TIMEOUT, + type=nexusrpc.HandlerErrorType.UPSTREAM_TIMEOUT, cause=err, ) else: - return nexusrpc.handler.HandlerError( + return nexusrpc.HandlerError( f"Unhandled RPC error status: {err.status}", - type=nexusrpc.handler.HandlerErrorType.INTERNAL, + type=nexusrpc.HandlerErrorType.INTERNAL, cause=err, ) - return nexusrpc.handler.HandlerError( - str(err), type=nexusrpc.handler.HandlerErrorType.INTERNAL, cause=err + return nexusrpc.HandlerError( + str(err), type=nexusrpc.HandlerErrorType.INTERNAL, cause=err ) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 54d16abe3..c5f6c39fa 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -27,13 +27,17 @@ import httpx import nexusrpc import pytest -from nexusrpc import OperationError, OperationErrorState, OperationInfo +from nexusrpc import ( + HandlerError, + HandlerErrorType, + OperationError, + OperationErrorState, + OperationInfo, +) from nexusrpc.handler import ( CancelOperationContext, FetchOperationInfoContext, FetchOperationResultContext, - HandlerError, - HandlerErrorType, OperationHandler, StartOperationContext, service_handler, diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index b5fff54f8..f2eb5a2d9 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -1101,9 +1101,9 @@ class ErrorTestService: @sync_operation async def op(self, ctx: StartOperationContext, input: ErrorTestInput) -> None: if input.action_in_sync_op == "raise_handler_error": - raise nexusrpc.handler.HandlerError( + raise nexusrpc.HandlerError( "test", - type=nexusrpc.handler.HandlerErrorType.INTERNAL, + type=nexusrpc.HandlerErrorType.INTERNAL, ) elif input.action_in_sync_op == "raise_operation_error": raise nexusrpc.OperationError( From 9b6f836f2ead317c26523afb9e26fd583b719e5a Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 29 Jun 2025 08:34:55 -0400 Subject: [PATCH 126/183] RTU: test is fixed by syncio.sync_operation --- tests/nexus/test_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index c5f6c39fa..4a58da028 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -42,6 +42,7 @@ StartOperationContext, service_handler, sync_operation, + syncio, ) from nexusrpc.handler._decorators import operation_handler @@ -864,8 +865,7 @@ class EchoService: @service_handler(service=EchoService) class SyncStartHandler: - # TODO(nexus-prerelease): why is this test passing? start is not `async def` - @sync_operation + @syncio.sync_operation def echo(self, ctx: StartOperationContext, input: Input) -> Output: assert ctx.headers["test-header-key"] == "test-header-value" ctx.outbound_links.extend(ctx.inbound_links) From f53a7831ec789fa53b71406e538165a185838ee5 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 28 Jun 2025 22:58:36 -0400 Subject: [PATCH 127/183] RTU: unskip test --- tests/nexus/test_handler.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 4a58da028..cf5fb418d 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -89,9 +89,9 @@ class NonSerializableOutput: class MyService: echo: nexusrpc.Operation[Input, Output] # TODO(nexus-prerelease): support renamed operations! - # echo_renamed: nexusrpc.Operation[Input, Output] = ( - # nexusrpc.Operation(name="echo-renamed") - # ) + echo_renamed: nexusrpc.Operation[Input, Output] = nexusrpc.Operation( + name="echo-renamed" + ) hang: nexusrpc.Operation[Input, Output] log: nexusrpc.Operation[Input, Output] workflow_run_operation_happy_path: nexusrpc.Operation[Input, Output] @@ -147,6 +147,17 @@ async def echo(self, ctx: StartOperationContext, input: Input) -> Output: value=f"from start method on {self.__class__.__name__}: {input.value}" ) + # The name override is prsent in the service definition. But the test below submits + # the same operation name in the request whether using a service definition or now. + # The name override here is necessary when the test is not using the service + # definition. It should be permitted when the service definition is in effect, as + # long as the name override is the same as that in the service definition. + # TODO(nexus-prerelease): implement in nexusrpc the check that operation handler + # name overrides must be consistent with service definition overrides. + @sync_operation(name="echo-renamed") + async def echo_renamed(self, ctx: StartOperationContext, input: Input) -> Output: + return await self.echo(ctx, input) + @sync_operation async def hang(self, ctx: StartOperationContext, input: Input) -> Output: await asyncio.Future() @@ -467,13 +478,8 @@ class SyncHandlerHappyPath(_TestCase): ), "Nexus-Link header not echoed correctly." -class SyncHandlerHappyPathRenamed(_TestCase): +class SyncHandlerHappyPathRenamed(SyncHandlerHappyPath): operation = "echo-renamed" - input = Input("hello") - expected = SuccessfulResponse( - status_code=200, - body_json={"value": "from start method on MyServiceHandler: hello"}, - ) class SyncHandlerHappyPathNonAsyncDef(_TestCase): @@ -707,8 +713,7 @@ class NonSerializableOutputFailure(_FailureTestCase): "test_case", [ SyncHandlerHappyPath, - # TODO(nexus-prerelease): support renamed operations! - # SyncHandlerHappyPathRenamed, + SyncHandlerHappyPathRenamed, SyncHandlerHappyPathNonAsyncDef, # TODO(nexus-prerelease): make callable instance work # SyncHandlerHappyPathWithNonAsyncCallableInstance, From e30bba29ef0675713250be40e3b0da5f01001d08 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 29 Jun 2025 15:19:28 -0400 Subject: [PATCH 128/183] RTU: syncio tree --- tests/nexus/test_handler.py | 4 ++-- uv.lock | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index cf5fb418d..a495a1104 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -42,9 +42,9 @@ StartOperationContext, service_handler, sync_operation, - syncio, ) from nexusrpc.handler._decorators import operation_handler +from nexusrpc.syncio.handler import sync_operation as syncio_sync_operation from temporalio import nexus, workflow from temporalio.client import Client @@ -870,7 +870,7 @@ class EchoService: @service_handler(service=EchoService) class SyncStartHandler: - @syncio.sync_operation + @syncio_sync_operation def echo(self, ctx: StartOperationContext, input: Input) -> Output: assert ctx.headers["test-header-key"] == "test-header-value" ctx.outbound_links.extend(ctx.inbound_links) diff --git a/uv.lock b/uv.lock index fdea8c415..68471a43f 100644 --- a/uv.lock +++ b/uv.lock @@ -1055,8 +1055,9 @@ requires-dist = [{ name = "typing-extensions", specifier = ">=4.12.2" }] [package.metadata.requires-dev] dev = [ { name = "mypy", specifier = ">=1.15.0" }, + { name = "poethepoet", specifier = ">=0.35.0" }, { name = "pydoctor", specifier = ">=25.4.0" }, - { name = "pyright", specifier = ">=1.1.400" }, + { name = "pyright", specifier = ">=1.1.402" }, { name = "pytest", specifier = ">=8.3.5" }, { name = "pytest-asyncio", specifier = ">=0.26.0" }, { name = "pytest-cov", specifier = ">=6.1.1" }, From b3de2ef64edb653c5160394bdec5c902b6e0d5d1 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 29 Jun 2025 16:50:50 -0400 Subject: [PATCH 129/183] Don't pass cause to HandlerError constructor --- temporalio/nexus/_operation_handlers.py | 12 +++---- temporalio/worker/_nexus.py | 42 +++++++++++-------------- tests/nexus/test_handler.py | 3 +- 3 files changed, 23 insertions(+), 34 deletions(-) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index 449e7ecb3..8d1253979 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -112,8 +112,7 @@ async def fetch_result( # "Failed to decode operation token as workflow operation token. " # "Fetching result for non-workflow operations is not supported.", # type=HandlerErrorType.NOT_FOUND, - # cause=err, - # ) + # ) from err # ctx = _temporal_fetch_operation_context.get() # try: # client_handle = nexus_handle.to_workflow_handle( @@ -123,8 +122,7 @@ async def fetch_result( # raise HandlerError( # "Failed to construct workflow handle from workflow operation token", # type=HandlerErrorType.NOT_FOUND, - # cause=err, - # ) + # ) from err # return await client_handle.result() @@ -145,8 +143,7 @@ async def cancel_operation( "Failed to decode operation token as workflow operation token. " "Canceling non-workflow operations is not supported.", type=HandlerErrorType.NOT_FOUND, - cause=err, - ) + ) from err ctx = _temporal_cancel_operation_context.get() try: @@ -157,6 +154,5 @@ async def cancel_operation( raise HandlerError( "Failed to construct workflow handle from workflow operation token", type=HandlerErrorType.NOT_FOUND, - cause=err, - ) + ) from err await client_workflow_handle.cancel(**kwargs) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 59d3fc25f..72e3ef2bb 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -381,7 +381,6 @@ async def deserialize( raise nexusrpc.HandlerError( "Data converter failed to decode Nexus operation input", type=nexusrpc.HandlerErrorType.BAD_REQUEST, - cause=err, retryable=False, ) from err @@ -393,36 +392,32 @@ def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError: if isinstance(err, nexusrpc.HandlerError): return err elif isinstance(err, ApplicationError): - return nexusrpc.HandlerError( + handler_err = nexusrpc.HandlerError( # TODO(nexus-prerelease): what should message be? err.message, type=nexusrpc.HandlerErrorType.INTERNAL, - cause=err, retryable=not err.non_retryable, ) elif isinstance(err, RPCError): if err.status == RPCStatusCode.INVALID_ARGUMENT: - return nexusrpc.HandlerError( + handler_err = nexusrpc.HandlerError( err.message, type=nexusrpc.HandlerErrorType.BAD_REQUEST, - cause=err, ) elif err.status in [ RPCStatusCode.ALREADY_EXISTS, RPCStatusCode.FAILED_PRECONDITION, RPCStatusCode.OUT_OF_RANGE, ]: - return nexusrpc.HandlerError( + handler_err = nexusrpc.HandlerError( err.message, type=nexusrpc.HandlerErrorType.INTERNAL, - cause=err, retryable=False, ) elif err.status in [RPCStatusCode.ABORTED, RPCStatusCode.UNAVAILABLE]: - return nexusrpc.HandlerError( + handler_err = nexusrpc.HandlerError( err.message, type=nexusrpc.HandlerErrorType.UNAVAILABLE, - cause=err, ) elif err.status in [ RPCStatusCode.CANCELLED, @@ -436,37 +431,36 @@ def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError: # we convert to internal because this is not a client auth error and happens # when the handler fails to auth with Temporal and should be considered # retryable. - return nexusrpc.HandlerError( - err.message, type=nexusrpc.HandlerErrorType.INTERNAL, cause=err + handler_err = nexusrpc.HandlerError( + err.message, type=nexusrpc.HandlerErrorType.INTERNAL ) elif err.status == RPCStatusCode.NOT_FOUND: - return nexusrpc.HandlerError( - err.message, type=nexusrpc.HandlerErrorType.NOT_FOUND, cause=err + handler_err = nexusrpc.HandlerError( + err.message, type=nexusrpc.HandlerErrorType.NOT_FOUND ) elif err.status == RPCStatusCode.RESOURCE_EXHAUSTED: - return nexusrpc.HandlerError( + handler_err = nexusrpc.HandlerError( err.message, type=nexusrpc.HandlerErrorType.RESOURCE_EXHAUSTED, - cause=err, ) elif err.status == RPCStatusCode.UNIMPLEMENTED: - return nexusrpc.HandlerError( + handler_err = nexusrpc.HandlerError( err.message, type=nexusrpc.HandlerErrorType.NOT_IMPLEMENTED, - cause=err, ) elif err.status == RPCStatusCode.DEADLINE_EXCEEDED: - return nexusrpc.HandlerError( + handler_err = nexusrpc.HandlerError( err.message, type=nexusrpc.HandlerErrorType.UPSTREAM_TIMEOUT, - cause=err, ) else: - return nexusrpc.HandlerError( + handler_err = nexusrpc.HandlerError( f"Unhandled RPC error status: {err.status}", type=nexusrpc.HandlerErrorType.INTERNAL, - cause=err, ) - return nexusrpc.HandlerError( - str(err), type=nexusrpc.HandlerErrorType.INTERNAL, cause=err - ) + else: + handler_err = nexusrpc.HandlerError( + str(err), type=nexusrpc.HandlerErrorType.INTERNAL + ) + handler_err.__cause__ = err + return handler_err diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index a495a1104..5e9f18ab6 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -194,8 +194,7 @@ async def handler_error_internal( message="deliberate internal handler error", type=HandlerErrorType.INTERNAL, retryable=False, - cause=RuntimeError("cause message"), - ) + ) from RuntimeError("cause message") @sync_operation async def operation_error_failed( From 54a0a86716e029bb3c2e3c2720f023b6f975dc78 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 29 Jun 2025 20:38:21 -0400 Subject: [PATCH 130/183] RTU: registration time enforcement of syncio/asyncio mistakes --- tests/nexus/test_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 5e9f18ab6..cd542115c 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -919,7 +919,7 @@ class SyncHandlerNoExecutor(_InstantiationCase): handler = SyncStartHandler executor = False exception = RuntimeError - match = "must be an `async def`" + match = "Use nexusrpc.syncio.handler.Handler instead" class DefaultCancel(_InstantiationCase): @@ -932,7 +932,7 @@ class SyncCancel(_InstantiationCase): handler = SyncCancelHandler executor = False exception = RuntimeError - match = "cancel method must be an `async def`" + match = "Use nexusrpc.syncio.handler.Handler instead" @pytest.mark.parametrize( From 332658dac667f1a5f77936e58f977da2c870f395 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sun, 29 Jun 2025 22:21:54 -0400 Subject: [PATCH 131/183] WIP --- tests/nexus/test_workflow_caller.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index f2eb5a2d9..179aade20 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -1087,7 +1087,13 @@ async def assert_handler_workflow_has_link_to_caller_workflow( # Handler -ActionInSyncOp = Literal["raise_handler_error", "raise_operation_error"] +ActionInSyncOp = Literal[ + "raise_handler_error", "raise_operation_error", "raise_custom_error" +] + + +class CustomError(Exception): + pass @dataclass @@ -1109,6 +1115,8 @@ async def op(self, ctx: StartOperationContext, input: ErrorTestInput) -> None: raise nexusrpc.OperationError( "test", state=nexusrpc.OperationErrorState.FAILED ) + elif input.action_in_sync_op == "raise_custom_error": + raise CustomError("test") else: raise NotImplementedError( f"Unhandled action_in_sync_op: {input.action_in_sync_op}" @@ -1144,7 +1152,8 @@ async def run(self, input: ErrorTestInput) -> list[str]: @pytest.mark.parametrize( - "action_in_sync_op", ["raise_handler_error", "raise_operation_error"] + "action_in_sync_op", + ["raise_handler_error", "raise_operation_error", "raise_custom_error"], ) async def test_errors_raised_by_nexus_operation( client: Client, action_in_sync_op: ActionInSyncOp @@ -1166,10 +1175,16 @@ async def test_errors_raised_by_nexus_operation( id=str(uuid.uuid4()), task_queue=task_queue, ) + + print(f"\n\n\n{action_in_sync_op}: \n", result, "\n\n\n") + if action_in_sync_op == "raise_handler_error": assert result == ["NexusOperationError", "HandlerError"] elif action_in_sync_op == "raise_operation_error": assert result == ["NexusOperationError", "ApplicationError"] + elif action_in_sync_op == "raise_custom_error": + # assert result == ["NexusOperationError", "CustomError"] + pass else: raise NotImplementedError( f"Unhandled action_in_sync_op: {action_in_sync_op}" From 0bba35e0bce978fbf544ae6f14b251897c5602bb Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 30 Jun 2025 15:24:46 -0400 Subject: [PATCH 132/183] RTU: Copy operation factory getter/setter from nexusrpc --- temporalio/nexus/_util.py | 34 +++++++++++++++++++ temporalio/worker/_interceptor.py | 3 +- ...ynamic_creation_of_user_handler_classes.py | 2 +- .../test_handler_operation_definitions.py | 5 ++- 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index 8b24383ad..e90ba8fc0 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -125,6 +125,40 @@ def get_callable_name(fn: Callable[..., Any]) -> str: return method_name +# TODO(nexus-preview) Copied from nexusrpc +def get_operation_factory( + obj: Any, +) -> tuple[ + Optional[Callable[[Any], OperationHandler[InputT, OutputT]]], + Optional[nexusrpc.Operation[InputT, OutputT]], +]: + """Return the :py:class:`Operation` for the object along with the factory function. + + ``obj`` should be a decorated operation start method. + """ + op_defn = get_operation_definition(obj) + if op_defn: + factory = obj + else: + if factory := getattr(obj, "__nexus_operation_factory__", None): + op_defn = get_operation_definition(factory) + if not isinstance(op_defn, nexusrpc.Operation): + return None, None + return factory, op_defn + + +# TODO(nexus-preview) Copied from nexusrpc +def set_operation_factory( + obj: Any, + operation_factory: Callable[[Any], OperationHandler[InputT, OutputT]], +) -> None: + """Set the :py:class:`OperationHandler` factory for this object. + + ``obj`` should be an operation start method. + """ + setattr(obj, "__nexus_operation_factory__", operation_factory) + + # Copied from https://github.com/modelcontextprotocol/python-sdk # # Copyright (c) 2024 Anthropic, PBC. diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 5ae9f382c..8499a5136 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -25,6 +25,7 @@ import temporalio.api.common.v1 import temporalio.common import temporalio.nexus +import temporalio.nexus._util import temporalio.workflow from temporalio.workflow import VersioningIntent @@ -313,7 +314,7 @@ def __post_init__(self) -> None: self._operation_name = self.operation self._input_type = None elif isinstance(self.operation, Callable): - _, op = nexusrpc.get_operation_factory(self.operation) + _, op = temporalio.nexus._util.get_operation_factory(self.operation) if isinstance(op, nexusrpc.Operation): self._operation_name = op.name self._input_type = op.input_type diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index c63e6931d..dd0c57017 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -3,10 +3,10 @@ import httpx import nexusrpc.handler import pytest -from nexusrpc import get_operation_factory from nexusrpc.handler import sync_operation from temporalio.client import Client +from temporalio.nexus._util import get_operation_factory from temporalio.worker import Worker from tests.helpers.nexus import create_nexus_endpoint diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index c3f812d05..ce124a8b0 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -11,6 +11,7 @@ from temporalio import nexus from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation +from temporalio.nexus._util import get_operation_factory @dataclass @@ -92,9 +93,7 @@ async def test_collected_operation_names( assert isinstance(service_defn, nexusrpc.ServiceDefinition) assert service_defn.name == "Service" for method_name, expected_op in test_case.expected_operations.items(): - _, actual_op = nexusrpc.get_operation_factory( - getattr(test_case.Service, method_name) - ) + _, actual_op = get_operation_factory(getattr(test_case.Service, method_name)) assert isinstance(actual_op, nexusrpc.Operation) assert actual_op.name == expected_op.name assert actual_op.input_type == expected_op.input_type From 3ee29e75cdcfb852a4cda96375e4f5cad281be52 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 30 Jun 2025 15:25:21 -0400 Subject: [PATCH 133/183] Use getters/setters --- temporalio/nexus/_decorators.py | 17 ++++++++++------- temporalio/nexus/_util.py | 11 ++++++----- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index b40fb7634..b1a30f93c 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -28,6 +28,7 @@ from temporalio.nexus._util import ( get_callable_name, get_workflow_run_start_method_input_and_output_type_annotations, + set_operation_factory, ) ServiceHandlerT = TypeVar("ServiceHandlerT") @@ -124,15 +125,17 @@ async def _start( return WorkflowRunOperationHandler(_start, input_type, output_type) method_name = get_callable_name(start) - # TODO(nexus-preview): make double-underscore attrs private to nexusrpc and expose getters/setters - operation_handler_factory.__nexus_operation__ = nexusrpc.Operation( - name=name or method_name, - method_name=method_name, - input_type=input_type, - output_type=output_type, + nexusrpc.set_operation_definition( + operation_handler_factory, + nexusrpc.Operation( + name=name or method_name, + method_name=method_name, + input_type=input_type, + output_type=output_type, + ), ) - start.__nexus_operation_factory__ = operation_handler_factory + set_operation_factory(start, operation_handler_factory) return start if start is None: diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index e90ba8fc0..3c2cc9fe4 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -14,6 +14,7 @@ Union, ) +import nexusrpc from nexusrpc import ( InputT, OutputT, @@ -129,19 +130,19 @@ def get_callable_name(fn: Callable[..., Any]) -> str: def get_operation_factory( obj: Any, ) -> tuple[ - Optional[Callable[[Any], OperationHandler[InputT, OutputT]]], - Optional[nexusrpc.Operation[InputT, OutputT]], + Optional[Callable[[Any], Any]], + Optional[nexusrpc.Operation[Any, Any]], ]: """Return the :py:class:`Operation` for the object along with the factory function. ``obj`` should be a decorated operation start method. """ - op_defn = get_operation_definition(obj) + op_defn = nexusrpc.get_operation_definition(obj) if op_defn: factory = obj else: if factory := getattr(obj, "__nexus_operation_factory__", None): - op_defn = get_operation_definition(factory) + op_defn = nexusrpc.get_operation_definition(factory) if not isinstance(op_defn, nexusrpc.Operation): return None, None return factory, op_defn @@ -150,7 +151,7 @@ def get_operation_factory( # TODO(nexus-preview) Copied from nexusrpc def set_operation_factory( obj: Any, - operation_factory: Callable[[Any], OperationHandler[InputT, OutputT]], + operation_factory: Callable[[Any], Any], ) -> None: """Set the :py:class:`OperationHandler` factory for this object. From 5be55aa34d2f5d4344069eb52d199c2f3a5f102c Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 30 Jun 2025 19:12:22 -0400 Subject: [PATCH 134/183] Move no-type-annotations test to invalid usage test --- tests/nexus/test_workflow_run_operation.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index 8eefa9ac3..0a00e32b6 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -1,3 +1,4 @@ +import re import uuid from dataclasses import dataclass from typing import Any, Type @@ -71,15 +72,6 @@ class Service: op: Operation[Input, str] -@service_handler -class SubclassingNoInputOutputTypeAnnotationsWithoutServiceDefinition: - @operation_handler - def op(self) -> OperationHandler: - return MyOperation() - - __expected__error__ = 500, "'dict' object has no attribute 'value'" - - @service_handler(service=Service) class SubclassingNoInputOutputTypeAnnotationsWithServiceDefinition: # Despite the lack of annotations on the service impl, the service definition @@ -94,7 +86,6 @@ def op(self) -> OperationHandler: "service_handler_cls", [ SubclassingHappyPath, - SubclassingNoInputOutputTypeAnnotationsWithoutServiceDefinition, SubclassingNoInputOutputTypeAnnotationsWithServiceDefinition, ], ) @@ -123,7 +114,7 @@ async def test_workflow_run_operation( status_code, message = service_handler_cls.__expected__error__ assert resp.status_code == status_code failure = Failure(**resp.json()) - assert failure.message == message + assert re.search(message, failure.message) else: assert resp.status_code == 201 From 109c97690d04056764baf154ffada72718905b03 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 30 Jun 2025 21:03:49 -0400 Subject: [PATCH 135/183] Remove operations without type annotations --- tests/nexus/test_handler.py | 94 +++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 45 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index cd542115c..68ebc2c4f 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -95,8 +95,6 @@ class MyService: hang: nexusrpc.Operation[Input, Output] log: nexusrpc.Operation[Input, Output] workflow_run_operation_happy_path: nexusrpc.Operation[Input, Output] - workflow_run_operation_without_type_annotations: nexusrpc.Operation[Input, Output] - sync_operation_without_type_annotations: nexusrpc.Operation[Input, Output] sync_operation_with_non_async_def: nexusrpc.Operation[Input, Output] # TODO(nexus-prerelease): fix tests of callable instances # sync_operation_with_non_async_callable_instance: nexusrpc.Operation[Input, Output] @@ -263,22 +261,6 @@ def __call__( sync_operation_with_non_async_callable_instance, ) - @sync_operation - async def sync_operation_without_type_annotations(self, ctx, input): - # Despite the lack of type annotations, the input type from the op definition in - # the service definition is used to deserialize the input. - return Output( - value=f"from start method on {self.__class__.__name__} without type annotations: {input}" - ) - - @workflow_run_operation - async def workflow_run_operation_without_type_annotations(self, ctx, input): - return await ctx.start_workflow( - WorkflowWithoutTypeAnnotations.run, - input, - id=str(uuid.uuid4()), - ) - @workflow_run_operation async def workflow_run_op_link_test( self, ctx: WorkflowRunOperationContext, input: Input @@ -500,23 +482,6 @@ class SyncHandlerHappyPathWithNonAsyncCallableInstance(_TestCase): skip = "TODO(nexus-prerelease): fix tests of callable instances" -class SyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): - operation = "sync_operation_without_type_annotations" - input = Input("hello") - expected = SuccessfulResponse( - status_code=200, - body_json={ - "value": "from start method on MyServiceHandler without type annotations: Input(value='hello')" - }, - ) - expected_without_service_definition = SuccessfulResponse( - status_code=200, - body_json={ - "value": "from start method on MyServiceHandler without type annotations: {'value': 'hello'}" - }, - ) - - class AsyncHandlerHappyPath(_TestCase): operation = "workflow_run_operation_happy_path" input = Input("hello") @@ -526,14 +491,6 @@ class AsyncHandlerHappyPath(_TestCase): ) -class AsyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): - operation = "workflow_run_operation_without_type_annotations" - input = Input("hello") - expected = SuccessfulResponse( - status_code=201, - ) - - class WorkflowRunOpLinkTestHappyPath(_TestCase): # TODO(nexus-prerelease): fix this test skip = "Yields invalid link" @@ -716,9 +673,7 @@ class NonSerializableOutputFailure(_FailureTestCase): SyncHandlerHappyPathNonAsyncDef, # TODO(nexus-prerelease): make callable instance work # SyncHandlerHappyPathWithNonAsyncCallableInstance, - SyncHandlerHappyPathWithoutTypeAnnotations, AsyncHandlerHappyPath, - AsyncHandlerHappyPathWithoutTypeAnnotations, WorkflowRunOpLinkTestHappyPath, ], ) @@ -807,6 +762,55 @@ async def _test_start_operation( assert not any(warnings), [w.message for w in warnings] +@nexusrpc.service +class MyServiceWithOperationsWithoutTypeAnnotations(MyService): + workflow_run_operation_without_type_annotations: nexusrpc.Operation[Input, Output] + sync_operation_without_type_annotations: nexusrpc.Operation[Input, Output] + + +class MyServiceHandlerWithOperationsWithoutTypeAnnotations(MyServiceHandler): + @sync_operation + async def sync_operation_without_type_annotations(self, ctx, input): + # Despite the lack of type annotations, the input type from the op definition in + # the service definition is used to deserialize the input. + return Output( + value=f"from start method on {self.__class__.__name__} without type annotations: {input}" + ) + + @workflow_run_operation + async def workflow_run_operation_without_type_annotations(self, ctx, input): + return await ctx.start_workflow( + WorkflowWithoutTypeAnnotations.run, + input, + id=str(uuid.uuid4()), + ) + + +class SyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): + operation = "sync_operation_without_type_annotations" + input = Input("hello") + expected = SuccessfulResponse( + status_code=200, + body_json={ + "value": "from start method on MyServiceHandler without type annotations: Input(value='hello')" + }, + ) + expected_without_service_definition = SuccessfulResponse( + status_code=200, + body_json={ + "value": "from start method on MyServiceHandler without type annotations: {'value': 'hello'}" + }, + ) + + +class AsyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): + operation = "workflow_run_operation_without_type_annotations" + input = Input("hello") + expected = SuccessfulResponse( + status_code=201, + ) + + async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: Any): task_queue = str(uuid.uuid4()) service_name = MyService.__name__ From ad42e6783f8518a71839d140660fba60c79959c7 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 30 Jun 2025 21:31:11 -0400 Subject: [PATCH 136/183] Split test --- tests/nexus/test_handler.py | 60 +++++++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 68ebc2c4f..ce6367ceb 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -683,7 +683,10 @@ async def test_start_operation_happy_path( with_service_definition: bool, env: WorkflowEnvironment, ): - await _test_start_operation(test_case, with_service_definition, env) + if with_service_definition: + await _test_start_operation_with_service_definition(test_case, env) + else: + await _test_start_operation_without_service_definition(test_case, env) @pytest.mark.parametrize( @@ -702,7 +705,7 @@ async def test_start_operation_happy_path( async def test_start_operation_protocol_level_failures( test_case: Type[_TestCase], env: WorkflowEnvironment ): - await _test_start_operation(test_case, True, env) + await _test_start_operation_with_service_definition(test_case, env) @pytest.mark.parametrize( @@ -716,12 +719,11 @@ async def test_start_operation_protocol_level_failures( async def test_start_operation_operation_failures( test_case: Type[_TestCase], env: WorkflowEnvironment ): - await _test_start_operation(test_case, True, env) + await _test_start_operation_with_service_definition(test_case, env) -async def _test_start_operation( +async def _test_start_operation_with_service_definition( test_case: Type[_TestCase], - with_service_definition: bool, env: WorkflowEnvironment, ): if test_case.skip: @@ -731,19 +733,45 @@ async def _test_start_operation( service_client = ServiceClient( server_address=server_address(env), endpoint=endpoint, - service=( - test_case.service_defn - if with_service_definition - else MyServiceHandler.__name__ - ), + service=(test_case.service_defn), ) with pytest.WarningsRecorder() as warnings: - decorator = ( - service_handler(service=MyService) - if with_service_definition - else service_handler - ) + decorator = service_handler(service=MyService) + user_service_handler = decorator(MyServiceHandler)() + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[user_service_handler], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + response = await service_client.start_operation( + test_case.operation, + dataclass_as_dict(test_case.input), + test_case.headers, + ) + test_case.check_response(response, with_service_definition=True) + + assert not any(warnings), [w.message for w in warnings] + + +async def _test_start_operation_without_service_definition( + test_case: Type[_TestCase], + env: WorkflowEnvironment, +): + if test_case.skip: + pytest.skip(test_case.skip) + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=MyServiceHandler.__name__, + ) + + with pytest.WarningsRecorder() as warnings: + decorator = service_handler user_service_handler = decorator(MyServiceHandler)() async with Worker( @@ -757,7 +785,7 @@ async def _test_start_operation( dataclass_as_dict(test_case.input), test_case.headers, ) - test_case.check_response(response, with_service_definition) + test_case.check_response(response, with_service_definition=False) assert not any(warnings), [w.message for w in warnings] From 5a5d9c6a901c59bd67c9a7ea709cd44c90c8f7b4 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 30 Jun 2025 21:37:13 -0400 Subject: [PATCH 137/183] Test operations without type annotations --- tests/nexus/test_handler.py | 66 ++++++++++++++++++++++++++++++++----- 1 file changed, 57 insertions(+), 9 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index ce6367ceb..6aaeac1d3 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -791,12 +791,12 @@ async def _test_start_operation_without_service_definition( @nexusrpc.service -class MyServiceWithOperationsWithoutTypeAnnotations(MyService): +class MyServiceWithOperationsWithoutTypeAnnotations: workflow_run_operation_without_type_annotations: nexusrpc.Operation[Input, Output] sync_operation_without_type_annotations: nexusrpc.Operation[Input, Output] -class MyServiceHandlerWithOperationsWithoutTypeAnnotations(MyServiceHandler): +class MyServiceHandlerWithOperationsWithoutTypeAnnotations: @sync_operation async def sync_operation_without_type_annotations(self, ctx, input): # Despite the lack of type annotations, the input type from the op definition in @@ -820,13 +820,7 @@ class SyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): expected = SuccessfulResponse( status_code=200, body_json={ - "value": "from start method on MyServiceHandler without type annotations: Input(value='hello')" - }, - ) - expected_without_service_definition = SuccessfulResponse( - status_code=200, - body_json={ - "value": "from start method on MyServiceHandler without type annotations: {'value': 'hello'}" + "value": "from start method on MyServiceHandlerWithOperationsWithoutTypeAnnotations without type annotations: Input(value='hello')" }, ) @@ -839,6 +833,60 @@ class AsyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): ) +# Attempting to use the service_handler decorator on a class containing an operation +# without type annotations is a validation error (test coverage in nexusrpc) +@pytest.mark.parametrize( + "test_case", + [ + SyncHandlerHappyPathWithoutTypeAnnotations, + AsyncHandlerHappyPathWithoutTypeAnnotations, + ], +) +async def test_start_operation_without_type_annotations( + test_case: Type[_TestCase], env: WorkflowEnvironment +): + if test_case.skip: + pytest.skip(test_case.skip) + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=MyServiceWithOperationsWithoutTypeAnnotations.__name__, + ) + + with pytest.WarningsRecorder() as warnings: + decorator = service_handler( + service=MyServiceWithOperationsWithoutTypeAnnotations + ) + user_service_handler = decorator( + MyServiceHandlerWithOperationsWithoutTypeAnnotations + )() + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[user_service_handler], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + response = await service_client.start_operation( + test_case.operation, + dataclass_as_dict(test_case.input), + test_case.headers, + ) + test_case.check_response(response, with_service_definition=True) + + assert not any(warnings), [w.message for w in warnings] + + +def test_operation_without_type_annotations_without_service_definition_raises_validation_error(): + with pytest.raises( + ValueError, + match=r"has no input type.+has no output type", + ): + service_handler(MyServiceHandlerWithOperationsWithoutTypeAnnotations) + + async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: Any): task_queue = str(uuid.uuid4()) service_name = MyService.__name__ From 86a9a61378f75b85416a10b582632ad726117eeb Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 30 Jun 2025 21:43:06 -0400 Subject: [PATCH 138/183] Delete redundant test --- tests/nexus/test_handler.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 6aaeac1d3..0bf4b9c4b 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -879,14 +879,6 @@ async def test_start_operation_without_type_annotations( assert not any(warnings), [w.message for w in warnings] -def test_operation_without_type_annotations_without_service_definition_raises_validation_error(): - with pytest.raises( - ValueError, - match=r"has no input type.+has no output type", - ): - service_handler(MyServiceHandlerWithOperationsWithoutTypeAnnotations) - - async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: Any): task_queue = str(uuid.uuid4()) service_name = MyService.__name__ From caef9305940991061967ce5368b27cdd949c80a8 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 30 Jun 2025 21:47:47 -0400 Subject: [PATCH 139/183] Delete failing callable instance test This is nexusrpc responsibility and it has a broken test like this --- tests/nexus/test_handler.py | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 0bf4b9c4b..5c7e25fa4 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -96,8 +96,6 @@ class MyService: log: nexusrpc.Operation[Input, Output] workflow_run_operation_happy_path: nexusrpc.Operation[Input, Output] sync_operation_with_non_async_def: nexusrpc.Operation[Input, Output] - # TODO(nexus-prerelease): fix tests of callable instances - # sync_operation_with_non_async_callable_instance: nexusrpc.Operation[Input, Output] operation_returning_unwrapped_result_at_runtime_error: nexusrpc.Operation[ Input, Output ] @@ -238,29 +236,6 @@ async def sync_operation_with_non_async_def( value=f"from start method on {self.__class__.__name__}: {input.value}" ) - if False: - # TODO(nexus-prerelease): fix tests of callable instances - def sync_operation_with_non_async_callable_instance( - self, - ) -> OperationHandler[Input, Output]: - class start: - def __call__( - self, - ctx: StartOperationContext, - input: Input, - ) -> Output: - return Output( - value=f"from start method on {self.__class__.__name__}: {input.value}" - ) - - return sync_operation(start()) - - _sync_operation_with_non_async_callable_instance = operation_handler( - name="sync_operation_with_non_async_callable_instance", - )( - sync_operation_with_non_async_callable_instance, - ) - @workflow_run_operation async def workflow_run_op_link_test( self, ctx: WorkflowRunOperationContext, input: Input @@ -472,16 +447,6 @@ class SyncHandlerHappyPathNonAsyncDef(_TestCase): ) -class SyncHandlerHappyPathWithNonAsyncCallableInstance(_TestCase): - operation = "sync_operation_with_non_async_callable_instance" - input = Input("hello") - expected = SuccessfulResponse( - status_code=200, - body_json={"value": "from start method on MyServiceHandler: hello"}, - ) - skip = "TODO(nexus-prerelease): fix tests of callable instances" - - class AsyncHandlerHappyPath(_TestCase): operation = "workflow_run_operation_happy_path" input = Input("hello") @@ -671,8 +636,6 @@ class NonSerializableOutputFailure(_FailureTestCase): SyncHandlerHappyPath, SyncHandlerHappyPathRenamed, SyncHandlerHappyPathNonAsyncDef, - # TODO(nexus-prerelease): make callable instance work - # SyncHandlerHappyPathWithNonAsyncCallableInstance, AsyncHandlerHappyPath, WorkflowRunOpLinkTestHappyPath, ], From b288eaa782a47d475a6e44116b8978b9f442221a Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 1 Jul 2025 09:45:19 -0400 Subject: [PATCH 140/183] Test error conversion --- tests/nexus/test_workflow_caller.py | 215 ++++++++++++++++++++++++---- 1 file changed, 191 insertions(+), 24 deletions(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 179aade20..d9905fe71 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -38,7 +38,7 @@ WorkflowHandle, ) from temporalio.common import WorkflowIDConflictPolicy -from temporalio.exceptions import CancelledError, NexusOperationError +from temporalio.exceptions import ApplicationError, CancelledError, NexusOperationError from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import Worker @@ -1087,8 +1087,134 @@ async def assert_handler_workflow_has_link_to_caller_workflow( # Handler +# @OperationImpl +# public OperationHandler testError() { +# return OperationHandler.sync( +# (ctx, details, input) -> { +# switch (input.getAction()) { +# case RAISE_APPLICATION_ERROR: +# throw ApplicationFailure.newNonRetryableFailure( +# "application error 1", "APPLICATION_ERROR"); +# case RAISE_CUSTOM_ERROR: +# throw new MyCustomException("Custom error 1"); +# case RAISE_CUSTOM_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# // ** THIS DOESN'T WORK **: CHAINED CUSTOM EXCEPTIONS DON'T SERIALIZE +# MyCustomException customError = new MyCustomException("Custom error 1"); +# customError.initCause(new MyCustomException("Custom error 2")); +# throw customError; +# case RAISE_APPLICATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# throw ApplicationFailure.newNonRetryableFailureWithCause( +# "application error 1", +# "APPLICATION_ERROR", +# new MyCustomException("Custom error 2")); +# case RAISE_NEXUS_HANDLER_ERROR: +# throw new HandlerException(HandlerException.ErrorType.NOT_FOUND, "Handler error 1"); +# case RAISE_NEXUS_HANDLER_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# // ** THIS DOESN'T WORK ** +# // Can't overwrite cause with +# // io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException: Custom error +# // 2 +# HandlerException handlerErr = +# new HandlerException(HandlerException.ErrorType.NOT_FOUND, "Handler error 1"); +# handlerErr.initCause(new MyCustomException("Custom error 2")); +# throw handlerErr; +# case RAISE_NEXUS_OPERATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# throw OperationException.failure( +# ApplicationFailure.newNonRetryableFailureWithCause( +# "application error 1", +# "APPLICATION_ERROR", +# new MyCustomException("Custom error 2"))); +# } +# return new NexusService.ErrorTestOutput("Unreachable"); +# }); +# } + +# 🌈 RAISE_APPLICATION_ERROR: +# io.temporal.failure.NexusOperationFailure(type=no-type-attr, message=Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=) +# io.nexusrpc.handler.HandlerException(type=no-type-attr, message=handler error: message='application error 1', type='APPLICATION_ERROR', nonRetryable=true) +# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='application error 1', type='APPLICATION_ERROR', nonRetryable=true) + + +# 🌈 RAISE_CUSTOM_ERROR: +# io.temporal.failure.NexusOperationFailure(type=no-type-attr, message=Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=) +# io.nexusrpc.handler.HandlerException(type=no-type-attr, message=handler error: message='Custom error wrapped: custom error 1', type='CUSTOM_ERROR', nonRetryable=true) +# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='Custom error wrapped: custom error 1', type='CUSTOM_ERROR', nonRetryable=true) + + +# 🌈 RAISE_APPLICATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# io.temporal.failure.NexusOperationFailure(type=no-type-attr, message=Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=) +# io.nexusrpc.handler.HandlerException(type=no-type-attr, message=handler error: message='application error 1', type='APPLICATION_ERROR', nonRetryable=true) +# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='application error 1', type='APPLICATION_ERROR', nonRetryable=true) +# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='Custom error 2', type='io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException', nonRetryable=false) + + +# 🌈 RAISE_NEXUS_HANDLER_ERROR: +# io.temporal.failure.NexusOperationFailure(type=no-type-attr, message=Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=) +# io.nexusrpc.handler.HandlerException(type=no-type-attr, message=handler error: message='Handler error 1', type='java.lang.RuntimeException', nonRetryable=false) +# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='Handler error 1', type='java.lang.RuntimeException', nonRetryable=false) + + +# 🌈 RAISE_NEXUS_HANDLER_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# io.temporal.failure.NexusOperationFailure(type=no-type-attr, message=Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=) +# io.temporal.failure.TimeoutFailure(type=no-type-attr, message=message='operation timed out', timeoutType=TIMEOUT_TYPE_SCHEDULE_TO_CLOSE) + + +# 🌈 RAISE_NEXUS_OPERATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# io.temporal.failure.NexusOperationFailure(type=no-type-attr, message=Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=) +# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='application error 1', type='APPLICATION_ERROR', nonRetryable=true) +# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='Custom error 2', type='io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException', nonRetryable=false) + +# @OperationImpl +# public OperationHandler testError() { +# return OperationHandler.sync( +# (ctx, details, input) -> { +# switch (input.getAction()) { +# case RAISE_APPLICATION_ERROR: +# throw ApplicationFailure.newNonRetryableFailure( +# "application error 1", "APPLICATION_ERROR"); +# case RAISE_CUSTOM_ERROR: +# throw new MyCustomException("Custom error 1"); +# case RAISE_CUSTOM_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# // ** THIS DOESN'T WORK **: CHAINED CUSTOM EXCEPTIONS DON'T SERIALIZE +# MyCustomException customError = new MyCustomException("Custom error 1"); +# customError.initCause(new MyCustomException("Custom error 2")); +# throw customError; +# case RAISE_APPLICATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# throw ApplicationFailure.newNonRetryableFailureWithCause( +# "application error 1", +# "APPLICATION_ERROR", +# new MyCustomException("Custom error 2")); +# case RAISE_NEXUS_HANDLER_ERROR: +# throw new HandlerException(HandlerException.ErrorType.NOT_FOUND, "Handler error 1"); +# case RAISE_NEXUS_HANDLER_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# // ** THIS DOESN'T WORK ** +# // Can't overwrite cause with +# // io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException: Custom error +# // 2 +# HandlerException handlerErr = +# new HandlerException(HandlerException.ErrorType.NOT_FOUND, "Handler error 1"); +# handlerErr.initCause(new MyCustomException("Custom error 2")); +# throw handlerErr; +# case RAISE_NEXUS_OPERATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# throw OperationException.failure( +# ApplicationFailure.newNonRetryableFailureWithCause( +# "application error 1", +# "APPLICATION_ERROR", +# new MyCustomException("Custom error 2"))); +# } +# return new NexusService.ErrorTestOutput("Unreachable"); +# }); +# } + + ActionInSyncOp = Literal[ - "raise_handler_error", "raise_operation_error", "raise_custom_error" + "application_error_non_retryable", + "custom_error", + "custom_error_from_custom_error", + "application_error_non_retryable_from_custom_error", + "nexus_handler_error_not_found", + "nexus_handler_error_not_found_from_custom_error", + "nexus_operation_error_from_application_error_non_retryable_from_custom_error", ] @@ -1106,17 +1232,46 @@ class ErrorTestInput: class ErrorTestService: @sync_operation async def op(self, ctx: StartOperationContext, input: ErrorTestInput) -> None: - if input.action_in_sync_op == "raise_handler_error": + if input.action_in_sync_op == "application_error_non_retryable": + raise ApplicationError("application error in nexus op", non_retryable=True) + elif input.action_in_sync_op == "custom_error": + raise CustomError("custom error in nexus op") + elif input.action_in_sync_op == "custom_error_from_custom_error": + raise CustomError("custom error 1 in nexus op") from CustomError( + "custom error 2 in nexus op" + ) + elif ( + input.action_in_sync_op + == "application_error_non_retryable_from_custom_error" + ): + raise ApplicationError( + "application error in nexus op", non_retryable=True + ) from CustomError("custom error in nexus op") + elif input.action_in_sync_op == "nexus_handler_error_not_found": raise nexusrpc.HandlerError( "test", - type=nexusrpc.HandlerErrorType.INTERNAL, + type=nexusrpc.HandlerErrorType.NOT_FOUND, ) - elif input.action_in_sync_op == "raise_operation_error": - raise nexusrpc.OperationError( - "test", state=nexusrpc.OperationErrorState.FAILED - ) - elif input.action_in_sync_op == "raise_custom_error": - raise CustomError("test") + elif ( + input.action_in_sync_op == "nexus_handler_error_not_found_from_custom_error" + ): + raise nexusrpc.HandlerError( + "test", + type=nexusrpc.HandlerErrorType.NOT_FOUND, + ) from CustomError("custom error in nexus op") + elif ( + input.action_in_sync_op + == "nexus_operation_error_from_application_error_non_retryable_from_custom_error" + ): + try: + raise ApplicationError( + "application error in nexus op", non_retryable=True + ) from CustomError("custom error in nexus op") + except ApplicationError as err: + raise nexusrpc.OperationError( + "operation error in nexus op", + state=nexusrpc.OperationErrorState.FAILED, + ) from err else: raise NotImplementedError( f"Unhandled action_in_sync_op: {input.action_in_sync_op}" @@ -1146,14 +1301,26 @@ async def run(self, input: ErrorTestInput) -> list[str]: # None input, ) - except Exception as err: - return [str(type(err).__name__), str(type(err.__cause__).__name__)] + except BaseException as err: + errs = [err] + while err.__cause__: + errs.append(err.__cause__) + err = err.__cause__ + return [type(err).__name__ for err in errs] assert False, "Unreachable" @pytest.mark.parametrize( "action_in_sync_op", - ["raise_handler_error", "raise_operation_error", "raise_custom_error"], + [ + "application_error_non_retryable", + "custom_error", + "custom_error_from_custom_error", + "application_error_non_retryable_from_custom_error", + "nexus_handler_error_not_found", + "nexus_handler_error_not_found_from_custom_error", + "nexus_operation_error_from_application_error_non_retryable_from_custom_error", + ], ) async def test_errors_raised_by_nexus_operation( client: Client, action_in_sync_op: ActionInSyncOp @@ -1178,14 +1345,14 @@ async def test_errors_raised_by_nexus_operation( print(f"\n\n\n{action_in_sync_op}: \n", result, "\n\n\n") - if action_in_sync_op == "raise_handler_error": - assert result == ["NexusOperationError", "HandlerError"] - elif action_in_sync_op == "raise_operation_error": - assert result == ["NexusOperationError", "ApplicationError"] - elif action_in_sync_op == "raise_custom_error": - # assert result == ["NexusOperationError", "CustomError"] - pass - else: - raise NotImplementedError( - f"Unhandled action_in_sync_op: {action_in_sync_op}" - ) + # if action_in_sync_op == "handler_error": + # assert result == ["NexusOperationError", "HandlerError"] + # elif action_in_sync_op == "operation_error": + # assert result == ["NexusOperationError", "ApplicationError"] + # elif action_in_sync_op == "custom_error": + # # assert result == ["NexusOperationError", "CustomError"] + # pass + # else: + # raise NotImplementedError( + # f"Unhandled action_in_sync_op: {action_in_sync_op}" + # ) From 6c08c80aeea178e7ca11e109d76f1bd051fd603e Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 1 Jul 2025 10:23:53 -0400 Subject: [PATCH 141/183] Translating Java assertions --- tests/nexus/test_workflow_caller.py | 258 ++++++++++++++++++++++------ 1 file changed, 205 insertions(+), 53 deletions(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index d9905fe71..53fbb3ed6 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -1164,47 +1164,211 @@ async def assert_handler_workflow_has_link_to_caller_workflow( # io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='application error 1', type='APPLICATION_ERROR', nonRetryable=true) # io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='Custom error 2', type='io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException', nonRetryable=false) -# @OperationImpl -# public OperationHandler testError() { -# return OperationHandler.sync( -# (ctx, details, input) -> { -# switch (input.getAction()) { -# case RAISE_APPLICATION_ERROR: -# throw ApplicationFailure.newNonRetryableFailure( -# "application error 1", "APPLICATION_ERROR"); -# case RAISE_CUSTOM_ERROR: -# throw new MyCustomException("Custom error 1"); -# case RAISE_CUSTOM_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: -# // ** THIS DOESN'T WORK **: CHAINED CUSTOM EXCEPTIONS DON'T SERIALIZE -# MyCustomException customError = new MyCustomException("Custom error 1"); -# customError.initCause(new MyCustomException("Custom error 2")); -# throw customError; -# case RAISE_APPLICATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: -# throw ApplicationFailure.newNonRetryableFailureWithCause( -# "application error 1", -# "APPLICATION_ERROR", -# new MyCustomException("Custom error 2")); -# case RAISE_NEXUS_HANDLER_ERROR: -# throw new HandlerException(HandlerException.ErrorType.NOT_FOUND, "Handler error 1"); -# case RAISE_NEXUS_HANDLER_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: -# // ** THIS DOESN'T WORK ** -# // Can't overwrite cause with -# // io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException: Custom error -# // 2 -# HandlerException handlerErr = -# new HandlerException(HandlerException.ErrorType.NOT_FOUND, "Handler error 1"); -# handlerErr.initCause(new MyCustomException("Custom error 2")); -# throw handlerErr; -# case RAISE_NEXUS_OPERATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: -# throw OperationException.failure( -# ApplicationFailure.newNonRetryableFailureWithCause( -# "application error 1", -# "APPLICATION_ERROR", -# new MyCustomException("Custom error 2"))); -# } -# return new NexusService.ErrorTestOutput("Unreachable"); -# }); -# } + +@dataclass +class ErrorConversionTestCase: + name: str + java_behavior: list[tuple[type[Exception], dict[str, Any]]] + + +error_conversion_test_cases = [] + + +# application_error_non_retryable: +_ = ["NexusOperationError", "HandlerError"] +# Java +_ = [ + "NexusOperationError", + "HandlerError('handler error: application error', type='APPLICATION_ERROR', non_retryable=True)", + "ApplicationError('application error', type='APPLICATION_ERROR', non_retryable=True)", +] + +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="application_error_non_retryable", + java_behavior=[ + (NexusOperationError, {}), + ( + nexusrpc.HandlerError, + { + "message": "application error", + "type": "APPLICATION_ERROR", + "non_retryable": True, + }, + ), + ( + ApplicationError, + { + "message": "application error", + "type": "APPLICATION_ERROR", + "non_retryable": True, + }, + ), + ], + ) +) + + +# custom_error: +_ = ["NexusOperationError", "HandlerError"] +# Java +_ = [ + "NexusOperationError", + "HandlerError('Custom error wrapped: custom error', type='CUSTOM_ERROR', non_retryable=True)", + "ApplicationError('Custom error wrapped: custom error', type='CUSTOM_ERROR', non_retryable=True)", +] +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="custom_error", + java_behavior=[ + (NexusOperationError, {}), + ( + nexusrpc.HandlerError, + { + "message": "Custom error wrapped: custom error", + "type": "CUSTOM_ERROR", + "non_retryable": True, + }, + ), + ( + ApplicationError, + { + "message": "Custom error wrapped: custom error", + "type": "CUSTOM_ERROR", + "non_retryable": True, + }, + ), + ], + ) +) + + +# custom_error_from_custom_error: +_ = ["NexusOperationError", "HandlerError"] +# Java +# [Not possible] + +# application_error_non_retryable_from_custom_error: +_ = ["NexusOperationError", "HandlerError"] +# Java +_ = [ + "NexusOperationError", + "HandlerError('handler error: application error', type='APPLICATION_ERROR', non_retryable=True)", + "ApplicationError('application error', type='APPLICATION_ERROR', non_retryable=True)", + "ApplicationError('custom error', type='MyCustomException', non_retryable=False)", +] + +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="application_error_non_retryable_from_custom_error", + java_behavior=[ + (NexusOperationError, {}), + ( + nexusrpc.HandlerError, + { + "message": "handler error: application error", + "type": "APPLICATION_ERROR", + "non_retryable": True, + }, + ), + ( + ApplicationError, + { + "message": "application error", + "type": "APPLICATION_ERROR", + "non_retryable": True, + }, + ), + ( + ApplicationError, + { + "message": "custom error", + "type": "MyCustomException", + "non_retryable": False, + }, + ), + ], + ) +) + +# nexus_handler_error_not_found: +_ = ["NexusOperationError", "HandlerError"] +# Java +_ = [ + "NexusOperationError", + "HandlerError('handler error: handler error', type='RuntimeException', non_retryable=False)", + "ApplicationError('handler error', type='RuntimeException', non_retryable=False)", +] + +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="application_error_non_retryable_from_custom_error", + java_behavior=[ + (NexusOperationError, {}), + ( + nexusrpc.HandlerError, + { + "message": "handler error: handler error", + "type": "RuntimeException", + "non_retryable": False, + }, + ), + ( + ApplicationError, + { + "message": "handler error", + "type": "RuntimeException", + "non_retryable": False, + }, + ), + ], + ) +) + +# nexus_handler_error_not_found_from_custom_error: +_ = ["NexusOperationError", "HandlerError"] +# Java +# [Not possible] +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="nexus_handler_error_not_found", + java_behavior=[], # [Not possible] + ) +) + + +# nexus_operation_error_from_application_error_non_retryable_from_custom_error: +_ = ["NexusOperationError", "ApplicationError", "ApplicationError", "ApplicationError"] +# Java +_ = [ + "NexusOperationError", + "ApplicationError('application error', type='APPLICATION_ERROR', non_retryable=True)", + "ApplicationError('custom error', type='MyCustomException', non_retryable=False)", +] +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="nexus_operation_error_from_application_error_non_retryable_from_custom_error", + java_behavior=[ + (NexusOperationError, {}), + ( + ApplicationError, + { + "message": "application error", + "type": "APPLICATION_ERROR", + "non_retryable": True, + }, + ), + ( + ApplicationError, + { + "message": "custom error", + "type": "MyCustomException", + "non_retryable": False, + }, + ), + ], + ) +) ActionInSyncOp = Literal[ @@ -1344,15 +1508,3 @@ async def test_errors_raised_by_nexus_operation( ) print(f"\n\n\n{action_in_sync_op}: \n", result, "\n\n\n") - - # if action_in_sync_op == "handler_error": - # assert result == ["NexusOperationError", "HandlerError"] - # elif action_in_sync_op == "operation_error": - # assert result == ["NexusOperationError", "ApplicationError"] - # elif action_in_sync_op == "custom_error": - # # assert result == ["NexusOperationError", "CustomError"] - # pass - # else: - # raise NotImplementedError( - # f"Unhandled action_in_sync_op: {action_in_sync_op}" - # ) From d5043d7b495f35684606f277bcaa5e92a655fd38 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 1 Jul 2025 12:24:30 -0400 Subject: [PATCH 142/183] Update test --- tests/nexus/test_workflow_caller.py | 75 +++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 19 deletions(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 53fbb3ed6..8bcf66b6a 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -2,6 +2,7 @@ import uuid from dataclasses import dataclass from enum import IntEnum +from itertools import zip_longest from typing import Any, Callable, Literal, Union import nexusrpc @@ -1164,12 +1165,34 @@ async def assert_handler_workflow_has_link_to_caller_workflow( # io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='application error 1', type='APPLICATION_ERROR', nonRetryable=true) # io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='Custom error 2', type='io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException', nonRetryable=false) +ActionInSyncOp = Literal[ + "application_error_non_retryable", + "custom_error", + "custom_error_from_custom_error", + "application_error_non_retryable_from_custom_error", + "nexus_handler_error_not_found", + "nexus_handler_error_not_found_from_custom_error", + "nexus_operation_error_from_application_error_non_retryable_from_custom_error", +] + @dataclass class ErrorConversionTestCase: - name: str + name: ActionInSyncOp java_behavior: list[tuple[type[Exception], dict[str, Any]]] + @staticmethod + def parse_exception( + exception: BaseException, + ) -> tuple[type[BaseException], dict[str, Any]]: + if isinstance(exception, NexusOperationError): + return NexusOperationError, {} + return type(exception), { + "message": getattr(exception, "message", None), + "type": getattr(exception, "type", None), + "non_retryable": getattr(exception, "non_retryable", None), + } + error_conversion_test_cases = [] @@ -1247,6 +1270,13 @@ class ErrorConversionTestCase: _ = ["NexusOperationError", "HandlerError"] # Java # [Not possible] +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="custom_error_from_custom_error", + java_behavior=[], # [Not possible] + ) +) + # application_error_non_retryable_from_custom_error: _ = ["NexusOperationError", "HandlerError"] @@ -1302,7 +1332,7 @@ class ErrorConversionTestCase: error_conversion_test_cases.append( ErrorConversionTestCase( - name="application_error_non_retryable_from_custom_error", + name="nexus_handler_error_not_found", java_behavior=[ (NexusOperationError, {}), ( @@ -1331,7 +1361,7 @@ class ErrorConversionTestCase: # [Not possible] error_conversion_test_cases.append( ErrorConversionTestCase( - name="nexus_handler_error_not_found", + name="nexus_handler_error_not_found_from_custom_error", java_behavior=[], # [Not possible] ) ) @@ -1371,17 +1401,6 @@ class ErrorConversionTestCase: ) -ActionInSyncOp = Literal[ - "application_error_non_retryable", - "custom_error", - "custom_error_from_custom_error", - "application_error_non_retryable_from_custom_error", - "nexus_handler_error_not_found", - "nexus_handler_error_not_found_from_custom_error", - "nexus_operation_error_from_application_error_non_retryable_from_custom_error", -] - - class CustomError(Exception): pass @@ -1453,9 +1472,10 @@ def __init__(self, input: ErrorTestInput): service=ErrorTestService, endpoint=make_nexus_endpoint_name(input.task_queue), ) + self.test_cases = {t.name: t for t in error_conversion_test_cases} @workflow.run - async def run(self, input: ErrorTestInput) -> list[str]: + async def run(self, input: ErrorTestInput) -> None: try: await self.nexus_client.execute_operation( # TODO(nexus-preview): why wasn't this a type error? @@ -1470,7 +1490,26 @@ async def run(self, input: ErrorTestInput) -> list[str]: while err.__cause__: errs.append(err.__cause__) err = err.__cause__ - return [type(err).__name__ for err in errs] + actual = [ErrorConversionTestCase.parse_exception(err) for err in errs] + results = list( + zip_longest( + self.test_cases[input.action_in_sync_op].java_behavior, + actual, + fillvalue=None, + ) + ) + print(f""" + +{input.action_in_sync_op} +{'-' * 80} +""") + for java_behavior, actual in results: + print(f"Java: {java_behavior}") + print(f"Python: {actual}") + print() + print("-" * 80) + return None + assert False, "Unreachable" @@ -1497,7 +1536,7 @@ async def test_errors_raised_by_nexus_operation( task_queue=task_queue, ): await create_nexus_endpoint(task_queue, client) - result = await client.execute_workflow( + await client.execute_workflow( ErrorTestCallerWorkflow.run, ErrorTestInput( task_queue=task_queue, @@ -1506,5 +1545,3 @@ async def test_errors_raised_by_nexus_operation( id=str(uuid.uuid4()), task_queue=task_queue, ) - - print(f"\n\n\n{action_in_sync_op}: \n", result, "\n\n\n") From e4141da6b84bd576f61e661a431af35353162f0b Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 1 Jul 2025 21:52:04 -0400 Subject: [PATCH 143/183] Corrected Java output --- tests/nexus/test_workflow_caller.py | 36 ++++++++++++----------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 8bcf66b6a..96378f776 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -1131,39 +1131,33 @@ async def assert_handler_workflow_has_link_to_caller_workflow( # } # 🌈 RAISE_APPLICATION_ERROR: -# io.temporal.failure.NexusOperationFailure(type=no-type-attr, message=Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=) -# io.nexusrpc.handler.HandlerException(type=no-type-attr, message=handler error: message='application error 1', type='APPLICATION_ERROR', nonRetryable=true) -# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='application error 1', type='APPLICATION_ERROR', nonRetryable=true) +# io.temporal.failure.NexusOperationFailure(message="Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=", scheduledEventId=scheduledEventId, operationToken="operationToken") +# io.nexusrpc.handler.HandlerException(message="handler error: message='application error 1', type='my-application-error-type', nonRetryable=true", type="INTERNAL", nonRetryable=true) +# io.temporal.failure.ApplicationFailure(message="application error 1", type="my-application-error-type", nonRetryable=true) # 🌈 RAISE_CUSTOM_ERROR: -# io.temporal.failure.NexusOperationFailure(type=no-type-attr, message=Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=) -# io.nexusrpc.handler.HandlerException(type=no-type-attr, message=handler error: message='Custom error wrapped: custom error 1', type='CUSTOM_ERROR', nonRetryable=true) -# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='Custom error wrapped: custom error 1', type='CUSTOM_ERROR', nonRetryable=true) +# io.temporal.failure.NexusOperationFailure(message="Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=", scheduledEventId=scheduledEventId, operationToken="operationToken") +# io.temporal.failure.TimeoutFailure(message="message='operation timed out', timeoutType=TIMEOUT_TYPE_SCHEDULE_TO_CLOSE") # 🌈 RAISE_APPLICATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: -# io.temporal.failure.NexusOperationFailure(type=no-type-attr, message=Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=) -# io.nexusrpc.handler.HandlerException(type=no-type-attr, message=handler error: message='application error 1', type='APPLICATION_ERROR', nonRetryable=true) -# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='application error 1', type='APPLICATION_ERROR', nonRetryable=true) -# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='Custom error 2', type='io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException', nonRetryable=false) +# io.temporal.failure.NexusOperationFailure(message="Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=", scheduledEventId=scheduledEventId, operationToken="operationToken") +# io.nexusrpc.handler.HandlerException(message="handler error: message='application error 1', type='my-application-error-type', nonRetryable=true", type="INTERNAL", nonRetryable=true) +# io.temporal.failure.ApplicationFailure(message="application error 1", type="my-application-error-type", nonRetryable=true) +# io.temporal.failure.ApplicationFailure(message="Custom error 2", type="io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException", nonRetryable=false) # 🌈 RAISE_NEXUS_HANDLER_ERROR: -# io.temporal.failure.NexusOperationFailure(type=no-type-attr, message=Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=) -# io.nexusrpc.handler.HandlerException(type=no-type-attr, message=handler error: message='Handler error 1', type='java.lang.RuntimeException', nonRetryable=false) -# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='Handler error 1', type='java.lang.RuntimeException', nonRetryable=false) - - -# 🌈 RAISE_NEXUS_HANDLER_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: -# io.temporal.failure.NexusOperationFailure(type=no-type-attr, message=Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=) -# io.temporal.failure.TimeoutFailure(type=no-type-attr, message=message='operation timed out', timeoutType=TIMEOUT_TYPE_SCHEDULE_TO_CLOSE) +# io.temporal.failure.NexusOperationFailure(message="Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=", scheduledEventId=scheduledEventId, operationToken="operationToken") +# io.nexusrpc.handler.HandlerException(message="handler error: message='Handler error 1', type='java.lang.RuntimeException', nonRetryable=false", type="NOT_FOUND", nonRetryable=true) +# io.temporal.failure.ApplicationFailure(message="Handler error 1", type="java.lang.RuntimeException", nonRetryable=false) # 🌈 RAISE_NEXUS_OPERATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: -# io.temporal.failure.NexusOperationFailure(type=no-type-attr, message=Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=) -# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='application error 1', type='APPLICATION_ERROR', nonRetryable=true) -# io.temporal.failure.ApplicationFailure(type=no-type-attr, message=message='Custom error 2', type='io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException', nonRetryable=false) +# io.temporal.failure.NexusOperationFailure(message="Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=", scheduledEventId=scheduledEventId, operationToken="operationToken") +# io.temporal.failure.ApplicationFailure(message="application error 1", type="my-application-error-type", nonRetryable=true) +# io.temporal.failure.ApplicationFailure(message="Custom error 2", type="io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException", nonRetryable=false) ActionInSyncOp = Literal[ "application_error_non_retryable", From b4fcd07e72fda7d19eac879dd1674ff0931dc7b0 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 1 Jul 2025 22:21:34 -0400 Subject: [PATCH 144/183] Update test assertions --- tests/nexus/test_workflow_caller.py | 92 +++++++++-------------------- 1 file changed, 29 insertions(+), 63 deletions(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 96378f776..2a6c58efb 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -1196,8 +1196,8 @@ def parse_exception( # Java _ = [ "NexusOperationError", - "HandlerError('handler error: application error', type='APPLICATION_ERROR', non_retryable=True)", - "ApplicationError('application error', type='APPLICATION_ERROR', non_retryable=True)", + "HandlerError('handler error: message='application error 1', type='my-application-error-type', nonRetryable=true', type='INTERNAL', nonRetryable=true)", + "ApplicationError('application error 1', type='my-application-error-type', nonRetryable=true)", ] error_conversion_test_cases.append( @@ -1208,50 +1208,16 @@ def parse_exception( ( nexusrpc.HandlerError, { - "message": "application error", - "type": "APPLICATION_ERROR", + "message": "handler error: message='application error 1', type='my-application-error-type', nonRetryable=true", + "type": "INTERNAL", "non_retryable": True, }, ), ( ApplicationError, { - "message": "application error", - "type": "APPLICATION_ERROR", - "non_retryable": True, - }, - ), - ], - ) -) - - -# custom_error: -_ = ["NexusOperationError", "HandlerError"] -# Java -_ = [ - "NexusOperationError", - "HandlerError('Custom error wrapped: custom error', type='CUSTOM_ERROR', non_retryable=True)", - "ApplicationError('Custom error wrapped: custom error', type='CUSTOM_ERROR', non_retryable=True)", -] -error_conversion_test_cases.append( - ErrorConversionTestCase( - name="custom_error", - java_behavior=[ - (NexusOperationError, {}), - ( - nexusrpc.HandlerError, - { - "message": "Custom error wrapped: custom error", - "type": "CUSTOM_ERROR", - "non_retryable": True, - }, - ), - ( - ApplicationError, - { - "message": "Custom error wrapped: custom error", - "type": "CUSTOM_ERROR", + "message": "application error 1", + "type": "my-application-error-type", "non_retryable": True, }, ), @@ -1277,9 +1243,9 @@ def parse_exception( # Java _ = [ "NexusOperationError", - "HandlerError('handler error: application error', type='APPLICATION_ERROR', non_retryable=True)", - "ApplicationError('application error', type='APPLICATION_ERROR', non_retryable=True)", - "ApplicationError('custom error', type='MyCustomException', non_retryable=False)", + "HandlerError('handler error: message='application error 1', type='my-application-error-type', nonRetryable=true', type='INTERNAL', nonRetryable=true)", + "ApplicationError('application error 1', type='my-application-error-type', nonRetryable=true)", + "ApplicationError('Custom error 2', type='io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException', nonRetryable=false)", ] error_conversion_test_cases.append( @@ -1290,24 +1256,24 @@ def parse_exception( ( nexusrpc.HandlerError, { - "message": "handler error: application error", - "type": "APPLICATION_ERROR", + "message": "handler error: message='application error 1', type='my-application-error-type', nonRetryable=true", + "type": "INTERNAL", "non_retryable": True, }, ), ( ApplicationError, { - "message": "application error", - "type": "APPLICATION_ERROR", + "message": "application error 1", + "type": "my-application-error-type", "non_retryable": True, }, ), ( ApplicationError, { - "message": "custom error", - "type": "MyCustomException", + "message": "Custom error 2", + "type": "io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException", "non_retryable": False, }, ), @@ -1320,8 +1286,8 @@ def parse_exception( # Java _ = [ "NexusOperationError", - "HandlerError('handler error: handler error', type='RuntimeException', non_retryable=False)", - "ApplicationError('handler error', type='RuntimeException', non_retryable=False)", + "HandlerError('handler error: message='Handler error 1', type='java.lang.RuntimeException', nonRetryable=false', type='NOT_FOUND', nonRetryable=true)", + "ApplicationError('Handler error 1', type='java.lang.RuntimeException', nonRetryable=false)", ] error_conversion_test_cases.append( @@ -1332,16 +1298,16 @@ def parse_exception( ( nexusrpc.HandlerError, { - "message": "handler error: handler error", - "type": "RuntimeException", - "non_retryable": False, + "message": "handler error: message='Handler error 1', type='java.lang.RuntimeException', nonRetryable=false", + "type": "NOT_FOUND", + "non_retryable": True, }, ), ( ApplicationError, { - "message": "handler error", - "type": "RuntimeException", + "message": "Handler error 1", + "type": "java.lang.RuntimeException", "non_retryable": False, }, ), @@ -1362,12 +1328,12 @@ def parse_exception( # nexus_operation_error_from_application_error_non_retryable_from_custom_error: -_ = ["NexusOperationError", "ApplicationError", "ApplicationError", "ApplicationError"] +_ = ["NexusOperationError", "ApplicationError", "ApplicationError"] # Java _ = [ "NexusOperationError", - "ApplicationError('application error', type='APPLICATION_ERROR', non_retryable=True)", - "ApplicationError('custom error', type='MyCustomException', non_retryable=False)", + "ApplicationError('application error 1', type='my-application-error-type', nonRetryable=true)", + "ApplicationError('Custom error 2', type='io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException', nonRetryable=false)", ] error_conversion_test_cases.append( ErrorConversionTestCase( @@ -1377,16 +1343,16 @@ def parse_exception( ( ApplicationError, { - "message": "application error", - "type": "APPLICATION_ERROR", + "message": "application error 1", + "type": "my-application-error-type", "non_retryable": True, }, ), ( ApplicationError, { - "message": "custom error", - "type": "MyCustomException", + "message": "Custom error 2", + "type": "io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException", "non_retryable": False, }, ), From b51063b6cbfee1862c6f9be4cd5fde79dff08652 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 1 Jul 2025 23:31:59 -0400 Subject: [PATCH 145/183] Add timeout test --- tests/nexus/test_workflow_caller.py | 57 ++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 2a6c58efb..410ccf321 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -1,6 +1,7 @@ import asyncio import uuid from dataclasses import dataclass +from datetime import timedelta from enum import IntEnum from itertools import zip_longest from typing import Any, Callable, Literal, Union @@ -39,7 +40,12 @@ WorkflowHandle, ) from temporalio.common import WorkflowIDConflictPolicy -from temporalio.exceptions import ApplicationError, CancelledError, NexusOperationError +from temporalio.exceptions import ( + ApplicationError, + CancelledError, + NexusOperationError, + TimeoutError, +) from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import Worker @@ -1505,3 +1511,52 @@ async def test_errors_raised_by_nexus_operation( id=str(uuid.uuid4()), task_queue=task_queue, ) + + +# Timeout test +@service_handler +class TimeoutTestService: + @sync_operation + async def op_handler_that_never_returns( + self, ctx: StartOperationContext, input: None + ) -> None: + await asyncio.Future() + + +@workflow.defn +class TimeoutTestCallerWorkflow: + @workflow.init + def __init__(self): + self.nexus_client = workflow.NexusClient( + service=TimeoutTestService, + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + ) + + @workflow.run + async def run(self) -> None: + await self.nexus_client.execute_operation( + TimeoutTestService.op_handler_that_never_returns, + None, + schedule_to_close_timeout=timedelta(seconds=0.1), + ) + + +async def test_timeout_error_raised_by_nexus_operation(client: Client): + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_service_handlers=[TimeoutTestService()], + workflows=[TimeoutTestCallerWorkflow], + task_queue=task_queue, + ): + await create_nexus_endpoint(task_queue, client) + try: + await client.execute_workflow( + TimeoutTestCallerWorkflow.run, + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + except Exception as err: + assert isinstance(err, WorkflowFailureError) + assert isinstance(err.__cause__, NexusOperationError) + assert isinstance(err.__cause__.__cause__, TimeoutError) From 7c5f1076c3fb04f442e1694050fcca90f212dc0c Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 2 Jul 2025 12:59:10 -0400 Subject: [PATCH 146/183] Install the Nexus SDK from GitHub --- pyproject.toml | 2 +- uv.lock | 22 +++------------------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1faed3fd3..072ce19c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -218,4 +218,4 @@ exclude = [ package = false [tool.uv.sources] -nexus-rpc = { path = "../nexus-sdk-python", editable = true } +nexus-rpc = { git = "https://github.com/nexus-rpc/sdk-python" } diff --git a/uv.lock b/uv.lock index 68471a43f..08dd46baf 100644 --- a/uv.lock +++ b/uv.lock @@ -1043,28 +1043,12 @@ wheels = [ [[package]] name = "nexus-rpc" -version = "0.1.0" -source = { editable = "../nexus-sdk-python" } +version = "1.1.0" +source = { git = "https://github.com/nexus-rpc/sdk-python#94a1267cb5baabf2d3609aedb7f6cf81587be6df" } dependencies = [ { name = "typing-extensions" }, ] -[package.metadata] -requires-dist = [{ name = "typing-extensions", specifier = ">=4.12.2" }] - -[package.metadata.requires-dev] -dev = [ - { name = "mypy", specifier = ">=1.15.0" }, - { name = "poethepoet", specifier = ">=0.35.0" }, - { name = "pydoctor", specifier = ">=25.4.0" }, - { name = "pyright", specifier = ">=1.1.402" }, - { name = "pytest", specifier = ">=8.3.5" }, - { name = "pytest-asyncio", specifier = ">=0.26.0" }, - { name = "pytest-cov", specifier = ">=6.1.1" }, - { name = "pytest-pretty", specifier = ">=1.3.0" }, - { name = "ruff", specifier = ">=0.12.0" }, -] - [[package]] name = "nh3" version = "0.2.21" @@ -1773,7 +1757,7 @@ dev = [ requires-dist = [ { name = "eval-type-backport", marker = "python_full_version < '3.10' and extra == 'openai-agents'", specifier = ">=0.2.2" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, - { name = "nexus-rpc", editable = "../nexus-sdk-python" }, + { name = "nexus-rpc", git = "https://github.com/nexus-rpc/sdk-python" }, { name = "openai-agents", marker = "extra == 'openai-agents'", specifier = ">=0.0.19,<0.1" }, { name = "opentelemetry-api", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, { name = "opentelemetry-sdk", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, From 638a121449797558803e5e3c24fafd877a801bab Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 2 Jul 2025 13:09:41 -0400 Subject: [PATCH 147/183] Update error tests --- tests/nexus/test_workflow_caller.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 410ccf321..be5af19c1 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -1483,7 +1483,6 @@ async def run(self, input: ErrorTestInput) -> None: "action_in_sync_op", [ "application_error_non_retryable", - "custom_error", "custom_error_from_custom_error", "application_error_non_retryable_from_custom_error", "nexus_handler_error_not_found", From c6dbb523cbd038999d2fe71c624cbf6c97f2c674 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 2 Jul 2025 17:54:25 -0400 Subject: [PATCH 148/183] Edit TODOs --- temporalio/nexus/_util.py | 3 -- temporalio/worker/_workflow_instance.py | 11 ++++--- temporalio/workflow.py | 2 +- tests/nexus/test_handler.py | 3 -- tests/nexus/test_workflow_caller.py | 38 ++++++++++++------------- 5 files changed, 25 insertions(+), 32 deletions(-) diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index 3c2cc9fe4..c0a1b8464 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -85,7 +85,6 @@ def _get_start_method_input_and_output_type_annotations( try: type_annotations = typing.get_type_hints(start) except TypeError: - # TODO(preview): stacklevel warnings.warn( f"Expected decorated start method {start} to have type annotations" ) @@ -93,7 +92,6 @@ def _get_start_method_input_and_output_type_annotations( output_type = type_annotations.pop("return", None) if len(type_annotations) != 2: - # TODO(preview): stacklevel suffix = f": {type_annotations}" if type_annotations else "" warnings.warn( f"Expected decorated start method {start} to have exactly 2 " @@ -104,7 +102,6 @@ def _get_start_method_input_and_output_type_annotations( else: ctx_type, input_type = type_annotations.values() if not issubclass(ctx_type, WorkflowRunOperationContext): - # TODO(preview): stacklevel warnings.warn( f"Expected first parameter of {start} to be an instance of " f"WorkflowRunOperationContext, but is {ctx_type}." diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index db24de6b7..dd8123a02 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -869,7 +869,7 @@ def _apply_resolve_nexus_operation_start( # Note that core will still send a `ResolveNexusOperation` job in the same # activation, so there does not need to be an exceptional case for this in # lang. - # TODO(dan): confirm appropriate to take no action here + # TODO(nexus-prerelease): confirm appropriate to take no action here pass else: raise ValueError(f"Unknown Nexus operation start status: {job}") @@ -893,7 +893,6 @@ def _apply_resolve_nexus_operation( ) handle._resolve_success(output) elif result.HasField("failed"): - # TODO(dan): test failure converter handle._resolve_failure( self._failure_converter.from_failure( result.failed, self._payload_converter @@ -2989,9 +2988,9 @@ async def cancel(self) -> None: await self._instance._cancel_external_workflow(command) -# TODO(dan): are we sure we don't want to inherit from asyncio.Task as ActivityHandle and -# ChildWorkflowHandle do? I worry that we should provide .done(), .result(), .exception() -# etc for consistency. +# TODO(nexus-preview): are we sure we don't want to inherit from asyncio.Task as +# ActivityHandle and ChildWorkflowHandle do? I worry that we should provide .done(), +# .result(), .exception() etc for consistency. class _NexusOperationHandle(temporalio.workflow.NexusOperationHandle[OutputT]): def __init__( self, @@ -3009,7 +3008,7 @@ def __init__( @property def operation_token(self) -> Optional[str]: - # TODO(dan): How should this behave? + # TODO(nexus-prerelease): How should this behave? # Java has a separate class that only exists if the operation token exists: # https://github.com/temporalio/sdk-java/blob/master/temporal-sdk/src/main/java/io/temporal/internal/sync/NexusOperationExecutionImpl.java#L26 # And Go similar: diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 4df127665..305b9c049 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -1983,7 +1983,7 @@ class _AsyncioTask(asyncio.Task[AnyType]): pass else: - # TODO(dan): inherited classes should be other way around? + # TODO: inherited classes should be other way around? class _AsyncioTask(Generic[AnyType], asyncio.Task): pass diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 5c7e25fa4..0c69b4bf1 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -130,9 +130,6 @@ async def run(self, input: Input) -> Output: return Output(value=f"from link test workflow: {input.value}") -# TODO: implement some of these ops as explicit OperationHandler classes to provide coverage for that? - - # The service_handler decorator is applied by the test class MyServiceHandler: @sync_operation diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index be5af19c1..f6b948cd6 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -51,9 +51,9 @@ from temporalio.worker import Worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name -# TODO(dan): test availability of Temporal client etc in async context set by worker -# TODO(dan): test worker shutdown, wait_all_completed, drain etc -# TODO(dan): test worker op handling failure +# TODO(nexus-prerelease): test availability of Temporal client etc in async context set by worker +# TODO(nexus-prerelease): test worker shutdown, wait_all_completed, drain etc +# TODO(nexus-prerelease): test worker op handling failure # ----------------------------------------------------------------------------- # Test definition @@ -145,7 +145,7 @@ async def run( ) -# TODO: make types pass pyright strict mode +# TODO(nexus-prerelease): check type-checking passing in CI class SyncOrAsyncOperation(OperationHandler[OpInput, OpOutput]): @@ -156,7 +156,7 @@ async def start( StartOperationResultAsync, ]: if input.response_type.exception_in_operation_start: - # TODO(dan): don't think RPCError should be used here + # TODO(nexus-prerelease): don't think RPCError should be used here raise RPCError( "RPCError INVALID_ARGUMENT in Nexus operation", RPCStatusCode.INVALID_ARGUMENT, @@ -381,7 +381,7 @@ class UntypedCallerWorkflow: def __init__( self, input: CallerWfInput, request_cancel: bool, task_queue: str ) -> None: - # TODO(dan): untyped caller cannot reference name of implementation. I think this is as it should be. + # TODO(nexus-prerelease): untyped caller cannot reference name of implementation. I think this is as it should be. service_name = "ServiceInterface" self.nexus_client = workflow.NexusClient( service=service_name, @@ -427,9 +427,9 @@ async def run( # -# TODO(dan): cross-namespace tests -# TODO(dan): nexus endpoint pytest fixture? -# TODO(dan): test headers +# TODO(nexus-prerelease): cross-namespace tests +# TODO(nexus-prerelease): nexus endpoint pytest fixture? +# TODO(nexus-prerelease): test headers @pytest.mark.parametrize("exception_in_operation_start", [False, True]) @pytest.mark.parametrize("request_cancel", [False, True]) @pytest.mark.parametrize( @@ -476,7 +476,7 @@ async def test_sync_response( task_queue=task_queue, ) - # TODO(dan): check bidi links for sync operation + # TODO(nexus-prerelease): check bidi links for sync operation # The operation result is returned even when request_cancel=True, because the # response was synchronous and it could not be cancelled. See explanation below. @@ -551,7 +551,7 @@ async def test_async_response( ) return - # TODO(dan): race here? How do we know it hasn't been canceled already? + # TODO(nexus-prerelease): race here? How do we know it hasn't been canceled already? handler_wf_info = await handler_wf_handle.describe() assert handler_wf_info.status in [ WorkflowExecutionStatus.RUNNING, @@ -736,8 +736,8 @@ class ServiceClassNameOutput: name: str -# TODO(dan): test interface op types not matching -# TODO(dan): async and non-async cancel methods +# TODO(nexus-prerelease): test interface op types not matching +# TODO(nexus-prerelease): async and non-async cancel methods @nexusrpc.service @@ -822,12 +822,12 @@ async def run( endpoint=make_nexus_endpoint_name(task_queue), ) - # TODO(dan): maybe not surprising that this doesn't type check given complexity of + # TODO(nexus-prerelease): maybe not surprising that this doesn't type check given complexity of # the union? return await nexus_client.execute_operation(service_cls.op, None) # type: ignore -# TODO(dan): check missing decorator behavior +# TODO(nexus-prerelease): check missing decorator behavior async def test_service_interface_and_implementation_names(client: Client): @@ -979,8 +979,8 @@ async def test_workflow_run_operation_can_execute_workflow_before_starting_backi assert result == "result-1-result-2" -# TODO(dan): test invalid service interface implementations -# TODO(dan): test caller passing output_type +# TODO(nexus-prerelease): test invalid service interface implementations +# TODO(nexus-prerelease): test caller passing output_type async def assert_caller_workflow_has_link_to_handler_workflow( @@ -1444,10 +1444,10 @@ def __init__(self, input: ErrorTestInput): async def run(self, input: ErrorTestInput) -> None: try: await self.nexus_client.execute_operation( - # TODO(nexus-preview): why wasn't this a type error? + # TODO(nexus-prerelease): why wasn't this a type error? # ErrorTestService.op, ErrorTestCallerWfInput() ErrorTestService.op, - # TODO(nexus-preview): why wasn't this a type error? + # TODO(nexus-prerelease): why wasn't this a type error? # None input, ) From 0fc88c0d8545914619e4445d3925a530dcfd470b Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 2 Jul 2025 19:10:21 -0400 Subject: [PATCH 149/183] Make a pass through prerelease TODOs --- temporalio/nexus/__init__.py | 5 ----- temporalio/nexus/_operation_context.py | 2 +- temporalio/types.py | 2 +- temporalio/worker/_activity.py | 5 +++-- temporalio/worker/_interceptor.py | 1 - temporalio/worker/_nexus.py | 2 +- temporalio/worker/_worker.py | 2 +- temporalio/worker/_workflow_instance.py | 4 ++-- temporalio/workflow.py | 11 +---------- tests/conftest.py | 2 +- ...ynamic_creation_of_user_handler_classes.py | 7 ------- tests/nexus/test_handler.py | 19 +++++-------------- tests/nexus/test_workflow_caller.py | 13 ++++--------- tests/nexus/test_workflow_run_operation.py | 3 --- 14 files changed, 20 insertions(+), 58 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index 5573df4a6..dd9935b05 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -14,8 +14,3 @@ from ._operation_context import logger as logger from ._operation_handlers import cancel_operation as cancel_operation from ._token import WorkflowHandle as WorkflowHandle - - -# TODO(nexus-prerelease) WARN temporal_sdk_core::worker::nexus: Failed to parse nexus timeout header value '9.155416ms' -# 2025-06-25T12:58:05.749589Z WARN temporal_sdk_core::worker::nexus: Failed to parse nexus timeout header value '9.155416ms' -# 2025-06-25T12:58:05.763052Z WARN temporal_sdk_core::worker::nexus: Nexus task not found on completion. This may happen if the operation has already been cancelled but completed anyway. details=Status { code: NotFound, message: "Nexus task not found or already expired", details: b"\x08\x05\x12'Nexus task not found or already expired\x1aB\n@type.googleapis.com/temporal.api.errordetails.v1.NotFoundFailure", metadata: MetadataMap { headers: {"content-type": "application/grpc"} }, source: None } diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 79afd5a91..3f0df0a6d 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -193,7 +193,7 @@ def from_start_operation_context( ) # Overload for single-param workflow - # TODO(nexus-prerelease): bring over other overloads + # TODO(nexus-prerelease)*: bring over other overloads async def start_workflow( self, workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], diff --git a/temporalio/types.py b/temporalio/types.py index f29d42e1e..01b3c9d58 100644 --- a/temporalio/types.py +++ b/temporalio/types.py @@ -81,7 +81,7 @@ class MethodAsyncSingleParam( ): """Generic callable type.""" - # TODO(nexus-prerelease): review changes to signatures in this file + # TODO(nexus-prerelease)*: review changes to signatures in this file def __call__( self, __self: ProtocolSelfType, __arg: ProtocolParamType ) -> Awaitable[ProtocolReturnType]: diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index c9f71834c..9bc373022 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -201,8 +201,9 @@ async def drain_poll_queue(self) -> None: # Only call this after run()/drain_poll_queue() have returned. This will not # raise an exception. - # TODO(nexus-prerelease): based on the comment above it looks like the intention may have been to use - # return_exceptions=True + # TODO(nexus-preview): based on the comment above it looks like the intention may have been to use + # return_exceptions=True. Change this for nexus and activity and change call sites to consume entire + # stream and then raise first exception async def wait_all_completed(self) -> None: running_tasks = [v.task for v in self._running_activities.values() if v.task] if running_tasks: diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 8499a5136..667848f16 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -304,7 +304,6 @@ class StartNexusOperationInput(Generic[InputT, OutputT]): _operation_name: str = field(init=False, repr=False) _input_type: Optional[Type[InputT]] = field(init=False, repr=False) - # TODO(nexus-prerelease): update this logic to handle service impl start methods def __post_init__(self) -> None: if isinstance(self.operation, nexusrpc.Operation): self._operation_name = self.operation.name diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 72e3ef2bb..67e3a3f81 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -191,7 +191,7 @@ async def _handle_cancel_operation_task( ), ) else: - # TODO(nexus-prerelease): when do we use ack_cancel? + # TODO(nexus-preview): ack_cancel completions? completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( task_token=task_token, completed=temporalio.api.nexus.v1.Response( diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 80b70a055..188d80080 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -311,7 +311,7 @@ def __init__( nexus_task_poller_behavior: Specify the behavior of Nexus task polling. Defaults to a 5-poller maximum. """ - # TODO(nexus-prerelease): max_concurrent_nexus_tasks / tuner support + # TODO(nexus-preview): max_concurrent_nexus_tasks / tuner support if not (activities or nexus_service_handlers or workflows): raise ValueError( "At least one activity, Nexus service, or workflow must be specified" diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index dd8123a02..78b53c589 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -869,7 +869,7 @@ def _apply_resolve_nexus_operation_start( # Note that core will still send a `ResolveNexusOperation` job in the same # activation, so there does not need to be an exceptional case for this in # lang. - # TODO(nexus-prerelease): confirm appropriate to take no action here + # TODO(nexus-preview): confirm appropriate to take no action here pass else: raise ValueError(f"Unknown Nexus operation start status: {job}") @@ -3008,7 +3008,7 @@ def __init__( @property def operation_token(self) -> Optional[str]: - # TODO(nexus-prerelease): How should this behave? + # TODO(nexus-preview): How should this behave? # Java has a separate class that only exists if the operation token exists: # https://github.com/temporalio/sdk-java/blob/master/temporal-sdk/src/main/java/io/temporal/internal/sync/NexusOperationExecutionImpl.java#L26 # And Go similar: diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 305b9c049..c51cfea05 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -4384,25 +4384,16 @@ async def execute_child_workflow( return await handle -# TODO(nexus-prerelease): ABC / inherit from asyncio.Task? class NexusOperationHandle(Generic[OutputT]): def cancel(self) -> bool: - # TODO(nexus-prerelease): docstring """ - Call task.cancel() on the asyncio task that is backing this handle. - - From asyncio docs: - - Cancel the future and schedule callbacks. - - If the future is already done or cancelled, return False. Otherwise, change the future's state to cancelled, schedule the callbacks and return True. + Request cancellation of the operation. """ raise NotImplementedError def __await__(self) -> Generator[Any, Any, OutputT]: raise NotImplementedError - # TODO(nexus-prerelease): check SDK-wide consistency for @property vs nullary accessor methods. @property def operation_token(self) -> Optional[str]: raise NotImplementedError diff --git a/tests/conftest.py b/tests/conftest.py index 48df7285e..7d9f0157d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -123,13 +123,13 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: ], dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION, ) + # TODO(nexus-preview): expose this in a more principled way env._http_port = http_port # type: ignore elif env_type == "time-skipping": env = await WorkflowEnvironment.start_time_skipping() else: env = WorkflowEnvironment.from_client(await Client.connect(env_type)) - # TODO(nexus-prerelease): expose this in a principled way yield env await env.shutdown() diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index dd0c57017..1ab153b06 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -13,13 +13,6 @@ HTTP_PORT = 7243 -# TODO(nexus-prerelease): test programmatic creation from ServiceHandler -def make_incrementer_service_from_service_handler( - op_names: list[str], -) -> tuple[str, type]: - pass - - def make_incrementer_user_service_definition_and_service_handler_classes( op_names: list[str], ) -> tuple[type, type]: diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 0c69b4bf1..025a442bb 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -80,15 +80,14 @@ class NonSerializableOutput: # TODO(nexus-prelease): Test attaching multiple callers to the same operation. # TODO(nexus-preview): type check nexus implementation under mypy -# TODO(nexus-prerelease): test dynamic creation of a service from unsugared definition -# TODO(nexus-prerelease): test malformed inbound_links and outbound_links -# TODO(nexus-prerelease): test good error message on forgetting to add decorators etc +# TODO(nexus-preview): test malformed inbound_links and outbound_links + +# TODO(nexus-prerelease): 2025-07-02T23:29:20.000489Z WARN temporal_sdk_core::worker::nexus: Nexus task not found on completion. This may happen if the operation has already been cancelled but completed anyway. details=Status { code: NotFound, message: "Nexus task not found or already expired", details: b"\x08\x05\x12'Nexus task not found or already expired\x1aB\n@type.googleapis.com/temporal.api.errordetails.v1.NotFoundFailure", metadata: MetadataMap { headers: {"content-type": "application/grpc"} }, source: None } @nexusrpc.service class MyService: echo: nexusrpc.Operation[Input, Output] - # TODO(nexus-prerelease): support renamed operations! echo_renamed: nexusrpc.Operation[Input, Output] = nexusrpc.Operation( name="echo-renamed" ) @@ -140,13 +139,11 @@ async def echo(self, ctx: StartOperationContext, input: Input) -> Output: value=f"from start method on {self.__class__.__name__}: {input.value}" ) - # The name override is prsent in the service definition. But the test below submits + # The name override is present in the service definition. But the test below submits # the same operation name in the request whether using a service definition or now. # The name override here is necessary when the test is not using the service # definition. It should be permitted when the service definition is in effect, as # long as the name override is the same as that in the service definition. - # TODO(nexus-prerelease): implement in nexusrpc the check that operation handler - # name overrides must be consistent with service definition overrides. @sync_operation(name="echo-renamed") async def echo_renamed(self, ctx: StartOperationContext, input: Input) -> Output: return await self.echo(ctx, input) @@ -163,7 +160,7 @@ async def non_retryable_application_error( raise ApplicationError( "non-retryable application error", "details arg", - # TODO(nexus-prerelease): what values of `type` should be tested? + # TODO(nexus-preview): what values of `type` should be tested? type="TestFailureType", non_retryable=True, ) @@ -329,7 +326,6 @@ class UnsuccessfulResponse: failure_details: bool = True # Expected value of inverse of non_retryable attribute of exception. retryable_exception: bool = True - # TODO(nexus-prerelease): the body of a successful response need not be JSON; test non-JSON-parseable string body_json: Optional[Callable[[dict[str, Any]], bool]] = None headers: Mapping[str, str] = UNSUCCESSFUL_RESPONSE_HEADERS @@ -478,11 +474,6 @@ def check_response( ), f"nexus-link header {nexus_link} does not start with None: - # TODO(nexus-prerelease): untyped caller cannot reference name of implementation. I think this is as it should be. + # TODO(nexus-preview): untyped caller cannot reference name of implementation. I think this is as it should be. service_name = "ServiceInterface" self.nexus_client = workflow.NexusClient( service=service_name, @@ -427,8 +426,8 @@ async def run( # -# TODO(nexus-prerelease): cross-namespace tests -# TODO(nexus-prerelease): nexus endpoint pytest fixture? +# TODO(nexus-preview): cross-namespace tests +# TODO(nexus-preview): nexus endpoint pytest fixture? # TODO(nexus-prerelease): test headers @pytest.mark.parametrize("exception_in_operation_start", [False, True]) @pytest.mark.parametrize("request_cancel", [False, True]) @@ -736,7 +735,6 @@ class ServiceClassNameOutput: name: str -# TODO(nexus-prerelease): test interface op types not matching # TODO(nexus-prerelease): async and non-async cancel methods @@ -822,8 +820,6 @@ async def run( endpoint=make_nexus_endpoint_name(task_queue), ) - # TODO(nexus-prerelease): maybe not surprising that this doesn't type check given complexity of - # the union? return await nexus_client.execute_operation(service_cls.op, None) # type: ignore @@ -947,7 +943,6 @@ async def run(self, input: str, task_queue: str) -> str: service=ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow, endpoint=make_nexus_endpoint_name(task_queue), ) - # TODO(nexus-prerelease): update StartNexusOperationInput.__post_init__ return await nexus_client.execute_operation( ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow.my_workflow_run_operation, None, diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index 0a00e32b6..217316412 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -41,9 +41,6 @@ async def run(self, input: str) -> str: return input -# TODO(nexus-prerelease): this test dates from a point at which we were encouraging -# subclassing WorkflowRunOperationHandler as part of the public API. Leaving it in for -# now. class MyOperation(WorkflowRunOperationHandler): def __init__(self): pass From 2ca30ff4077283cbd6ea653a51c8e855283767d0 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 2 Jul 2025 19:24:51 -0400 Subject: [PATCH 150/183] Revert change to callable types --- temporalio/types.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/temporalio/types.py b/temporalio/types.py index 01b3c9d58..331c9596e 100644 --- a/temporalio/types.py +++ b/temporalio/types.py @@ -81,9 +81,8 @@ class MethodAsyncSingleParam( ): """Generic callable type.""" - # TODO(nexus-prerelease)*: review changes to signatures in this file def __call__( - self, __self: ProtocolSelfType, __arg: ProtocolParamType + self, __self: ProtocolSelfType, __arg: ProtocolParamType, / ) -> Awaitable[ProtocolReturnType]: """Generic callable type callback.""" ... @@ -95,7 +94,7 @@ class MethodSyncSingleParam( """Generic callable type.""" def __call__( - self, __self: ProtocolSelfType, __arg: ProtocolParamType + self, __self: ProtocolSelfType, __arg: ProtocolParamType, / ) -> ProtocolReturnType: """Generic callable type callback.""" ... @@ -117,7 +116,7 @@ class MethodSyncOrAsyncSingleParam( """Generic callable type.""" def __call__( - self, __self: ProtocolSelfType, __param: ProtocolParamType + self, __self: ProtocolSelfType, __param: ProtocolParamType, / ) -> Union[ProtocolReturnType, Awaitable[ProtocolReturnType]]: """Generic callable type callback.""" ... From b3146ea46246f8e52b0848dafac50b6c50c025aa Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 2 Jul 2025 19:51:07 -0400 Subject: [PATCH 151/183] Test start_workflow overloads --- tests/nexus/test_workflow_caller.py | 149 +++++++++++++++++++++++++++- 1 file changed, 148 insertions(+), 1 deletion(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 91020a211..4bc5e0919 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -4,7 +4,7 @@ from datetime import timedelta from enum import IntEnum from itertools import zip_longest -from typing import Any, Callable, Literal, Union +from typing import Any, Awaitable, Callable, Literal, Union import nexusrpc import nexusrpc.handler @@ -31,6 +31,7 @@ import temporalio.api.operatorservice import temporalio.api.operatorservice.v1 import temporalio.exceptions +import temporalio.nexus from temporalio import nexus, workflow from temporalio.client import ( Client, @@ -1554,3 +1555,149 @@ async def test_timeout_error_raised_by_nexus_operation(client: Client): assert isinstance(err, WorkflowFailureError) assert isinstance(err.__cause__, NexusOperationError) assert isinstance(err.__cause__.__cause__, TimeoutError) + + +# Test overloads + + +@workflow.defn +class OverloadTestHandlerWorkflow: + @workflow.run + async def run(self, input: int) -> int: + return input * 2 + + +@workflow.defn +class OverloadTestHandlerWorkflowNoParam: + @workflow.run + async def run(self) -> int: + return 0 + + +@nexusrpc.handler.service_handler +class OverloadTestServiceHandler: + @workflow_run_operation + async def no_param( + self, + ctx: WorkflowRunOperationContext, + _: int, + ) -> nexus.WorkflowHandle[int]: + return await ctx.start_workflow( + OverloadTestHandlerWorkflowNoParam.run, + id=str(uuid.uuid4()), + ) + + @workflow_run_operation + async def single_param( + self, ctx: WorkflowRunOperationContext, input: int + ) -> nexus.WorkflowHandle[int]: + return await ctx.start_workflow( + OverloadTestHandlerWorkflow.run, + input, + id=str(uuid.uuid4()), + ) + + @workflow_run_operation + async def multi_param( + self, ctx: WorkflowRunOperationContext, input: int + ) -> nexus.WorkflowHandle[int]: + return await ctx.start_workflow( + OverloadTestHandlerWorkflow.run, + args=[input], + id=str(uuid.uuid4()), + ) + + @workflow_run_operation + async def by_name( + self, ctx: WorkflowRunOperationContext, input: int + ) -> nexus.WorkflowHandle[int]: + return await ctx.start_workflow( + "OverloadTestHandlerWorkflow", + input, + id=str(uuid.uuid4()), + result_type=OverloadTestValue, + ) + + @workflow_run_operation + async def by_name_multi_param( + self, ctx: WorkflowRunOperationContext, input: int + ) -> nexus.WorkflowHandle[int]: + return await ctx.start_workflow( + "OverloadTestHandlerWorkflow", + args=[input], + id=str(uuid.uuid4()), + ) + + +@dataclass +class OverloadTestInput: + op: Callable[ + [Any, WorkflowRunOperationContext, Any], + Awaitable[temporalio.nexus.WorkflowHandle[Any]], + ] + input: Any + output: Any + + +@workflow.defn +class OverloadTestCallerWorkflow: + @workflow.run + async def run(self, op: str, input: int) -> int: + nexus_client = workflow.NexusClient( + service=OverloadTestServiceHandler, + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + ) + if op == "no_param": + return await nexus_client.execute_operation( + OverloadTestServiceHandler.no_param, input + ) + elif op == "single_param": + return await nexus_client.execute_operation( + OverloadTestServiceHandler.single_param, input + ) + elif op == "multi_param": + return await nexus_client.execute_operation( + OverloadTestServiceHandler.multi_param, input + ) + elif op == "by_name": + return await nexus_client.execute_operation( + OverloadTestServiceHandler.by_name, input + ) + elif op == "by_name_multi_param": + return await nexus_client.execute_operation( + OverloadTestServiceHandler.by_name_multi_param, input + ) + else: + raise ValueError(f"Unknown op: {op}") + + +@pytest.mark.parametrize( + "op", + [ + "no_param", + "single_param", + "multi_param", + "by_name", + "by_name_multi_param", + ], +) +async def test_workflow_run_operation_overloads(client: Client, op: str): + task_queue = str(uuid.uuid4()) + async with Worker( + client, + task_queue=task_queue, + workflows=[ + OverloadTestCallerWorkflow, + OverloadTestHandlerWorkflow, + OverloadTestHandlerWorkflowNoParam, + ], + nexus_service_handlers=[OverloadTestServiceHandler()], + ): + await create_nexus_endpoint(task_queue, client) + res = await client.execute_workflow( + OverloadTestCallerWorkflow.run, + args=[op, 2], + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + assert res == (4 if op != "no_param" else 0) From ac3c96e2d648c96df093b9005397889ddb27c6ea Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 2 Jul 2025 20:15:24 -0400 Subject: [PATCH 152/183] Add additional overloads --- temporalio/nexus/_operation_context.py | 41 ++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 3f0df0a6d..74bf2a691 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -9,12 +9,15 @@ from datetime import timedelta from typing import ( Any, + Awaitable, Callable, Mapping, MutableMapping, Optional, Sequence, + Type, Union, + overload, ) import nexusrpc.handler @@ -194,6 +197,7 @@ def from_start_operation_context( # Overload for single-param workflow # TODO(nexus-prerelease)*: bring over other overloads + @overload async def start_workflow( self, workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], @@ -225,6 +229,41 @@ async def start_workflow( request_eager_start: bool = False, priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, + ) -> WorkflowHandle[ReturnType]: ... + + async def start_workflow( + self, + workflow: Union[str, Callable[..., Awaitable[ReturnType]]], + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: Optional[str] = None, + result_type: Optional[Type] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, ) -> WorkflowHandle[ReturnType]: """Start a workflow that will deliver the result of the Nexus operation. @@ -266,8 +305,10 @@ async def start_workflow( wf_handle = await self.temporal_context.client.start_workflow( # type: ignore workflow=workflow, arg=arg, + args=args, id=id, task_queue=task_queue or self.temporal_context.info().task_queue, + result_type=result_type, execution_timeout=execution_timeout, run_timeout=run_timeout, task_timeout=task_timeout, From de2aa482d8ce042a101043937b9a2129f5aa611d Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 2 Jul 2025 21:06:57 -0400 Subject: [PATCH 153/183] string name workflow --- temporalio/nexus/_operation_context.py | 37 ++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 74bf2a691..9b5880429 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -231,6 +231,43 @@ async def start_workflow( versioning_override: Optional[temporalio.common.VersioningOverride] = None, ) -> WorkflowHandle[ReturnType]: ... + # Overload for string-name workflow + @overload + async def start_workflow( + self, + workflow: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: Optional[str] = None, + result_type: Optional[Type] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, + ) -> WorkflowHandle[Any]: ... + async def start_workflow( self, workflow: Union[str, Callable[..., Awaitable[ReturnType]]], From 915fde145eeb5a9e240a47a0dc1062dad941af1c Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 2 Jul 2025 21:08:44 -0400 Subject: [PATCH 154/183] Use a dataclass --- tests/nexus/test_workflow_caller.py | 43 +++++++++++++++++------------ 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 4bc5e0919..8e9998e5e 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -1560,18 +1560,23 @@ async def test_timeout_error_raised_by_nexus_operation(client: Client): # Test overloads +@dataclass +class OverloadTestValue: + value: int + + @workflow.defn class OverloadTestHandlerWorkflow: @workflow.run - async def run(self, input: int) -> int: - return input * 2 + async def run(self, input: OverloadTestValue) -> OverloadTestValue: + return OverloadTestValue(value=input.value * 2) @workflow.defn class OverloadTestHandlerWorkflowNoParam: @workflow.run - async def run(self) -> int: - return 0 + async def run(self) -> OverloadTestValue: + return OverloadTestValue(value=0) @nexusrpc.handler.service_handler @@ -1580,8 +1585,8 @@ class OverloadTestServiceHandler: async def no_param( self, ctx: WorkflowRunOperationContext, - _: int, - ) -> nexus.WorkflowHandle[int]: + _: OverloadTestValue, + ) -> nexus.WorkflowHandle[OverloadTestValue]: return await ctx.start_workflow( OverloadTestHandlerWorkflowNoParam.run, id=str(uuid.uuid4()), @@ -1589,8 +1594,8 @@ async def no_param( @workflow_run_operation async def single_param( - self, ctx: WorkflowRunOperationContext, input: int - ) -> nexus.WorkflowHandle[int]: + self, ctx: WorkflowRunOperationContext, input: OverloadTestValue + ) -> nexus.WorkflowHandle[OverloadTestValue]: return await ctx.start_workflow( OverloadTestHandlerWorkflow.run, input, @@ -1599,8 +1604,8 @@ async def single_param( @workflow_run_operation async def multi_param( - self, ctx: WorkflowRunOperationContext, input: int - ) -> nexus.WorkflowHandle[int]: + self, ctx: WorkflowRunOperationContext, input: OverloadTestValue + ) -> nexus.WorkflowHandle[OverloadTestValue]: return await ctx.start_workflow( OverloadTestHandlerWorkflow.run, args=[input], @@ -1609,8 +1614,8 @@ async def multi_param( @workflow_run_operation async def by_name( - self, ctx: WorkflowRunOperationContext, input: int - ) -> nexus.WorkflowHandle[int]: + self, ctx: WorkflowRunOperationContext, input: OverloadTestValue + ) -> nexus.WorkflowHandle[OverloadTestValue]: return await ctx.start_workflow( "OverloadTestHandlerWorkflow", input, @@ -1620,8 +1625,8 @@ async def by_name( @workflow_run_operation async def by_name_multi_param( - self, ctx: WorkflowRunOperationContext, input: int - ) -> nexus.WorkflowHandle[int]: + self, ctx: WorkflowRunOperationContext, input: OverloadTestValue + ) -> nexus.WorkflowHandle[OverloadTestValue]: return await ctx.start_workflow( "OverloadTestHandlerWorkflow", args=[input], @@ -1642,7 +1647,7 @@ class OverloadTestInput: @workflow.defn class OverloadTestCallerWorkflow: @workflow.run - async def run(self, op: str, input: int) -> int: + async def run(self, op: str, input: OverloadTestValue) -> OverloadTestValue: nexus_client = workflow.NexusClient( service=OverloadTestServiceHandler, endpoint=make_nexus_endpoint_name(workflow.info().task_queue), @@ -1696,8 +1701,12 @@ async def test_workflow_run_operation_overloads(client: Client, op: str): await create_nexus_endpoint(task_queue, client) res = await client.execute_workflow( OverloadTestCallerWorkflow.run, - args=[op, 2], + args=[op, OverloadTestValue(value=2)], id=str(uuid.uuid4()), task_queue=task_queue, ) - assert res == (4 if op != "no_param" else 0) + assert res == ( + OverloadTestValue(value=4) + if op != "no_param" + else OverloadTestValue(value=0) + ) From 4462e3a4063fa52884915c57e5b1260c0841bf77 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 2 Jul 2025 21:14:44 -0400 Subject: [PATCH 155/183] More overloads --- temporalio/nexus/_operation_context.py | 79 +++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 9b5880429..6179bc910 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -22,6 +22,7 @@ import nexusrpc.handler from nexusrpc.handler import CancelOperationContext, StartOperationContext +from typing_extensions import Concatenate import temporalio.api.common.v1 import temporalio.api.enums.v1 @@ -29,7 +30,9 @@ import temporalio.common from temporalio.nexus._token import WorkflowHandle from temporalio.types import ( + MethodAsyncNoParam, MethodAsyncSingleParam, + MultiParamSpec, ParamType, ReturnType, SelfType, @@ -195,8 +198,41 @@ def from_start_operation_context( **{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)}, ) + # Overload for no-param workflow + @overload + async def start_workflow( + self, + workflow: MethodAsyncNoParam[SelfType, ReturnType], + *, + id: str, + task_queue: Optional[str] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, + ) -> WorkflowHandle[ReturnType]: ... + # Overload for single-param workflow - # TODO(nexus-prerelease)*: bring over other overloads @overload async def start_workflow( self, @@ -231,6 +267,43 @@ async def start_workflow( versioning_override: Optional[temporalio.common.VersioningOverride] = None, ) -> WorkflowHandle[ReturnType]: ... + # Overload for multi-param workflow + @overload + async def start_workflow( + self, + workflow: Callable[ + Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType] + ], + *, + args: Sequence[Any], + id: str, + task_queue: Optional[str] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, + ) -> WorkflowHandle[ReturnType]: ... + # Overload for string-name workflow @overload async def start_workflow( @@ -241,7 +314,7 @@ async def start_workflow( args: Sequence[Any] = [], id: str, task_queue: Optional[str] = None, - result_type: Optional[Type] = None, + result_type: Optional[Type[ReturnType]] = None, execution_timeout: Optional[timedelta] = None, run_timeout: Optional[timedelta] = None, task_timeout: Optional[timedelta] = None, @@ -266,7 +339,7 @@ async def start_workflow( request_eager_start: bool = False, priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, - ) -> WorkflowHandle[Any]: ... + ) -> WorkflowHandle[ReturnType]: ... async def start_workflow( self, From e028a713b02a13d764c7f2d4d30fa6591634f887 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 30 Jun 2025 23:26:29 -0400 Subject: [PATCH 156/183] Initial version of nexus_operation_as_tool --- .../contrib/openai_agents/temporal_tools.py | 103 +++++++++++++++++- tests/contrib/openai_agents/test_openai.py | 64 ++++++++++- 2 files changed, 160 insertions(+), 7 deletions(-) diff --git a/temporalio/contrib/openai_agents/temporal_tools.py b/temporalio/contrib/openai_agents/temporal_tools.py index e2ba8ed39..7083e0bf2 100644 --- a/temporalio/contrib/openai_agents/temporal_tools.py +++ b/temporalio/contrib/openai_agents/temporal_tools.py @@ -1,12 +1,14 @@ """Support for using Temporal activities as OpenAI agents tools.""" import json +import typing from datetime import timedelta -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Type from temporalio import activity, workflow from temporalio.common import Priority, RetryPolicy from temporalio.exceptions import ApplicationError, TemporalError +from temporalio.nexus._util import get_operation_factory from temporalio.workflow import ActivityCancellationType, VersioningIntent, unsafe with unsafe.imports_passed_through(): @@ -115,3 +117,102 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any: on_invoke_tool=run_activity, strict_json_schema=True, ) + + +def nexus_operation_as_tool( + fn: Callable, + *, + service: Type[Any], + endpoint: str, + schedule_to_close_timeout: Optional[timedelta] = None, +) -> Tool: + """Convert a Nexus operation into an OpenAI agent tool. + + .. warning:: + This API is experimental and may change in future versions. + Use with caution in production environments. + + This function takes a Nexus operation and converts it into an + OpenAI agent tool that can be used by the agent to execute the operation + during workflow execution. The tool will automatically handle the conversion + of inputs and outputs between the agent and the operation. + + Args: + fn: A Nexus operation to convert into a tool. + service: The Nexus service class that contains the operation. + endpoint: The Nexus endpoint to use for the operation. + + Returns: + An OpenAI agent tool that wraps the provided operation. + + Raises: + ApplicationError: If the operation is not properly decorated as a Nexus operation. + + Example: + >>> @service_handler + >>> class WeatherServiceHandler: + ... @sync_operation + ... async def get_weather_object(self, ctx: StartOperationContext, input: WeatherInput) -> Weather: + ... return Weather( + ... city=input.city, temperature_range="14-20C", conditions="Sunny with wind." + ... ) + >>> + >>> # Create tool with custom activity options + >>> tool = nexus_operation_as_tool( + ... WeatherServiceHandler.get_weather_object, + ... service=WeatherServiceHandler, + ... endpoint="weather-service", + ... ) + >>> # Use tool with an OpenAI agent + """ + if not get_operation_factory(fn): + raise ApplicationError( + "Function is not a Nexus operation", + "invalid_tool", + ) + + schema = function_schema(adapt_nexus_operation_function_schema(fn)) + + async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any: + try: + json_data = json.loads(input) + except Exception as e: + raise ApplicationError( + f"Invalid JSON input for tool {schema.name}: {input}" + ) from e + + nexus_client = workflow.NexusClient(service=service, endpoint=endpoint) + args, _ = schema.to_call_args(schema.params_pydantic_model(**json_data)) + assert len(args) == 1, "Nexus operations must have exactly one argument" + [arg] = args + result = await nexus_client.execute_operation( + fn, + arg, + schedule_to_close_timeout=schedule_to_close_timeout, + ) + try: + return str(result) + except Exception as e: + raise ToolSerializationError( + "You must return a string representation of the tool output, or something we can call str() on" + ) from e + + return FunctionTool( + name=schema.name, + description=schema.description or "", + params_json_schema=schema.params_json_schema, + on_invoke_tool=run_operation, + strict_json_schema=True, + ) + + +def adapt_nexus_operation_function_schema(fn: Callable[..., Any]) -> Callable[..., Any]: + # Nexus operation start methods look like + # async def operation(self, ctx: StartOperationContext, input: InputType) -> OutputType + _, inputT, retT = typing.get_type_hints(fn).values() + + def adapted(input: inputT) -> retT: # type: ignore + pass + + adapted.__name__ = fn.__name__ + return adapted diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 8d989c154..a87853471 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -5,6 +5,7 @@ from typing import Any, Optional, Union, no_type_check import pytest +from nexusrpc.handler import StartOperationContext, service_handler, sync_operation from pydantic import ConfigDict, Field from temporalio import activity, workflow @@ -19,12 +20,16 @@ from temporalio.contrib.openai_agents.temporal_openai_agents import ( set_open_ai_agent_temporal_overrides, ) -from temporalio.contrib.openai_agents.temporal_tools import activity_as_tool +from temporalio.contrib.openai_agents.temporal_tools import ( + activity_as_tool, + nexus_operation_as_tool, +) from temporalio.contrib.openai_agents.trace_interceptor import ( OpenAIAgentsTracingInterceptor, ) from temporalio.exceptions import CancelledError from tests.helpers import new_worker +from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name with workflow.unsafe.imports_passed_through(): from agents import ( @@ -223,6 +228,17 @@ async def get_weather_object(input: WeatherInput) -> Weather: ) +@service_handler +class WeatherServiceHandler: + @sync_operation + async def get_weather_object_nexus_operation( + self, ctx: StartOperationContext, input: WeatherInput + ) -> Weather: + return Weather( + city=input.city, temperature_range="14-20C", conditions="Sunny with wind." + ) + + class TestWeatherModel(TestModel): responses = [ ModelResponse( @@ -253,6 +269,20 @@ class TestWeatherModel(TestModel): usage=Usage(), response_id=None, ), + ModelResponse( + output=[ + ResponseFunctionToolCall( + arguments='{"input":{"city":"Tokyo"}}', + call_id="call", + name="get_weather_object_nexus_operation", + type="function_call", + id="id", + status="completed", + ) + ], + usage=Usage(), + response_id=None, + ), ModelResponse( output=[ ResponseFunctionToolCall( @@ -306,6 +336,12 @@ async def run(self, question: str) -> str: activity_as_tool( get_weather_country, start_to_close_timeout=timedelta(seconds=10) ), + nexus_operation_as_tool( + WeatherServiceHandler.get_weather_object_nexus_operation, + service=WeatherServiceHandler, + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + schedule_to_close_timeout=timedelta(seconds=10), + ), ], ) # type: Agent result = await Runner.run(starting_agent=agent, input=question) @@ -340,8 +376,11 @@ async def test_tool_workflow(client: Client, use_local_model: bool): get_weather_object, get_weather_country, ], + nexus_service_handlers=[WeatherServiceHandler()], interceptors=[OpenAIAgentsTracingInterceptor()], ) as worker: + await create_nexus_endpoint(worker.task_queue, client) + workflow_handle = await client.start_workflow( ToolsWorkflow.run, "What is the weather in Tokio?", @@ -353,13 +392,14 @@ async def test_tool_workflow(client: Client, use_local_model: bool): if use_local_model: assert result == "Test weather result" - events = [] async for e in workflow_handle.fetch_history_events(): - if e.HasField("activity_task_completed_event_attributes"): + if e.HasField( + "activity_task_completed_event_attributes" + ) or e.HasField("nexus_operation_completed_event_attributes"): events.append(e) - assert len(events) == 7 + assert len(events) == 9 assert ( "function_call" in events[0] @@ -392,13 +432,25 @@ async def test_tool_workflow(client: Client, use_local_model: bool): ) assert ( "Sunny with wind" - in events[5] + in events[ + 5 + ].nexus_operation_completed_event_attributes.result.data.decode() + ) + assert ( + "function_call" + in events[6] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[7] .activity_task_completed_event_attributes.result.payloads[0] .data.decode() ) assert ( "Test weather result" - in events[6] + in events[8] .activity_task_completed_event_attributes.result.payloads[0] .data.decode() ) From 4de6b4896a02ac7ab702c12835708195fc23999c Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 2 Jul 2025 21:59:58 -0400 Subject: [PATCH 157/183] Fix mypy failures --- temporalio/nexus/_token.py | 2 +- temporalio/worker/_interceptor.py | 2 +- temporalio/worker/_nexus.py | 25 ++++++------- temporalio/worker/_workflow_instance.py | 4 +-- ...ynamic_creation_of_user_handler_classes.py | 5 ++- tests/nexus/test_handler.py | 35 +++++++++++-------- tests/nexus/test_workflow_caller.py | 11 +++--- 7 files changed, 46 insertions(+), 38 deletions(-) diff --git a/temporalio/nexus/_token.py b/temporalio/nexus/_token.py index a6290111c..18bf0dba8 100644 --- a/temporalio/nexus/_token.py +++ b/temporalio/nexus/_token.py @@ -9,8 +9,8 @@ from temporalio import client -OPERATION_TOKEN_TYPE_WORKFLOW = 1 OperationTokenType = Literal[1] +OPERATION_TOKEN_TYPE_WORKFLOW: OperationTokenType = 1 @dataclass(frozen=True) diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 667848f16..692721ad1 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -312,7 +312,7 @@ def __post_init__(self) -> None: elif isinstance(self.operation, str): self._operation_name = self.operation self._input_type = None - elif isinstance(self.operation, Callable): + elif callable(self.operation): _, op = temporalio.nexus._util.get_operation_factory(self.operation) if isinstance(op, nexusrpc.Operation): self._operation_name = op.name diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 67e3a3f81..65973e4a1 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -89,10 +89,10 @@ async def raise_from_exception_queue() -> NoReturn: if exception_task.done(): poll_task.cancel() await exception_task - task = await poll_task + nexus_task = await poll_task - if task.HasField("task"): - task = task.task + if nexus_task.HasField("task"): + task = nexus_task.task if task.request.HasField("start_operation"): self._running_tasks[task.task_token] = asyncio.create_task( self._handle_start_operation_task( @@ -115,18 +115,19 @@ async def raise_from_exception_queue() -> NoReturn: raise NotImplementedError( f"Invalid Nexus task request: {task.request}" ) - elif task.HasField("cancel_task"): - task = task.cancel_task - if _task := self._running_tasks.get(task.task_token): + elif nexus_task.HasField("cancel_task"): + if running_task := self._running_tasks.get( + nexus_task.cancel_task.task_token + ): # TODO(nexus-prerelease): when do we remove the entry from _running_operations? - _task.cancel() + running_task.cancel() else: logger.debug( f"Received cancel_task but no running task exists for " - f"task token: {task.task_token}" + f"task token: {nexus_task.cancel_task.task_token.decode()}" ) else: - raise NotImplementedError(f"Invalid Nexus task: {task}") + raise NotImplementedError(f"Invalid Nexus task: {nexus_task}") except temporalio.bridge.worker.PollShutdownError: exception_task.cancel() @@ -321,11 +322,11 @@ async def _exception_to_failure_proto( try: api_failure = temporalio.api.failure.v1.Failure() await self._data_converter.encode_failure(err, api_failure) - api_failure = google.protobuf.json_format.MessageToDict(api_failure) + _api_failure = google.protobuf.json_format.MessageToDict(api_failure) return temporalio.api.nexus.v1.Failure( - message=api_failure.pop("message", ""), + message=_api_failure.pop("message", ""), metadata={"type": "temporal.api.failure.v1.Failure"}, - details=json.dumps(api_failure).encode("utf-8"), + details=json.dumps(_api_failure).encode("utf-8"), ) except BaseException as err: return temporalio.api.nexus.v1.Failure( diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 78b53c589..a7fe73d67 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1842,7 +1842,7 @@ async def _outbound_start_nexus_operation( async def operation_handle_fn() -> OutputT: while True: try: - return await asyncio.shield(handle._result_fut) + return cast(OutputT, await asyncio.shield(handle._result_fut)) except asyncio.CancelledError: cancel_command = self._add_command() handle._apply_cancel_command(cancel_command) @@ -3038,7 +3038,7 @@ def _resolve_start_success(self, operation_token: Optional[str]) -> None: # We intentionally let this error if already done self._start_fut.set_result(operation_token) - def _resolve_success(self, result: OutputT) -> None: + def _resolve_success(self, result: Any) -> None: # We intentionally let this error if already done self._result_fut.set_result(result) diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index 1ab153b06..f2d1ec84e 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -21,7 +21,7 @@ def make_incrementer_user_service_definition_and_service_handler_classes( # ops = {name: nexusrpc.Operation[int, int] for name in op_names} - service_cls = nexusrpc.service(type("ServiceContract", (), ops)) + service_cls: type = nexusrpc.service(type("ServiceContract", (), ops)) # # service handler @@ -40,7 +40,7 @@ async def _increment_op( assert op_handler_factory op_handler_factories[name] = op_handler_factory - handler_cls = nexusrpc.handler.service_handler(service=service_cls)( + handler_cls: type = nexusrpc.handler.service_handler(service=service_cls)( type("ServiceImpl", (), op_handler_factories) ) @@ -72,7 +72,6 @@ async def test_dynamic_creation_of_user_handler_classes(client: Client): response = await http_client.post( f"http://127.0.0.1:{HTTP_PORT}/nexus/endpoints/{endpoint}/services/{service_name}/increment", json=1, - headers={}, ) assert response.status_code == 200 assert response.json() == 2 diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 025a442bb..83ddbbbca 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -40,6 +40,7 @@ FetchOperationResultContext, OperationHandler, StartOperationContext, + StartOperationResultSync, service_handler, sync_operation, ) @@ -50,7 +51,10 @@ from temporalio.client import Client from temporalio.common import WorkflowIDReusePolicy from temporalio.exceptions import ApplicationError -from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation +from temporalio.nexus import ( + WorkflowRunOperationContext, + workflow_run_operation, +) from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ( @@ -256,10 +260,10 @@ async def start( input: Input, # This return type is a type error, but VSCode doesn't flag it unless # "python.analysis.typeCheckingMode" is set to "strict" - ) -> Output: + ) -> StartOperationResultSync[Output]: # Invalid: start method must wrap result as StartOperationResultSync # or StartOperationResultAsync - return Output(value="unwrapped result error") # type: ignore + return StartOperationResultSync(Output(value="unwrapped result error")) # type: ignore async def fetch_info( self, ctx: FetchOperationInfoContext, token: str @@ -365,7 +369,7 @@ def check_response( class _FailureTestCase(_TestCase): - expected: UnsuccessfulResponse + expected: UnsuccessfulResponse # type: ignore[assignment] @classmethod def check_response( @@ -398,10 +402,9 @@ def check_response( exception_from_failure_details.type == "HandlerError" and exception_from_failure_details.__cause__ ): - exception_from_failure_details = ( - exception_from_failure_details.__cause__ - ) - assert isinstance(exception_from_failure_details, ApplicationError) + cause = exception_from_failure_details.__cause__ + assert isinstance(cause, ApplicationError) + exception_from_failure_details = cause assert exception_from_failure_details.non_retryable == ( not cls.expected.retryable_exception @@ -534,6 +537,8 @@ class BadRequest(_FailureTestCase): class _ApplicationErrorTestCase(_FailureTestCase): """Test cases in which the operation raises an ApplicationError.""" + expected: UnsuccessfulResponse # type: ignore[assignment] + @classmethod def check_response( cls, response: httpx.Response, with_service_definition: bool @@ -919,18 +924,20 @@ async def start( input: Input, # This return type is a type error, but VSCode doesn't flag it unless # "python.analysis.typeCheckingMode" is set to "strict" - ) -> Output: + ) -> StartOperationResultSync[Output]: # Invalid: start method must wrap result as StartOperationResultSync # or StartOperationResultAsync - return Output(value="Hello") # type: ignore + return StartOperationResultSync(Output(value="Hello")) # type: ignore - def cancel(self, ctx: CancelOperationContext, token: str) -> Output: - return Output(value="Hello") # type: ignore + def cancel(self, ctx: CancelOperationContext, token: str) -> None: + return None # type: ignore - def fetch_info(self, ctx: FetchOperationInfoContext) -> OperationInfo: + def fetch_info( + self, ctx: FetchOperationInfoContext, token: str + ) -> OperationInfo: raise NotImplementedError - def fetch_result(self, ctx: FetchOperationResultContext) -> Output: + def fetch_result(self, ctx: FetchOperationResultContext, token: str) -> Output: raise NotImplementedError @operation_handler diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 8e9998e5e..857ef8acb 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -149,7 +149,7 @@ async def run( class SyncOrAsyncOperation(OperationHandler[OpInput, OpOutput]): - async def start( + async def start( # type: ignore[override] self, ctx: StartOperationContext, input: OpInput ) -> Union[ StartOperationResultSync[OpOutput], @@ -312,7 +312,7 @@ def _get_operation( nexusrpc.Operation[OpInput, OpOutput], Callable[[Any], OperationHandler[OpInput, OpOutput]], ]: - return { + return { # type: ignore[return-value] ( SyncResponse, OpDefinitionType.SHORTHAND, @@ -383,7 +383,7 @@ def __init__( ) -> None: # TODO(nexus-preview): untyped caller cannot reference name of implementation. I think this is as it should be. service_name = "ServiceInterface" - self.nexus_client = workflow.NexusClient( + self.nexus_client: workflow.NexusClient[Any] = workflow.NexusClient( service=service_name, endpoint=make_nexus_endpoint_name(task_queue), ) @@ -800,6 +800,7 @@ async def run( task_queue: str, ) -> ServiceClassNameOutput: C, N = CallerReference, NameOverride + service_cls: type if (caller_reference, name_override) == (C.INTERFACE, N.YES): service_cls = ServiceInterfaceWithNameOverride elif (caller_reference, name_override) == (C.INTERFACE, N.NO): @@ -1190,7 +1191,7 @@ def parse_exception( } -error_conversion_test_cases = [] +error_conversion_test_cases: list[ErrorConversionTestCase] = [] # application_error_non_retryable: @@ -1465,7 +1466,7 @@ async def run(self, input: ErrorTestInput) -> None: {input.action_in_sync_op} {'-' * 80} """) - for java_behavior, actual in results: + for java_behavior, actual in results: # type: ignore[assignment] print(f"Java: {java_behavior}") print(f"Python: {actual}") print() From 38a50faab6c2959f3dc951cb455ddf669c3847c3 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 06:41:44 -0400 Subject: [PATCH 158/183] Revert "Convert nexus_handler_failure_info as nexusrpc.HandlerError" This reverts commit a9bac66287fbccde67e50852647265a4c60abdfc. --- temporalio/converter.py | 15 ++----------- temporalio/exceptions.py | 35 ++++++++++++++++++++++++++++- tests/nexus/test_workflow_caller.py | 7 +++--- 3 files changed, 40 insertions(+), 17 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 43dbe305b..b976eca08 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -16,7 +16,6 @@ from datetime import datetime from enum import IntEnum from itertools import zip_longest -from logging import getLogger from typing import ( Any, Awaitable, @@ -41,7 +40,6 @@ import google.protobuf.json_format import google.protobuf.message import google.protobuf.symbol_database -import nexusrpc import typing_extensions import temporalio.api.common.v1 @@ -62,8 +60,6 @@ if sys.version_info >= (3, 10): from types import UnionType -logger = getLogger(__name__) - class PayloadConverter(ABC): """Base payload converter to/from multiple payloads/values.""" @@ -1018,16 +1014,9 @@ def from_failure( ) elif failure.HasField("nexus_handler_failure_info"): nexus_handler_failure_info = failure.nexus_handler_failure_info - try: - _type = nexusrpc.HandlerErrorType[nexus_handler_failure_info.type] - except KeyError: - logger.warning( - f"Unknown Nexus HandlerErrorType: {nexus_handler_failure_info.type}" - ) - _type = nexusrpc.HandlerErrorType.INTERNAL - return nexusrpc.HandlerError( + err = temporalio.exceptions.NexusHandlerError( failure.message or "Nexus handler error", - type=_type, + type=nexus_handler_failure_info.type, retryable={ temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE: True, temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: False, diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index 0a1cd9a1d..c088614e9 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -375,7 +375,15 @@ def __init__( operation: str, operation_token: str, ): - """Initialize a Nexus operation error.""" + """ + Args: + message: The error message. + scheduled_event_id: The NexusOperationScheduled event ID for the failed operation. + endpoint: The endpoint name for the failed operation. + service: The service name for the failed operation. + operation: The name of the failed operation. + operation_token: The operation token returned by the failed operation. + """ super().__init__(message) self._scheduled_event_id = scheduled_event_id self._endpoint = endpoint @@ -409,6 +417,31 @@ def operation_token(self) -> str: return self._operation_token +class NexusHandlerError(FailureError): + """ + Error raised on Nexus handler failure. + + This is a Temporal serialized form of nexusrpc.HandlerError. + """ + + def __init__( + self, + message: str, + *, + type: str, + retryable: Optional[bool] = None, + ): + """ + Args: + message: The error message. + type: String representation of the nexusrpc.HandlerErrorType. + retryable: Whether the error was marked as retryable by the code that raised it. + """ + super().__init__(message) + self.type = type + self.retryable = retryable + + def is_cancelled_exception(exception: BaseException) -> bool: """Check whether the given exception is considered a cancellation exception according to Temporal. diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 857ef8acb..efcf4e3ad 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -44,6 +44,7 @@ from temporalio.exceptions import ( ApplicationError, CancelledError, + NexusHandlerError, NexusOperationError, TimeoutError, ) @@ -486,7 +487,7 @@ async def test_sync_response( e = ei.value assert isinstance(e, WorkflowFailureError) assert isinstance(e.__cause__, NexusOperationError) - assert isinstance(e.__cause__.__cause__, nexusrpc.HandlerError) + assert isinstance(e.__cause__.__cause__, NexusHandlerError) # ID of first command assert e.__cause__.scheduled_event_id == 5 assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue) @@ -539,7 +540,7 @@ async def test_async_response( e = ei.value assert isinstance(e, WorkflowFailureError) assert isinstance(e.__cause__, NexusOperationError) - assert isinstance(e.__cause__.__cause__, nexusrpc.HandlerError) + assert isinstance(e.__cause__.__cause__, NexusHandlerError) # ID of first command after update accepted assert e.__cause__.scheduled_event_id == 6 assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue) @@ -716,7 +717,7 @@ async def test_untyped_caller( e = ei.value assert isinstance(e, WorkflowFailureError) assert isinstance(e.__cause__, NexusOperationError) - assert isinstance(e.__cause__.__cause__, nexusrpc.HandlerError) + assert isinstance(e.__cause__.__cause__, NexusHandlerError) else: result = await caller_wf_handle.result() assert result.op_output.value == ( From 0fb7837575fca848ba747d0a76e89b0b9a74aaa5 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 07:33:05 -0400 Subject: [PATCH 159/183] Cleanup error test --- tests/nexus/test_workflow_caller.py | 47 ++++++----------------------- 1 file changed, 9 insertions(+), 38 deletions(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index efcf4e3ad..1d8faf612 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -1196,14 +1196,6 @@ def parse_exception( # application_error_non_retryable: -_ = ["NexusOperationError", "HandlerError"] -# Java -_ = [ - "NexusOperationError", - "HandlerError('handler error: message='application error 1', type='my-application-error-type', nonRetryable=true', type='INTERNAL', nonRetryable=true)", - "ApplicationError('application error 1', type='my-application-error-type', nonRetryable=true)", -] - error_conversion_test_cases.append( ErrorConversionTestCase( name="application_error_non_retryable", @@ -1229,11 +1221,16 @@ def parse_exception( ) ) +# custom_error: +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="custom_error", + java_behavior=[], # [Not possible] + ) +) + # custom_error_from_custom_error: -_ = ["NexusOperationError", "HandlerError"] -# Java -# [Not possible] error_conversion_test_cases.append( ErrorConversionTestCase( name="custom_error_from_custom_error", @@ -1243,15 +1240,6 @@ def parse_exception( # application_error_non_retryable_from_custom_error: -_ = ["NexusOperationError", "HandlerError"] -# Java -_ = [ - "NexusOperationError", - "HandlerError('handler error: message='application error 1', type='my-application-error-type', nonRetryable=true', type='INTERNAL', nonRetryable=true)", - "ApplicationError('application error 1', type='my-application-error-type', nonRetryable=true)", - "ApplicationError('Custom error 2', type='io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException', nonRetryable=false)", -] - error_conversion_test_cases.append( ErrorConversionTestCase( name="application_error_non_retryable_from_custom_error", @@ -1286,14 +1274,6 @@ def parse_exception( ) # nexus_handler_error_not_found: -_ = ["NexusOperationError", "HandlerError"] -# Java -_ = [ - "NexusOperationError", - "HandlerError('handler error: message='Handler error 1', type='java.lang.RuntimeException', nonRetryable=false', type='NOT_FOUND', nonRetryable=true)", - "ApplicationError('Handler error 1', type='java.lang.RuntimeException', nonRetryable=false)", -] - error_conversion_test_cases.append( ErrorConversionTestCase( name="nexus_handler_error_not_found", @@ -1320,9 +1300,6 @@ def parse_exception( ) # nexus_handler_error_not_found_from_custom_error: -_ = ["NexusOperationError", "HandlerError"] -# Java -# [Not possible] error_conversion_test_cases.append( ErrorConversionTestCase( name="nexus_handler_error_not_found_from_custom_error", @@ -1332,13 +1309,6 @@ def parse_exception( # nexus_operation_error_from_application_error_non_retryable_from_custom_error: -_ = ["NexusOperationError", "ApplicationError", "ApplicationError"] -# Java -_ = [ - "NexusOperationError", - "ApplicationError('application error 1', type='my-application-error-type', nonRetryable=true)", - "ApplicationError('Custom error 2', type='io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException', nonRetryable=false)", -] error_conversion_test_cases.append( ErrorConversionTestCase( name="nexus_operation_error_from_application_error_non_retryable_from_custom_error", @@ -1481,6 +1451,7 @@ async def run(self, input: ErrorTestInput) -> None: "action_in_sync_op", [ "application_error_non_retryable", + "custom_error", "custom_error_from_custom_error", "application_error_non_retryable_from_custom_error", "nexus_handler_error_not_found", From 7656a0201743e37eeed773638e3158dbe02c77a5 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 08:06:04 -0400 Subject: [PATCH 160/183] Revert "Delete redundant test" This reverts commit 86a9a61378f75b85416a10b582632ad726117eeb. --- tests/nexus/test_handler.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 83ddbbbca..2d2bd5996 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -835,6 +835,14 @@ async def test_start_operation_without_type_annotations( assert not any(warnings), [w.message for w in warnings] +def test_operation_without_type_annotations_without_service_definition_raises_validation_error(): + with pytest.raises( + ValueError, + match=r"has no input type.+has no output type", + ): + service_handler(MyServiceHandlerWithOperationsWithoutTypeAnnotations) + + async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: Any): task_queue = str(uuid.uuid4()) service_name = MyService.__name__ From 075ec7c815bbaba6034338c0a9582afefc4338f1 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 08:32:04 -0400 Subject: [PATCH 161/183] Evolve context API --- temporalio/nexus/_operation_context.py | 16 ++++++---------- tests/nexus/test_handler.py | 9 +++------ 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 6179bc910..866335417 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -120,7 +120,7 @@ def get(cls) -> _TemporalStartOperationContext: def set(self) -> None: _temporal_start_operation_context.set(self) - def get_completion_callbacks( + def _get_completion_callbacks( self, ) -> list[temporalio.client.NexusCompletionCallback]: ctx = self.nexus_context @@ -140,7 +140,7 @@ def get_completion_callbacks( else [] ) - def get_workflow_event_links( + def _get_workflow_event_links( self, ) -> list[temporalio.api.common.v1.Link.WorkflowEvent]: event_links = [] @@ -149,7 +149,7 @@ def get_workflow_event_links( event_links.append(link) return event_links - def add_outbound_links( + def _add_outbound_links( self, workflow_handle: temporalio.client.WorkflowHandle[Any, Any] ): try: @@ -185,10 +185,6 @@ def temporal_context(self) -> _TemporalStartOperationContext: raise RuntimeError("Temporal context not set") return self._temporal_context - @property - def nexus_context(self) -> StartOperationContext: - return self.temporal_context.nexus_context - @classmethod def from_start_operation_context( cls, ctx: StartOperationContext @@ -438,12 +434,12 @@ async def start_workflow( request_eager_start=request_eager_start, priority=priority, versioning_override=versioning_override, - nexus_completion_callbacks=self.temporal_context.get_completion_callbacks(), - workflow_event_links=self.temporal_context.get_workflow_event_links(), + nexus_completion_callbacks=self.temporal_context._get_completion_callbacks(), + workflow_event_links=self.temporal_context._get_workflow_event_links(), request_id=self.temporal_context.nexus_context.request_id, ) - self.temporal_context.add_outbound_links(wf_handle) + self.temporal_context._add_outbound_links(wf_handle) return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 2d2bd5996..a9b116c73 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -239,13 +239,10 @@ async def workflow_run_op_link_test( self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: assert any( - link.url == "http://inbound-link/" - for link in ctx.nexus_context.inbound_links + link.url == "http://inbound-link/" for link in ctx.inbound_links ), "Inbound link not found" - assert ( - ctx.nexus_context.request_id == "test-request-id-123" - ), "Request ID mismatch" - ctx.nexus_context.outbound_links.extend(ctx.nexus_context.inbound_links) + assert ctx.request_id == "test-request-id-123", "Request ID mismatch" + ctx.outbound_links.extend(ctx.inbound_links) return await ctx.start_workflow( MyLinkTestWorkflow.run, From 6c1e8946cf2b4d380d0d25b71cb4454e458d3f78 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 08:39:38 -0400 Subject: [PATCH 162/183] Rename as temporalio.nexus.cancel_workflow --- temporalio/nexus/__init__.py | 1 - temporalio/nexus/_operation_handlers.py | 17 +++++++++++------ tests/nexus/test_workflow_caller.py | 4 ++-- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index dd9935b05..f25c4ac8c 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -12,5 +12,4 @@ from ._operation_context import client as client from ._operation_context import info as info from ._operation_context import logger as logger -from ._operation_handlers import cancel_operation as cancel_operation from ._token import WorkflowHandle as WorkflowHandle diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index 8d1253979..ecc286719 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -88,7 +88,7 @@ async def start( async def cancel(self, ctx: CancelOperationContext, token: str) -> None: """Cancel the operation, by cancelling the workflow.""" - await cancel_operation(token) + await _cancel_workflow(token) async def fetch_info( self, ctx: FetchOperationInfoContext, token: str @@ -126,21 +126,26 @@ async def fetch_result( # return await client_handle.result() -async def cancel_operation( +async def _cancel_workflow( token: str, **kwargs: Any, ) -> None: - """Cancel a Nexus operation. + """ + Cancel a workflow that is backing a Nexus operation. + + This function is used by the Nexus worker to cancel a workflow that is backing a + Nexus operation, i.e. started by a + :py:func:`temporalio.nexus.workflow_run_operation`-decorated method. Args: - token: The token of the operation to cancel. - client: The client to use to cancel the operation. + token: The token of the workflow to cancel. kwargs: Additional keyword arguments + to pass to the workflow cancel method. """ try: nexus_workflow_handle = WorkflowHandle[Any].from_token(token) except Exception as err: raise HandlerError( - "Failed to decode operation token as workflow operation token. " + "Failed to decode operation token as a workflow operation token. " "Canceling non-workflow operations is not supported.", type=HandlerErrorType.NOT_FOUND, ) from err diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 1d8faf612..496aeeb47 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -31,7 +31,7 @@ import temporalio.api.operatorservice import temporalio.api.operatorservice.v1 import temporalio.exceptions -import temporalio.nexus +import temporalio.nexus._operation_handlers from temporalio import nexus, workflow from temporalio.client import ( Client, @@ -180,7 +180,7 @@ async def start( # type: ignore[override] raise TypeError async def cancel(self, ctx: CancelOperationContext, token: str) -> None: - return await nexus.cancel_operation(token) + return await temporalio.nexus._operation_handlers._cancel_workflow(token) async def fetch_info( self, ctx: FetchOperationInfoContext, token: str From d5fce2a0e7343efb209a6d70bfe742ebc046fcfd Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 08:44:32 -0400 Subject: [PATCH 163/183] Fix test --- tests/nexus/test_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index a9b116c73..3413f93ad 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -257,10 +257,10 @@ async def start( input: Input, # This return type is a type error, but VSCode doesn't flag it unless # "python.analysis.typeCheckingMode" is set to "strict" - ) -> StartOperationResultSync[Output]: + ) -> Output: # Invalid: start method must wrap result as StartOperationResultSync # or StartOperationResultAsync - return StartOperationResultSync(Output(value="unwrapped result error")) # type: ignore + return Output(value="unwrapped result error") async def fetch_info( self, ctx: FetchOperationInfoContext, token: str From 8f3681b0967b6096883f4e694b5d5d699ee8d5e8 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 09:00:18 -0400 Subject: [PATCH 164/183] Cleanup --- temporalio/nexus/_token.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/temporalio/nexus/_token.py b/temporalio/nexus/_token.py index 18bf0dba8..e69ff07e4 100644 --- a/temporalio/nexus/_token.py +++ b/temporalio/nexus/_token.py @@ -120,13 +120,15 @@ def _base64url_encode_no_padding(data: bytes) -> str: return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=") +_base64_url_alphabet = set( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-" +) + + def _base64url_decode_no_padding(s: str) -> bytes: - if not all( - c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-" - for c in s - ): + if invalid_chars := set(s) - _base64_url_alphabet: raise ValueError( - "invalid base64URL encoded string: contains invalid characters" + f"invalid base64URL encoded string: contains invalid characters: {invalid_chars}" ) padding = "=" * (-len(s) % 4) return base64.urlsafe_b64decode(s + padding) From 8117ea15f2e86dd4bb5fea9d54f2ecec7dce8f6c Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 09:07:00 -0400 Subject: [PATCH 165/183] Impprovements from code review comments --- temporalio/nexus/__init__.py | 1 + temporalio/nexus/_operation_context.py | 7 +++++++ temporalio/worker/_nexus.py | 6 +++--- tests/nexus/test_handler.py | 2 ++ 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index f25c4ac8c..75fa5fba2 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -10,6 +10,7 @@ _TemporalStartOperationContext as _TemporalStartOperationContext, ) from ._operation_context import client as client +from ._operation_context import in_operation as in_operation from ._operation_context import info as info from ._operation_context import logger as logger from ._token import WorkflowHandle as WorkflowHandle diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 866335417..331684f79 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -62,6 +62,13 @@ class Info: """The task queue of the worker handling this Nexus operation.""" +def in_operation() -> bool: + """ + Whether the current code is inside a Nexus operation. + """ + return _try_temporal_context() is not None + + def info() -> Info: """ Get the current Nexus operation information. diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 65973e4a1..168e1ef6b 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -183,8 +183,8 @@ async def _handle_cancel_operation_task( ).set() try: await self._handler.cancel_operation(ctx, request.operation_token) - except Exception as err: - logger.exception("Failed to execute Nexus cancel operation method") + except BaseException as err: + logger.warning("Failed to execute Nexus cancel operation method") completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( task_token=task_token, error=await self._handler_error_to_proto( @@ -220,7 +220,7 @@ async def _handle_start_operation_task( try: start_response = await self._start_operation(start_request, headers) except BaseException as err: - logger.exception("Failed to execute Nexus start operation method") + logger.warning("Failed to execute Nexus start operation method") handler_err = _exception_to_handler_error(err) completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( task_token=task_token, diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 3413f93ad..75227d745 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -139,6 +139,7 @@ class MyServiceHandler: async def echo(self, ctx: StartOperationContext, input: Input) -> Output: assert ctx.headers["test-header-key"] == "test-header-value" ctx.outbound_links.extend(ctx.inbound_links) + assert nexus.in_operation() return Output( value=f"from start method on {self.__class__.__name__}: {input.value}" ) @@ -219,6 +220,7 @@ async def log(self, ctx: StartOperationContext, input: Input) -> Output: async def workflow_run_operation_happy_path( self, ctx: WorkflowRunOperationContext, input: Input ) -> nexus.WorkflowHandle[Output]: + assert nexus.in_operation() return await ctx.start_workflow( MyWorkflow.run, input, From 7b629550c1790e27e9ec172e2fc57ff996e2de8a Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 10:27:35 -0400 Subject: [PATCH 166/183] Expose nexus.LoggerAdapter --- temporalio/nexus/__init__.py | 1 + temporalio/nexus/_operation_context.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index 75fa5fba2..217e37565 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -1,5 +1,6 @@ from ._decorators import workflow_run_operation as workflow_run_operation from ._operation_context import Info as Info +from ._operation_context import LoggerAdapter as LoggerAdapter from ._operation_context import ( WorkflowRunOperationContext as WorkflowRunOperationContext, ) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 331684f79..47425ebbb 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -563,7 +563,7 @@ def _nexus_link_to_workflow_event( ) -class _LoggerAdapter(logging.LoggerAdapter): +class LoggerAdapter(logging.LoggerAdapter): def __init__(self, logger: logging.Logger, extra: Optional[Mapping[str, Any]]): super().__init__(logger, extra or {}) @@ -579,5 +579,5 @@ def process( return msg, kwargs -logger = _LoggerAdapter(logging.getLogger("temporalio.nexus"), None) +logger = LoggerAdapter(logging.getLogger("temporalio.nexus"), None) """Logger that emits additional data describing the current Nexus operation.""" From 4e34462ae4dfd3eb191c45cc17ff05c404c1eaa1 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 10:34:37 -0400 Subject: [PATCH 167/183] Add outbound links for sync responses also --- temporalio/worker/_nexus.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 168e1ef6b..1fb597af3 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -285,21 +285,23 @@ async def _start_operation( ) try: result = await self._handler.start_operation(ctx, input) + links = [ + temporalio.api.nexus.v1.Link(url=link.url, type=link.type) + for link in ctx.outbound_links + ] if isinstance(result, nexusrpc.handler.StartOperationResultAsync): return temporalio.api.nexus.v1.StartOperationResponse( async_success=temporalio.api.nexus.v1.StartOperationResponse.Async( operation_token=result.token, - links=[ - temporalio.api.nexus.v1.Link(url=link.url, type=link.type) - for link in ctx.outbound_links - ], + links=links, ) ) elif isinstance(result, nexusrpc.handler.StartOperationResultSync): [payload] = await self._data_converter.encode([result.value]) return temporalio.api.nexus.v1.StartOperationResponse( sync_success=temporalio.api.nexus.v1.StartOperationResponse.Sync( - payload=payload + payload=payload, + links=links, ) ) else: From 8889dde3417019ace217ccaa488ac71199fbdff5 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 10:48:42 -0400 Subject: [PATCH 168/183] Don't expose separate workflow.start_nexus_operation --- temporalio/workflow.py | 71 ++++++++++++++++-------------------------- 1 file changed, 27 insertions(+), 44 deletions(-) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index c51cfea05..2eecfb3ed 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -4399,44 +4399,6 @@ def operation_token(self) -> Optional[str]: raise NotImplementedError -async def start_nexus_operation( - endpoint: str, - service: str, - operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], - input: Any, - *, - output_type: Optional[Type[OutputT]] = None, - schedule_to_close_timeout: Optional[timedelta] = None, - headers: Optional[Mapping[str, str]] = None, -) -> NexusOperationHandle[OutputT]: - """Start a Nexus operation and return its handle. - - Args: - endpoint: The Nexus endpoint. - service: The Nexus service. - operation: The Nexus operation. - input: The Nexus operation input. - output_type: The Nexus operation output type. - schedule_to_close_timeout: Timeout for the entire operation attempt. - headers: Headers to send with the Nexus HTTP request. - - Returns: - A handle to the Nexus operation. The result can be obtained as - ```python - await handle.result() - ``` - """ - return await _Runtime.current().workflow_start_nexus_operation( - endpoint=endpoint, - service=service, - operation=operation, - input=input, - output_type=output_type, - schedule_to_close_timeout=schedule_to_close_timeout, - headers=headers, - ) - - class ExternalWorkflowHandle(Generic[SelfType]): """Handle for interacting with an external workflow. @@ -5157,19 +5119,25 @@ def __init__( *, endpoint: str, ) -> None: + """Create a Nexus client. + + Args: + service: The Nexus service. + endpoint: The Nexus endpoint. + """ # If service is not a str, then it must be a service interface or implementation # class. if isinstance(service, str): - self._service_name = service + self.service_name = service elif service_defn := nexusrpc.get_service_definition(service): - self._service_name = service_defn.name + self.service_name = service_defn.name else: raise ValueError( f"`service` may be a name (str), or a class decorated with either " f"@nexusrpc.handler.service_handler or @nexusrpc.service. " f"Invalid service type: {type(service)}" ) - self._endpoint = endpoint + self.endpoint = endpoint # TODO(nexus-prerelease): overloads: no-input, ret type # TODO(nexus-prerelease): should it be an error to use a reference to a method on a class other than that supplied? @@ -5182,9 +5150,24 @@ async def start_operation( schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: - return await temporalio.workflow.start_nexus_operation( - endpoint=self._endpoint, - service=self._service_name, + """Start a Nexus operation and return its handle. + + Args: + operation: The Nexus operation. + input: The Nexus operation input. + output_type: The Nexus operation output type. + schedule_to_close_timeout: Timeout for the entire operation attempt. + headers: Headers to send with the Nexus HTTP request. + + Returns: + A handle to the Nexus operation. The result can be obtained as + ```python + await handle.result() + ``` + """ + return await _Runtime.current().workflow_start_nexus_operation( + endpoint=self.endpoint, + service=self.service_name, operation=operation, input=input, output_type=output_type, From 41b9709d1b064421c7e848b492ac9723f5173a92 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 10:50:45 -0400 Subject: [PATCH 169/183] Remove unnecessary type hint --- temporalio/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 2eecfb3ed..04c2a63d3 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -5185,7 +5185,7 @@ async def execute_operation( schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: - handle: NexusOperationHandle[OutputT] = await self.start_operation( + handle = await self.start_operation( operation, input, output_type=output_type, From c2d482530f29f43f0bfa8941cc338a6c9863970d Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 11:21:44 -0400 Subject: [PATCH 170/183] New Nexus client constructor --- temporalio/workflow.py | 112 +++++++++++++++++++++------- tests/nexus/test_workflow_caller.py | 28 +++---- 2 files changed, 100 insertions(+), 40 deletions(-) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 04c2a63d3..34cc4a55c 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -5112,12 +5112,73 @@ def _to_proto(self) -> temporalio.bridge.proto.common.VersioningIntent.ValueType ServiceT = TypeVar("ServiceT") -class NexusClient(Generic[ServiceT]): +class NexusClient(ABC, Generic[ServiceT]): + """ + A client for invoking Nexus operations. + + example: + ```python + nexus_client = workflow.create_nexus_client( + endpoint=my_nexus_endpoint, + service=MyService, + ) + handle = await nexus_client.start_operation( + operation=MyService.my_operation, + input=MyOperationInput(value="hello"), + schedule_to_close_timeout=timedelta(seconds=10), + ) + result = await handle.result() + ``` + """ + + # TODO(nexus-prerelease): overloads: no-input, ret type + # TODO(nexus-prerelease): should it be an error to use a reference to a method on a class other than that supplied? + @abstractmethod + async def start_operation( + self, + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], + input: InputT, + *, + output_type: Optional[Type[OutputT]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> NexusOperationHandle[OutputT]: + """Start a Nexus operation and return its handle. + + Args: + operation: The Nexus operation. + input: The Nexus operation input. + output_type: The Nexus operation output type. + schedule_to_close_timeout: Timeout for the entire operation attempt. + headers: Headers to send with the Nexus HTTP request. + + Returns: + A handle to the Nexus operation. The result can be obtained as + ```python + await handle.result() + ``` + """ + ... + + # TODO(nexus-prerelease): overloads: no-input, ret type + @abstractmethod + async def execute_operation( + self, + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], + input: InputT, + *, + output_type: Optional[Type[OutputT]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> OutputT: ... + + +class _NexusClient(NexusClient[ServiceT]): def __init__( self, - service: Union[Type[ServiceT], str], *, endpoint: str, + service: Union[Type[ServiceT], str], ) -> None: """Create a Nexus client. @@ -5149,30 +5210,17 @@ async def start_operation( output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, - ) -> NexusOperationHandle[OutputT]: - """Start a Nexus operation and return its handle. - - Args: - operation: The Nexus operation. - input: The Nexus operation input. - output_type: The Nexus operation output type. - schedule_to_close_timeout: Timeout for the entire operation attempt. - headers: Headers to send with the Nexus HTTP request. - - Returns: - A handle to the Nexus operation. The result can be obtained as - ```python - await handle.result() - ``` - """ - return await _Runtime.current().workflow_start_nexus_operation( - endpoint=self.endpoint, - service=self.service_name, - operation=operation, - input=input, - output_type=output_type, - schedule_to_close_timeout=schedule_to_close_timeout, - headers=headers, + ) -> temporalio.workflow.NexusOperationHandle[OutputT]: + return ( + await temporalio.workflow._Runtime.current().workflow_start_nexus_operation( + endpoint=self.endpoint, + service=self.service_name, + operation=operation, + input=input, + output_type=output_type, + schedule_to_close_timeout=schedule_to_close_timeout, + headers=headers, + ) ) # TODO(nexus-prerelease): overloads: no-input, ret type @@ -5193,3 +5241,15 @@ async def execute_operation( headers=headers, ) return await handle + + +def create_nexus_client( + endpoint: str, service: Union[Type[ServiceT], str] +) -> NexusClient[ServiceT]: + """Create a Nexus client. + + Args: + endpoint: The Nexus endpoint. + service: The Nexus service. + """ + return _NexusClient(endpoint=endpoint, service=service) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 496aeeb47..f0e227c8e 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -261,12 +261,12 @@ def __init__( request_cancel: bool, task_queue: str, ) -> None: - self.nexus_client = workflow.NexusClient( + self.nexus_client = workflow.create_nexus_client( + endpoint=make_nexus_endpoint_name(task_queue), service={ CallerReference.IMPL_WITH_INTERFACE: ServiceImpl, CallerReference.INTERFACE: ServiceInterface, }[input.op_input.caller_reference], - endpoint=make_nexus_endpoint_name(task_queue), ) self._nexus_operation_started = False self._proceed = False @@ -384,9 +384,9 @@ def __init__( ) -> None: # TODO(nexus-preview): untyped caller cannot reference name of implementation. I think this is as it should be. service_name = "ServiceInterface" - self.nexus_client: workflow.NexusClient[Any] = workflow.NexusClient( - service=service_name, + self.nexus_client: workflow.NexusClient[Any] = workflow.create_nexus_client( endpoint=make_nexus_endpoint_name(task_queue), + service=service_name, ) @workflow.run @@ -818,9 +818,9 @@ async def run( f"Invalid combination of caller_reference ({caller_reference}) and name_override ({name_override})" ) - nexus_client = workflow.NexusClient( - service=service_cls, + nexus_client = workflow.create_nexus_client( endpoint=make_nexus_endpoint_name(task_queue), + service=service_cls, ) return await nexus_client.execute_operation(service_cls.op, None) # type: ignore @@ -942,9 +942,9 @@ async def my_workflow_run_operation( class WorkflowCallingNexusOperationThatExecutesWorkflowBeforeStartingBackingWorkflow: @workflow.run async def run(self, input: str, task_queue: str) -> str: - nexus_client = workflow.NexusClient( - service=ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow, + nexus_client = workflow.create_nexus_client( endpoint=make_nexus_endpoint_name(task_queue), + service=ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow, ) return await nexus_client.execute_operation( ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow.my_workflow_run_operation, @@ -1402,9 +1402,9 @@ async def op(self, ctx: StartOperationContext, input: ErrorTestInput) -> None: class ErrorTestCallerWorkflow: @workflow.init def __init__(self, input: ErrorTestInput): - self.nexus_client = workflow.NexusClient( - service=ErrorTestService, + self.nexus_client = workflow.create_nexus_client( endpoint=make_nexus_endpoint_name(input.task_queue), + service=ErrorTestService, ) self.test_cases = {t.name: t for t in error_conversion_test_cases} @@ -1495,9 +1495,9 @@ async def op_handler_that_never_returns( class TimeoutTestCallerWorkflow: @workflow.init def __init__(self): - self.nexus_client = workflow.NexusClient( - service=TimeoutTestService, + self.nexus_client = workflow.create_nexus_client( endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + service=TimeoutTestService, ) @workflow.run @@ -1621,9 +1621,9 @@ class OverloadTestInput: class OverloadTestCallerWorkflow: @workflow.run async def run(self, op: str, input: OverloadTestValue) -> OverloadTestValue: - nexus_client = workflow.NexusClient( - service=OverloadTestServiceHandler, + nexus_client = workflow.create_nexus_client( endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + service=OverloadTestServiceHandler, ) if op == "no_param": return await nexus_client.execute_operation( From a4fd205263776dc5542e5b7ae811b992c6b5d306 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 11:35:21 -0400 Subject: [PATCH 171/183] Remove unused test helper methods --- tests/helpers/nexus.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/tests/helpers/nexus.py b/tests/helpers/nexus.py index 46460d77c..4452944da 100644 --- a/tests/helpers/nexus.py +++ b/tests/helpers/nexus.py @@ -63,30 +63,6 @@ async def start_operation( headers=headers, ) - async def fetch_operation_info( - self, - operation: str, - token: str, - ) -> httpx.Response: - async with httpx.AsyncClient() as http_client: - return await http_client.get( - f"{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}", - # Token can also be sent as "Nexus-Operation-Token" header - params={"token": token}, - ) - - async def fetch_operation_result( - self, - operation: str, - token: str, - ) -> httpx.Response: - async with httpx.AsyncClient() as http_client: - return await http_client.get( - f"{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}/result", - # Token can also be sent as "Nexus-Operation-Token" header - params={"token": token}, - ) - async def cancel_operation( self, operation: str, From a5f67d2b6cde24ba381a2a407b03fec45f6d6906 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 11:37:42 -0400 Subject: [PATCH 172/183] Clean up token type --- temporalio/nexus/_token.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/temporalio/nexus/_token.py b/temporalio/nexus/_token.py index e69ff07e4..480a404b1 100644 --- a/temporalio/nexus/_token.py +++ b/temporalio/nexus/_token.py @@ -19,7 +19,6 @@ class WorkflowHandle(Generic[OutputT]): namespace: str workflow_id: str - _type: OperationTokenType = OPERATION_TOKEN_TYPE_WORKFLOW # Version of the token. Treated as v1 if missing. This field is not included in the # serialized token; it's only used to reject newer token versions on load. version: Optional[int] = None @@ -56,7 +55,7 @@ def to_token(self) -> str: return _base64url_encode_no_padding( json.dumps( { - "t": self._type, + "t": OPERATION_TOKEN_TYPE_WORKFLOW, "ns": self.namespace, "wid": self.workflow_id, }, @@ -83,10 +82,10 @@ def from_token(cls, token: str) -> WorkflowHandle[OutputT]: f"invalid workflow token: expected dict, got {type(workflow_operation_token)}" ) - _type = workflow_operation_token.get("t") - if _type != OPERATION_TOKEN_TYPE_WORKFLOW: + token_type = workflow_operation_token.get("t") + if token_type != OPERATION_TOKEN_TYPE_WORKFLOW: raise TypeError( - f"invalid workflow token type: {_type}, expected: {OPERATION_TOKEN_TYPE_WORKFLOW}" + f"invalid workflow token type: {token_type}, expected: {OPERATION_TOKEN_TYPE_WORKFLOW}" ) version = workflow_operation_token.get("v") @@ -109,7 +108,6 @@ def from_token(cls, token: str) -> WorkflowHandle[OutputT]: ) return cls( - _type=_type, namespace=namespace, workflow_id=workflow_id, version=version, From f7f8a51e1c7236d154c61eb00673ec0bc98c6ceb Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 11:52:03 -0400 Subject: [PATCH 173/183] Refactor start timeout test --- tests/nexus/test_workflow_caller.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index f0e227c8e..83f7c7d4b 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -1481,9 +1481,9 @@ async def test_errors_raised_by_nexus_operation( ) -# Timeout test +# Start timeout test @service_handler -class TimeoutTestService: +class StartTimeoutTestService: @sync_operation async def op_handler_that_never_returns( self, ctx: StartOperationContext, input: None @@ -1492,35 +1492,45 @@ async def op_handler_that_never_returns( @workflow.defn -class TimeoutTestCallerWorkflow: +class StartTimeoutTestCallerWorkflow: @workflow.init def __init__(self): self.nexus_client = workflow.create_nexus_client( endpoint=make_nexus_endpoint_name(workflow.info().task_queue), - service=TimeoutTestService, + service=StartTimeoutTestService, ) @workflow.run async def run(self) -> None: await self.nexus_client.execute_operation( - TimeoutTestService.op_handler_that_never_returns, + StartTimeoutTestService.op_handler_that_never_returns, None, schedule_to_close_timeout=timedelta(seconds=0.1), ) -async def test_timeout_error_raised_by_nexus_operation(client: Client): +async def test_error_raised_by_timeout_of_nexus_start_operation(client: Client): task_queue = str(uuid.uuid4()) async with Worker( client, - nexus_service_handlers=[TimeoutTestService()], - workflows=[TimeoutTestCallerWorkflow], + nexus_service_handlers=[StartTimeoutTestService()], + workflows=[StartTimeoutTestCallerWorkflow], task_queue=task_queue, ): await create_nexus_endpoint(task_queue, client) try: await client.execute_workflow( - TimeoutTestCallerWorkflow.run, + StartTimeoutTestCallerWorkflow.run, + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + except Exception as err: + assert isinstance(err, WorkflowFailureError) + assert isinstance(err.__cause__, NexusOperationError) + assert isinstance(err.__cause__.__cause__, TimeoutError) + else: + pytest.fail("Expected exception due to timeout of nexus start operation") + id=str(uuid.uuid4()), task_queue=task_queue, ) From 01960ed11fa934efc981091382eb1636492b595e Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 11:52:20 -0400 Subject: [PATCH 174/183] Cancellation timeout test --- tests/nexus/test_workflow_caller.py | 66 +++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 83f7c7d4b..5b87a73e2 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -1531,6 +1531,70 @@ async def test_error_raised_by_timeout_of_nexus_start_operation(client: Client): else: pytest.fail("Expected exception due to timeout of nexus start operation") + +# Cancellation timeout test + + +class OperationWithCancelMethodThatNeverReturns(OperationHandler[None, None]): + async def start( + self, ctx: StartOperationContext, input: None + ) -> StartOperationResultAsync: + return StartOperationResultAsync("fake-token") + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + await asyncio.Future() + + async def fetch_info( + self, ctx: FetchOperationInfoContext, token: str + ) -> nexusrpc.OperationInfo: + raise NotImplementedError("Not implemented") + + async def fetch_result(self, ctx: FetchOperationResultContext, token: str) -> None: + raise NotImplementedError("Not implemented") + + +@service_handler +class CancellationTimeoutTestService: + @nexusrpc.handler._decorators.operation_handler + def op_with_cancel_method_that_never_returns( + self, + ) -> OperationHandler[None, None]: + return OperationWithCancelMethodThatNeverReturns() + + +@workflow.defn +class CancellationTimeoutTestCallerWorkflow: + @workflow.init + def __init__(self): + self.nexus_client = workflow.create_nexus_client( + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + service=CancellationTimeoutTestService, + ) + + @workflow.run + async def run(self) -> None: + op_handle = await self.nexus_client.start_operation( + CancellationTimeoutTestService.op_with_cancel_method_that_never_returns, + None, + schedule_to_close_timeout=timedelta(seconds=0.1), + ) + op_handle.cancel() + await op_handle + + +async def test_error_raised_by_timeout_of_nexus_cancel_operation(client: Client): + pytest.skip("TODO(nexus-prerelease): finish writing this test") + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_service_handlers=[CancellationTimeoutTestService()], + workflows=[CancellationTimeoutTestCallerWorkflow], + task_queue=task_queue, + ): + await create_nexus_endpoint(task_queue, client) + try: + await client.execute_workflow( + CancellationTimeoutTestCallerWorkflow.run, id=str(uuid.uuid4()), task_queue=task_queue, ) @@ -1538,6 +1602,8 @@ async def test_error_raised_by_timeout_of_nexus_start_operation(client: Client): assert isinstance(err, WorkflowFailureError) assert isinstance(err.__cause__, NexusOperationError) assert isinstance(err.__cause__.__cause__, TimeoutError) + else: + pytest.fail("Expected exception due to timeout of nexus cancel operation") # Test overloads From 5103f807e5c2fafc8466f28352d6b1a22e3152c9 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 12:49:54 -0400 Subject: [PATCH 175/183] Create running_task for cancellation op handler --- temporalio/worker/_nexus.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 1fb597af3..1adb67b3e 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -102,9 +102,7 @@ async def raise_from_exception_queue() -> NoReturn: ) ) elif task.request.HasField("cancel_operation"): - # TODO(nexus-prerelease): do we need to track cancel operation - # tasks as we do start operation tasks? - asyncio.create_task( + self._running_tasks[task.task_token] = asyncio.create_task( self._handle_cancel_operation_task( task.task_token, task.request.cancel_operation, From 03513b597e3d54d005b66b55ac1ae34abd9f2413 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 13:20:55 -0400 Subject: [PATCH 176/183] Test creation of worker from ServiceHandler instances --- ...ynamic_creation_of_user_handler_classes.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index f2d1ec84e..64f9484ee 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -5,6 +5,7 @@ import pytest from nexusrpc.handler import sync_operation +from temporalio import nexus, workflow from temporalio.client import Client from temporalio.nexus._util import get_operation_factory from temporalio.worker import Worker @@ -13,6 +14,55 @@ HTTP_PORT = 7243 +@workflow.defn +class MyWorkflow: + @workflow.run + async def run(self, input: int) -> int: + return input + 1 + + +@nexusrpc.handler.service_handler +class MyServiceHandlerWithWorkflowRunOperation: + @nexus.workflow_run_operation + async def increment( + self, + ctx: nexus.WorkflowRunOperationContext, + input: int, + ) -> nexus.WorkflowHandle[int]: + return await ctx.start_workflow(MyWorkflow.run, input, id=str(uuid.uuid4())) + + +async def test_run_nexus_service_from_programmatically_created_service_handler( + client: Client, +): + task_queue = str(uuid.uuid4()) + + user_service_handler_instance = MyServiceHandlerWithWorkflowRunOperation() + service_handler = nexusrpc.handler._core.ServiceHandler.from_user_instance( + user_service_handler_instance + ) + + assert ( + service_defn := nexusrpc.get_service_definition( + user_service_handler_instance.__class__ + ) + ) + service_name = service_defn.name + + endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id + async with Worker( + client, + task_queue=task_queue, + nexus_service_handlers=[service_handler], + ): + async with httpx.AsyncClient() as http_client: + response = await http_client.post( + f"http://127.0.0.1:{HTTP_PORT}/nexus/endpoints/{endpoint}/services/{service_name}/increment", + json=1, + ) + assert response.status_code == 201 + + def make_incrementer_user_service_definition_and_service_handler_classes( op_names: list[str], ) -> tuple[type, type]: From b6bcd6cdd33444184f24a94b60e3145032e4ff90 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 15:58:00 -0400 Subject: [PATCH 177/183] Test creation of worker from programmatically-created ServiceHandler --- ...ynamic_creation_of_user_handler_classes.py | 74 +++++++++++++++---- 1 file changed, 58 insertions(+), 16 deletions(-) diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index 64f9484ee..26f94c122 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -21,15 +21,50 @@ async def run(self, input: int) -> int: return input + 1 -@nexusrpc.handler.service_handler -class MyServiceHandlerWithWorkflowRunOperation: - @nexus.workflow_run_operation - async def increment( +@nexusrpc.service +class MyService: + increment: nexusrpc.Operation[int, int] + + +class MyIncrementOperationHandler(nexusrpc.handler.OperationHandler[int, int]): + async def start( self, - ctx: nexus.WorkflowRunOperationContext, + ctx: nexusrpc.handler.StartOperationContext, input: int, - ) -> nexus.WorkflowHandle[int]: - return await ctx.start_workflow(MyWorkflow.run, input, id=str(uuid.uuid4())) + ) -> nexusrpc.handler.StartOperationResultAsync: + wrctx = nexus.WorkflowRunOperationContext.from_start_operation_context(ctx) + wf_handle = await wrctx.start_workflow( + MyWorkflow.run, input, id=str(uuid.uuid4()) + ) + return nexusrpc.handler.StartOperationResultAsync(token=wf_handle.to_token()) + + async def cancel( + self, + ctx: nexusrpc.handler.CancelOperationContext, + token: str, + ) -> None: + raise NotImplementedError + + async def fetch_info( + self, + ctx: nexusrpc.handler.FetchOperationInfoContext, + token: str, + ) -> nexusrpc.OperationInfo: + raise NotImplementedError + + async def fetch_result( + self, + ctx: nexusrpc.handler.FetchOperationResultContext, + token: str, + ) -> int: + raise NotImplementedError + + +@nexusrpc.handler.service_handler +class MyServiceHandlerWithWorkflowRunOperation: + @nexusrpc.handler._decorators.operation_handler + def increment(self) -> nexusrpc.handler.OperationHandler[int, int]: + return MyIncrementOperationHandler() async def test_run_nexus_service_from_programmatically_created_service_handler( @@ -37,17 +72,24 @@ async def test_run_nexus_service_from_programmatically_created_service_handler( ): task_queue = str(uuid.uuid4()) - user_service_handler_instance = MyServiceHandlerWithWorkflowRunOperation() - service_handler = nexusrpc.handler._core.ServiceHandler.from_user_instance( - user_service_handler_instance + service_handler = nexusrpc.handler._core.ServiceHandler( + service=nexusrpc.ServiceDefinition( + name="MyService", + operations={ + "increment": nexusrpc.Operation[int, int]( + name="increment", + method_name="increment", + input_type=int, + output_type=int, + ), + }, + ), + operation_handlers={ + "increment": MyIncrementOperationHandler(), + }, ) - assert ( - service_defn := nexusrpc.get_service_definition( - user_service_handler_instance.__class__ - ) - ) - service_name = service_defn.name + service_name = service_handler.service.name endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id async with Worker( From b5abf2b7be29145704b7515fef1f25d0a4812b93 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 3 Jul 2025 16:01:58 -0400 Subject: [PATCH 178/183] Respond to upstream: new nexus client constructor --- temporalio/contrib/openai_agents/temporal_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/temporalio/contrib/openai_agents/temporal_tools.py b/temporalio/contrib/openai_agents/temporal_tools.py index 7083e0bf2..d8cac9ad5 100644 --- a/temporalio/contrib/openai_agents/temporal_tools.py +++ b/temporalio/contrib/openai_agents/temporal_tools.py @@ -181,7 +181,7 @@ async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any: f"Invalid JSON input for tool {schema.name}: {input}" ) from e - nexus_client = workflow.NexusClient(service=service, endpoint=endpoint) + nexus_client = workflow.create_nexus_client(endpoint=endpoint, service=service) args, _ = schema.to_call_args(schema.params_pydantic_model(**json_data)) assert len(args) == 1, "Nexus operations must have exactly one argument" [arg] = args From f25e108b265cad33d0d687f376353d5d5976d60d Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 8 Jul 2025 09:19:41 -0400 Subject: [PATCH 179/183] function schema adaptation workaround vendor function_schema and use supplied global namespace when resolving type hints --- .../contrib/openai_agents/temporal_tools.py | 192 +++++++++++++++++- 1 file changed, 184 insertions(+), 8 deletions(-) diff --git a/temporalio/contrib/openai_agents/temporal_tools.py b/temporalio/contrib/openai_agents/temporal_tools.py index d8cac9ad5..c654f0e84 100644 --- a/temporalio/contrib/openai_agents/temporal_tools.py +++ b/temporalio/contrib/openai_agents/temporal_tools.py @@ -1,9 +1,13 @@ """Support for using Temporal activities as OpenAI agents tools.""" +from __future__ import annotations + +import inspect import json -import typing from datetime import timedelta -from typing import Any, Callable, Optional, Type +from typing import Any, Callable, Optional, Type, get_args, get_origin, get_type_hints + +from pydantic import BaseModel, Field, create_model from temporalio import activity, workflow from temporalio.common import Priority, RetryPolicy @@ -13,7 +17,12 @@ with unsafe.imports_passed_through(): from agents import FunctionTool, RunContextWrapper, Tool - from agents.function_schema import function_schema + from agents.function_schema import ( + FuncSchema, + ToolContext, + ensure_strict_json_schema, + generate_func_documentation, + ) class ToolSerializationError(TemporalError): @@ -120,11 +129,12 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any: def nexus_operation_as_tool( - fn: Callable, + fn: Callable[..., Any], *, service: Type[Any], endpoint: str, schedule_to_close_timeout: Optional[timedelta] = None, + function_schema_globalns: Optional[dict[str, Any]] = None, ) -> Tool: """Convert a Nexus operation into an OpenAI agent tool. @@ -171,7 +181,10 @@ def nexus_operation_as_tool( "invalid_tool", ) - schema = function_schema(adapt_nexus_operation_function_schema(fn)) + schema = function_schema( + adapt_nexus_operation_function_schema(fn), + globalns=function_schema_globalns, + ) async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any: try: @@ -209,10 +222,173 @@ async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any: def adapt_nexus_operation_function_schema(fn: Callable[..., Any]) -> Callable[..., Any]: # Nexus operation start methods look like # async def operation(self, ctx: StartOperationContext, input: InputType) -> OutputType - _, inputT, retT = typing.get_type_hints(fn).values() + _, input_type, ret_type = fn.__annotations__.values() - def adapted(input: inputT) -> retT: # type: ignore + def adapted(input): pass - adapted.__name__ = fn.__name__ + adapted.__annotations__ = {"input": input_type, "return": ret_type} return adapted + + +def function_schema( + func: Callable[..., Any], + docstring_style: DocstringStyle | None = None, + name_override: str | None = None, + description_override: str | None = None, + use_docstring_info: bool = True, + strict_json_schema: bool = True, + globalns: Optional[dict[str, Any]] = None, +) -> FuncSchema: + """ + Given a python function, extracts a `FuncSchema` from it, capturing the name, description, + parameter descriptions, and other metadata. + + Args: + func: The function to extract the schema from. + docstring_style: The style of the docstring to use for parsing. If not provided, we will + attempt to auto-detect the style. + name_override: If provided, use this name instead of the function's `__name__`. + description_override: If provided, use this description instead of the one derived from the + docstring. + use_docstring_info: If True, uses the docstring to generate the description and parameter + descriptions. + strict_json_schema: Whether the JSON schema is in strict mode. If True, we'll ensure that + the schema adheres to the "strict" standard the OpenAI API expects. We **strongly** + recommend setting this to True, as it increases the likelihood of the LLM providing + correct JSON input. + + Returns: + A `FuncSchema` object containing the function's name, description, parameter descriptions, + and other metadata. + """ + + # 1. Grab docstring info + if use_docstring_info: + doc_info = generate_func_documentation(func, docstring_style) + param_descs = doc_info.param_descriptions or {} + else: + doc_info = None + param_descs = {} + + # Ensure name_override takes precedence even if docstring info is disabled. + func_name = name_override or (doc_info.name if doc_info else func.__name__) + + # 2. Inspect function signature and get type hints + sig = inspect.signature(func) + type_hints = get_type_hints(func, globalns=globalns) + params = list(sig.parameters.items()) + takes_context = False + filtered_params = [] + + if params: + first_name, first_param = params[0] + # Prefer the evaluated type hint if available + ann = type_hints.get(first_name, first_param.annotation) + if ann != inspect._empty: + origin = get_origin(ann) or ann + if origin is RunContextWrapper or origin is ToolContext: + takes_context = True # Mark that the function takes context + else: + filtered_params.append((first_name, first_param)) + else: + filtered_params.append((first_name, first_param)) + + # For parameters other than the first, raise error if any use RunContextWrapper or ToolContext. + for name, param in params[1:]: + ann = type_hints.get(name, param.annotation) + if ann != inspect._empty: + origin = get_origin(ann) or ann + if origin is RunContextWrapper or origin is ToolContext: + raise UserError( + f"RunContextWrapper/ToolContext param found at non-first position in function" + f" {func.__name__}" + ) + filtered_params.append((name, param)) + + # We will collect field definitions for create_model as a dict: + # field_name -> (type_annotation, default_value_or_Field(...)) + fields: dict[str, Any] = {} + + for name, param in filtered_params: + ann = type_hints.get(name, param.annotation) + default = param.default + + # If there's no type hint, assume `Any` + if ann == inspect._empty: + ann = Any + + # If a docstring param description exists, use it + field_description = param_descs.get(name, None) + + # Handle different parameter kinds + if param.kind == param.VAR_POSITIONAL: + # e.g. *args: extend positional args + if get_origin(ann) is tuple: + # e.g. def foo(*args: tuple[int, ...]) -> treat as List[int] + args_of_tuple = get_args(ann) + if len(args_of_tuple) == 2 and args_of_tuple[1] is Ellipsis: + ann = list[args_of_tuple[0]] # type: ignore + else: + ann = list[Any] + else: + # If user wrote *args: int, treat as List[int] + ann = list[ann] # type: ignore + + # Default factory to empty list + fields[name] = ( + ann, + Field(default_factory=list, description=field_description), # type: ignore + ) + + elif param.kind == param.VAR_KEYWORD: + # **kwargs handling + if get_origin(ann) is dict: + # e.g. def foo(**kwargs: dict[str, int]) + dict_args = get_args(ann) + if len(dict_args) == 2: + ann = dict[dict_args[0], dict_args[1]] # type: ignore + else: + ann = dict[str, Any] + else: + # e.g. def foo(**kwargs: int) -> Dict[str, int] + ann = dict[str, ann] # type: ignore + + fields[name] = ( + ann, + Field(default_factory=dict, description=field_description), # type: ignore + ) + + else: + # Normal parameter + if default == inspect._empty: + # Required field + fields[name] = ( + ann, + Field(..., description=field_description), + ) + else: + # Parameter with a default value + fields[name] = ( + ann, + Field(default=default, description=field_description), + ) + + # 3. Dynamically build a Pydantic model + dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields) + + # 4. Build JSON schema from that model + json_schema = dynamic_model.model_json_schema() + if strict_json_schema: + json_schema = ensure_strict_json_schema(json_schema) + + # 5. Return as a FuncSchema dataclass + return FuncSchema( + name=func_name, + description=description_override or doc_info.description if doc_info else None, + params_pydantic_model=dynamic_model, + params_json_schema=json_schema, + signature=sig, + takes_context=takes_context, + strict_json_schema=strict_json_schema, + ) From 3bed524d08f1b5a1a21498adf7de3af64c3aa913 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 8 Jul 2025 11:11:47 -0400 Subject: [PATCH 180/183] Retain name on adapted function --- temporalio/contrib/openai_agents/temporal_tools.py | 1 + 1 file changed, 1 insertion(+) diff --git a/temporalio/contrib/openai_agents/temporal_tools.py b/temporalio/contrib/openai_agents/temporal_tools.py index c654f0e84..d57e20fe7 100644 --- a/temporalio/contrib/openai_agents/temporal_tools.py +++ b/temporalio/contrib/openai_agents/temporal_tools.py @@ -228,6 +228,7 @@ def adapted(input): pass adapted.__annotations__ = {"input": input_type, "return": ret_type} + adapted.__name__ = fn.__name__ return adapted From d8795eb3b60dcc0cfe8c4ad576bf763d13dd478e Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 8 Jul 2025 12:06:57 -0400 Subject: [PATCH 181/183] Call nexus operation via contract --- .../contrib/openai_agents/temporal_tools.py | 47 ++++++++----------- tests/contrib/openai_agents/test_openai.py | 17 ++++--- 2 files changed, 30 insertions(+), 34 deletions(-) diff --git a/temporalio/contrib/openai_agents/temporal_tools.py b/temporalio/contrib/openai_agents/temporal_tools.py index d57e20fe7..352197a80 100644 --- a/temporalio/contrib/openai_agents/temporal_tools.py +++ b/temporalio/contrib/openai_agents/temporal_tools.py @@ -7,12 +7,12 @@ from datetime import timedelta from typing import Any, Callable, Optional, Type, get_args, get_origin, get_type_hints +import nexusrpc from pydantic import BaseModel, Field, create_model from temporalio import activity, workflow from temporalio.common import Priority, RetryPolicy from temporalio.exceptions import ApplicationError, TemporalError -from temporalio.nexus._util import get_operation_factory from temporalio.workflow import ActivityCancellationType, VersioningIntent, unsafe with unsafe.imports_passed_through(): @@ -129,7 +129,7 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any: def nexus_operation_as_tool( - fn: Callable[..., Any], + operation: nexusrpc.Operation[nexusrpc.InputT, nexusrpc.OutputT], *, service: Type[Any], endpoint: str, @@ -159,30 +159,20 @@ def nexus_operation_as_tool( ApplicationError: If the operation is not properly decorated as a Nexus operation. Example: - >>> @service_handler - >>> class WeatherServiceHandler: - ... @sync_operation - ... async def get_weather_object(self, ctx: StartOperationContext, input: WeatherInput) -> Weather: - ... return Weather( - ... city=input.city, temperature_range="14-20C", conditions="Sunny with wind." - ... ) + >>> @nexusrpc.service + ... class WeatherService: + ... get_weather_object_nexus_operation: nexusrpc.Operation[WeatherInput, Weather] >>> >>> # Create tool with custom activity options >>> tool = nexus_operation_as_tool( - ... WeatherServiceHandler.get_weather_object, - ... service=WeatherServiceHandler, + ... WeatherService.get_weather_object_nexus_operation, + ... service=WeatherService, ... endpoint="weather-service", ... ) >>> # Use tool with an OpenAI agent """ - if not get_operation_factory(fn): - raise ApplicationError( - "Function is not a Nexus operation", - "invalid_tool", - ) - schema = function_schema( - adapt_nexus_operation_function_schema(fn), + adapt_nexus_operation_function_schema(operation), globalns=function_schema_globalns, ) @@ -199,7 +189,7 @@ async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any: assert len(args) == 1, "Nexus operations must have exactly one argument" [arg] = args result = await nexus_client.execute_operation( - fn, + operation, arg, schedule_to_close_timeout=schedule_to_close_timeout, ) @@ -219,16 +209,17 @@ async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any: ) -def adapt_nexus_operation_function_schema(fn: Callable[..., Any]) -> Callable[..., Any]: - # Nexus operation start methods look like - # async def operation(self, ctx: StartOperationContext, input: InputType) -> OutputType - _, input_type, ret_type = fn.__annotations__.values() - - def adapted(input): - pass +def adapt_nexus_operation_function_schema( + operation: nexusrpc.Operation[nexusrpc.InputT, nexusrpc.OutputT], +) -> Callable[[nexusrpc.InputT], nexusrpc.OutputT]: + def adapted(input: nexusrpc.InputT) -> nexusrpc.OutputT: + raise NotImplementedError("This function definition is used as a type only") - adapted.__annotations__ = {"input": input_type, "return": ret_type} - adapted.__name__ = fn.__name__ + adapted.__annotations__ = { + "input": operation.input_type, + "return": operation.output_type, + } + adapted.__name__ = operation.name return adapted diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index a87853471..9a6fc60fa 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -4,8 +4,8 @@ from datetime import timedelta from typing import Any, Optional, Union, no_type_check +import nexusrpc import pytest -from nexusrpc.handler import StartOperationContext, service_handler, sync_operation from pydantic import ConfigDict, Field from temporalio import activity, workflow @@ -228,11 +228,16 @@ async def get_weather_object(input: WeatherInput) -> Weather: ) -@service_handler +@nexusrpc.service +class WeatherService: + get_weather_object_nexus_operation: nexusrpc.Operation[WeatherInput, Weather] + + +@nexusrpc.handler.service_handler(service=WeatherService) class WeatherServiceHandler: - @sync_operation + @nexusrpc.handler.sync_operation async def get_weather_object_nexus_operation( - self, ctx: StartOperationContext, input: WeatherInput + self, ctx: nexusrpc.handler.StartOperationContext, input: WeatherInput ) -> Weather: return Weather( city=input.city, temperature_range="14-20C", conditions="Sunny with wind." @@ -337,8 +342,8 @@ async def run(self, question: str) -> str: get_weather_country, start_to_close_timeout=timedelta(seconds=10) ), nexus_operation_as_tool( - WeatherServiceHandler.get_weather_object_nexus_operation, - service=WeatherServiceHandler, + WeatherService.get_weather_object_nexus_operation, + service=WeatherService, endpoint=make_nexus_endpoint_name(workflow.info().task_queue), schedule_to_close_timeout=timedelta(seconds=10), ), From 0a30af84b1defc072f77fbaa19568ae5789a5c00 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 8 Jul 2025 12:29:59 -0400 Subject: [PATCH 182/183] Revert vendoring --- .../contrib/openai_agents/temporal_tools.py | 179 +----------------- 1 file changed, 3 insertions(+), 176 deletions(-) diff --git a/temporalio/contrib/openai_agents/temporal_tools.py b/temporalio/contrib/openai_agents/temporal_tools.py index 352197a80..eb4ea8645 100644 --- a/temporalio/contrib/openai_agents/temporal_tools.py +++ b/temporalio/contrib/openai_agents/temporal_tools.py @@ -2,13 +2,11 @@ from __future__ import annotations -import inspect import json from datetime import timedelta -from typing import Any, Callable, Optional, Type, get_args, get_origin, get_type_hints +from typing import Any, Callable, Optional, Type import nexusrpc -from pydantic import BaseModel, Field, create_model from temporalio import activity, workflow from temporalio.common import Priority, RetryPolicy @@ -17,12 +15,7 @@ with unsafe.imports_passed_through(): from agents import FunctionTool, RunContextWrapper, Tool - from agents.function_schema import ( - FuncSchema, - ToolContext, - ensure_strict_json_schema, - generate_func_documentation, - ) + from agents.function_schema import function_schema class ToolSerializationError(TemporalError): @@ -171,10 +164,7 @@ def nexus_operation_as_tool( ... ) >>> # Use tool with an OpenAI agent """ - schema = function_schema( - adapt_nexus_operation_function_schema(operation), - globalns=function_schema_globalns, - ) + schema = function_schema(adapt_nexus_operation_function_schema(operation)) async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any: try: @@ -221,166 +211,3 @@ def adapted(input: nexusrpc.InputT) -> nexusrpc.OutputT: } adapted.__name__ = operation.name return adapted - - -def function_schema( - func: Callable[..., Any], - docstring_style: DocstringStyle | None = None, - name_override: str | None = None, - description_override: str | None = None, - use_docstring_info: bool = True, - strict_json_schema: bool = True, - globalns: Optional[dict[str, Any]] = None, -) -> FuncSchema: - """ - Given a python function, extracts a `FuncSchema` from it, capturing the name, description, - parameter descriptions, and other metadata. - - Args: - func: The function to extract the schema from. - docstring_style: The style of the docstring to use for parsing. If not provided, we will - attempt to auto-detect the style. - name_override: If provided, use this name instead of the function's `__name__`. - description_override: If provided, use this description instead of the one derived from the - docstring. - use_docstring_info: If True, uses the docstring to generate the description and parameter - descriptions. - strict_json_schema: Whether the JSON schema is in strict mode. If True, we'll ensure that - the schema adheres to the "strict" standard the OpenAI API expects. We **strongly** - recommend setting this to True, as it increases the likelihood of the LLM providing - correct JSON input. - - Returns: - A `FuncSchema` object containing the function's name, description, parameter descriptions, - and other metadata. - """ - - # 1. Grab docstring info - if use_docstring_info: - doc_info = generate_func_documentation(func, docstring_style) - param_descs = doc_info.param_descriptions or {} - else: - doc_info = None - param_descs = {} - - # Ensure name_override takes precedence even if docstring info is disabled. - func_name = name_override or (doc_info.name if doc_info else func.__name__) - - # 2. Inspect function signature and get type hints - sig = inspect.signature(func) - type_hints = get_type_hints(func, globalns=globalns) - params = list(sig.parameters.items()) - takes_context = False - filtered_params = [] - - if params: - first_name, first_param = params[0] - # Prefer the evaluated type hint if available - ann = type_hints.get(first_name, first_param.annotation) - if ann != inspect._empty: - origin = get_origin(ann) or ann - if origin is RunContextWrapper or origin is ToolContext: - takes_context = True # Mark that the function takes context - else: - filtered_params.append((first_name, first_param)) - else: - filtered_params.append((first_name, first_param)) - - # For parameters other than the first, raise error if any use RunContextWrapper or ToolContext. - for name, param in params[1:]: - ann = type_hints.get(name, param.annotation) - if ann != inspect._empty: - origin = get_origin(ann) or ann - if origin is RunContextWrapper or origin is ToolContext: - raise UserError( - f"RunContextWrapper/ToolContext param found at non-first position in function" - f" {func.__name__}" - ) - filtered_params.append((name, param)) - - # We will collect field definitions for create_model as a dict: - # field_name -> (type_annotation, default_value_or_Field(...)) - fields: dict[str, Any] = {} - - for name, param in filtered_params: - ann = type_hints.get(name, param.annotation) - default = param.default - - # If there's no type hint, assume `Any` - if ann == inspect._empty: - ann = Any - - # If a docstring param description exists, use it - field_description = param_descs.get(name, None) - - # Handle different parameter kinds - if param.kind == param.VAR_POSITIONAL: - # e.g. *args: extend positional args - if get_origin(ann) is tuple: - # e.g. def foo(*args: tuple[int, ...]) -> treat as List[int] - args_of_tuple = get_args(ann) - if len(args_of_tuple) == 2 and args_of_tuple[1] is Ellipsis: - ann = list[args_of_tuple[0]] # type: ignore - else: - ann = list[Any] - else: - # If user wrote *args: int, treat as List[int] - ann = list[ann] # type: ignore - - # Default factory to empty list - fields[name] = ( - ann, - Field(default_factory=list, description=field_description), # type: ignore - ) - - elif param.kind == param.VAR_KEYWORD: - # **kwargs handling - if get_origin(ann) is dict: - # e.g. def foo(**kwargs: dict[str, int]) - dict_args = get_args(ann) - if len(dict_args) == 2: - ann = dict[dict_args[0], dict_args[1]] # type: ignore - else: - ann = dict[str, Any] - else: - # e.g. def foo(**kwargs: int) -> Dict[str, int] - ann = dict[str, ann] # type: ignore - - fields[name] = ( - ann, - Field(default_factory=dict, description=field_description), # type: ignore - ) - - else: - # Normal parameter - if default == inspect._empty: - # Required field - fields[name] = ( - ann, - Field(..., description=field_description), - ) - else: - # Parameter with a default value - fields[name] = ( - ann, - Field(default=default, description=field_description), - ) - - # 3. Dynamically build a Pydantic model - dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields) - - # 4. Build JSON schema from that model - json_schema = dynamic_model.model_json_schema() - if strict_json_schema: - json_schema = ensure_strict_json_schema(json_schema) - - # 5. Return as a FuncSchema dataclass - return FuncSchema( - name=func_name, - description=description_override or doc_info.description if doc_info else None, - params_pydantic_model=dynamic_model, - params_json_schema=json_schema, - signature=sig, - takes_context=takes_context, - strict_json_schema=strict_json_schema, - ) From 7a8c5f2a90b1b6f00837ab3df77d9dabd392f225 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 8 Jul 2025 14:19:28 -0400 Subject: [PATCH 183/183] Remove now-redundant module namespace parameter --- temporalio/contrib/openai_agents/temporal_tools.py | 1 - 1 file changed, 1 deletion(-) diff --git a/temporalio/contrib/openai_agents/temporal_tools.py b/temporalio/contrib/openai_agents/temporal_tools.py index eb4ea8645..a8b2e79c5 100644 --- a/temporalio/contrib/openai_agents/temporal_tools.py +++ b/temporalio/contrib/openai_agents/temporal_tools.py @@ -127,7 +127,6 @@ def nexus_operation_as_tool( service: Type[Any], endpoint: str, schedule_to_close_timeout: Optional[timedelta] = None, - function_schema_globalns: Optional[dict[str, Any]] = None, ) -> Tool: """Convert a Nexus operation into an OpenAI agent tool.