diff --git a/temporalio/nexus/_link_conversion.py b/temporalio/nexus/_link_conversion.py new file mode 100644 index 000000000..87027333b --- /dev/null +++ b/temporalio/nexus/_link_conversion.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import logging +import re +import urllib.parse +from typing import ( + Any, + Optional, +) + +import nexusrpc + +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.client + +logger = logging.getLogger(__name__) + +_LINK_URL_PATH_REGEX = re.compile( + r"^/namespaces/(?P[^/]+)/workflows/(?P[^/]+)/(?P[^/]+)/history$" +) +LINK_EVENT_ID_PARAM_NAME = "eventID" +LINK_EVENT_TYPE_PARAM_NAME = "eventType" + + +def workflow_handle_to_workflow_execution_started_event_link( + handle: temporalio.client.WorkflowHandle[Any, Any], +) -> temporalio.api.common.v1.Link.WorkflowEvent: + """Create a WorkflowEvent link corresponding to a started workflow""" + if handle.first_execution_run_id is None: + raise ValueError( + f"Workflow handle {handle} has no first execution run ID. " + f"Cannot create WorkflowExecutionStarted event link." + ) + return temporalio.api.common.v1.Link.WorkflowEvent( + namespace=handle._client.namespace, + workflow_id=handle.id, + run_id=handle.first_execution_run_id, + event_ref=temporalio.api.common.v1.Link.WorkflowEvent.EventReference( + event_id=1, + event_type=temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED, + ), + # TODO(nexus-preview): RequestIdReference + ) + + +def workflow_event_to_nexus_link( + workflow_event: temporalio.api.common.v1.Link.WorkflowEvent, +) -> nexusrpc.Link: + """Convert a WorkflowEvent link into a nexusrpc link + + Used when propagating links from a StartWorkflow response to a Nexus start operation + response. + """ + scheme = "temporal" + namespace = urllib.parse.quote(workflow_event.namespace) + workflow_id = urllib.parse.quote(workflow_event.workflow_id) + run_id = urllib.parse.quote(workflow_event.run_id) + path = f"/namespaces/{namespace}/workflows/{workflow_id}/{run_id}/history" + query_params = _event_reference_to_query_params(workflow_event.event_ref) + return nexusrpc.Link( + url=urllib.parse.urlunparse((scheme, "", path, "", query_params, "")), + type=workflow_event.DESCRIPTOR.full_name, + ) + + +def nexus_link_to_workflow_event( + link: nexusrpc.Link, +) -> Optional[temporalio.api.common.v1.Link.WorkflowEvent]: + """Convert a nexus link into a WorkflowEvent link + + This is used when propagating links from a Nexus start operation request to a + StartWorklow request. + """ + url = urllib.parse.urlparse(link.url) + match = _LINK_URL_PATH_REGEX.match(url.path) + if not match: + logger.warning( + f"Invalid Nexus link: {link}. Expected path to match {_LINK_URL_PATH_REGEX.pattern}" + ) + return None + try: + event_ref = _query_params_to_event_reference(url.query) + except ValueError as err: + logger.warning( + f"Failed to parse event reference from Nexus link URL query parameters: {link} ({err})" + ) + return None + + groups = match.groupdict() + return temporalio.api.common.v1.Link.WorkflowEvent( + namespace=urllib.parse.unquote(groups["namespace"]), + workflow_id=urllib.parse.unquote(groups["workflow_id"]), + run_id=urllib.parse.unquote(groups["run_id"]), + event_ref=event_ref, + ) + + +def _event_reference_to_query_params( + event_ref: temporalio.api.common.v1.Link.WorkflowEvent.EventReference, +) -> str: + event_type_name = temporalio.api.enums.v1.EventType.Name(event_ref.event_type) + if event_type_name.startswith("EVENT_TYPE_"): + event_type_name = _event_type_constant_case_to_pascal_case( + event_type_name.removeprefix("EVENT_TYPE_") + ) + return urllib.parse.urlencode( + { + "eventID": event_ref.event_id, + "eventType": event_type_name, + "referenceType": "EventReference", + } + ) + + +def _query_params_to_event_reference( + raw_query_params: str, +) -> temporalio.api.common.v1.Link.WorkflowEvent.EventReference: + """Return an EventReference from the query params or raise ValueError.""" + query_params = urllib.parse.parse_qs(raw_query_params) + + [reference_type] = query_params.get("referenceType") or [""] + if reference_type != "EventReference": + raise ValueError( + f"Expected Nexus link URL query parameter referenceType to be EventReference but got: {reference_type}" + ) + # event type + [raw_event_type_name] = query_params.get(LINK_EVENT_TYPE_PARAM_NAME) or [""] + if not raw_event_type_name: + raise ValueError(f"query params do not contain event type: {query_params}") + if raw_event_type_name.startswith("EVENT_TYPE_"): + event_type_name = raw_event_type_name + elif re.match("[A-Z][a-z]", raw_event_type_name): + event_type_name = "EVENT_TYPE_" + _event_type_pascal_case_to_constant_case( + raw_event_type_name + ) + else: + raise ValueError(f"Invalid event type name: {raw_event_type_name}") + + # event id + event_id = 0 + [raw_event_id] = query_params.get(LINK_EVENT_ID_PARAM_NAME) or [""] + if raw_event_id: + try: + event_id = int(raw_event_id) + except ValueError: + raise ValueError(f"Query params contain invalid event id: {raw_event_id}") + + return temporalio.api.common.v1.Link.WorkflowEvent.EventReference( + event_type=temporalio.api.enums.v1.EventType.Value(event_type_name), + event_id=event_id, + ) + + +def _event_type_constant_case_to_pascal_case(s: str) -> str: + """Convert a CONSTANT_CASE string to PascalCase. + + >>> _event_type_constant_case_to_pascal_case("NEXUS_OPERATION_SCHEDULED") + "NexusOperationScheduled" + """ + return re.sub(r"(\b|_)([a-z])", lambda m: m.groups()[1].upper(), s.lower()) + + +def _event_type_pascal_case_to_constant_case(s: str) -> str: + """Convert a PascalCase string to CONSTANT_CASE. + + >>> _event_type_pascal_case_to_constant_case("NexusOperationScheduled") + "NEXUS_OPERATION_SCHEDULED" + """ + return re.sub(r"([A-Z])", r"_\1", s).lstrip("_").upper() diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index ce129e4da..52e6f7b1d 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -2,8 +2,6 @@ import dataclasses import logging -import re -import urllib.parse from contextvars import ContextVar from dataclasses import dataclass from datetime import timedelta @@ -20,14 +18,13 @@ overload, ) -import nexusrpc.handler from nexusrpc.handler import CancelOperationContext, StartOperationContext from typing_extensions import Concatenate import temporalio.api.common.v1 -import temporalio.api.enums.v1 import temporalio.client import temporalio.common +from temporalio.nexus import _link_conversion from temporalio.nexus._token import WorkflowHandle from temporalio.types import ( MethodAsyncNoParam, @@ -128,11 +125,6 @@ def _get_callbacks( ctx = self.nexus_context return ( [ - # TODO(nexus-prerelease): For WorkflowRunOperation, when it handles the Nexus - # request, it needs to copy the links to the callback in - # StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links - # (for backwards compatibility). PR reference in Go SDK: - # https://github.com/temporalio/sdk-go/pull/1945 temporalio.client.NexusCallback( url=ctx.callback_url, headers=ctx.callback_headers, @@ -147,7 +139,7 @@ def _get_workflow_event_links( ) -> list[temporalio.api.common.v1.Link.WorkflowEvent]: event_links = [] for inbound_link in self.nexus_context.inbound_links: - if link := _nexus_link_to_workflow_event(inbound_link): + if link := _link_conversion.nexus_link_to_workflow_event(inbound_link): event_links.append(link) return event_links @@ -155,8 +147,8 @@ def _add_outbound_links( self, workflow_handle: temporalio.client.WorkflowHandle[Any, Any] ): try: - link = _workflow_event_to_nexus_link( - _workflow_handle_to_workflow_execution_started_event_link( + link = _link_conversion.workflow_event_to_nexus_link( + _link_conversion.workflow_handle_to_workflow_execution_started_event_link( workflow_handle ) ) @@ -479,91 +471,6 @@ def set(self) -> None: _temporal_cancel_operation_context.set(self) -def _workflow_handle_to_workflow_execution_started_event_link( - handle: temporalio.client.WorkflowHandle[Any, Any], -) -> temporalio.api.common.v1.Link.WorkflowEvent: - if handle.first_execution_run_id is None: - raise ValueError( - f"Workflow handle {handle} has no first execution run ID. " - "Cannot create WorkflowExecutionStarted event link." - ) - return temporalio.api.common.v1.Link.WorkflowEvent( - namespace=handle._client.namespace, - workflow_id=handle.id, - run_id=handle.first_execution_run_id, - event_ref=temporalio.api.common.v1.Link.WorkflowEvent.EventReference( - event_id=1, - event_type=temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED, - ), - # TODO(nexus-prerelease): RequestIdReference? - ) - - -def _workflow_event_to_nexus_link( - workflow_event: temporalio.api.common.v1.Link.WorkflowEvent, -) -> nexusrpc.Link: - scheme = "temporal" - namespace = urllib.parse.quote(workflow_event.namespace) - workflow_id = urllib.parse.quote(workflow_event.workflow_id) - run_id = urllib.parse.quote(workflow_event.run_id) - path = f"/namespaces/{namespace}/workflows/{workflow_id}/{run_id}/history" - query_params = urllib.parse.urlencode( - { - "eventType": temporalio.api.enums.v1.EventType.Name( - workflow_event.event_ref.event_type - ), - "referenceType": "EventReference", - } - ) - return nexusrpc.Link( - url=urllib.parse.urlunparse((scheme, "", path, "", query_params, "")), - type=workflow_event.DESCRIPTOR.full_name, - ) - - -_LINK_URL_PATH_REGEX = re.compile( - r"^/namespaces/(?P[^/]+)/workflows/(?P[^/]+)/(?P[^/]+)/history$" -) - - -def _nexus_link_to_workflow_event( - link: nexusrpc.Link, -) -> Optional[temporalio.api.common.v1.Link.WorkflowEvent]: - url = urllib.parse.urlparse(link.url) - match = _LINK_URL_PATH_REGEX.match(url.path) - if not match: - logger.warning( - f"Invalid Nexus link: {link}. Expected path to match {_LINK_URL_PATH_REGEX.pattern}" - ) - return None - try: - query_params = urllib.parse.parse_qs(url.query) - [reference_type] = query_params.get("referenceType", []) - if reference_type != "EventReference": - raise ValueError( - f"Expected Nexus link URL query parameter referenceType to be EventReference but got: {reference_type}" - ) - [event_type_name] = query_params.get("eventType", []) - event_ref = temporalio.api.common.v1.Link.WorkflowEvent.EventReference( - # TODO(nexus-prerelease): confirm that it is correct not to use event_id. - # Should the proto say explicitly that it's optional or how it behaves when it's missing? - event_type=temporalio.api.enums.v1.EventType.Value(event_type_name) - ) - except ValueError as err: - logger.warning( - f"Failed to parse event type from Nexus link URL query parameters: {link} ({err})" - ) - event_ref = None - - groups = match.groupdict() - return temporalio.api.common.v1.Link.WorkflowEvent( - namespace=urllib.parse.unquote(groups["namespace"]), - workflow_id=urllib.parse.unquote(groups["workflow_id"]), - run_id=urllib.parse.unquote(groups["run_id"]), - event_ref=event_ref, - ) - - class LoggerAdapter(logging.LoggerAdapter): """Logger adapter that adds Nexus operation context information.""" diff --git a/tests/helpers/nexus.py b/tests/helpers/nexus.py index 3740c853e..6e8c62e59 100644 --- a/tests/helpers/nexus.py +++ b/tests/helpers/nexus.py @@ -1,6 +1,7 @@ import dataclasses from dataclasses import dataclass from typing import Any, Mapping, Optional +from urllib.parse import urlparse import temporalio.api.failure.v1 import temporalio.api.nexus.v1 @@ -8,6 +9,7 @@ import temporalio.workflow from temporalio.client import Client from temporalio.converter import FailureConverter, PayloadConverter +from temporalio.testing import WorkflowEnvironment with temporalio.workflow.unsafe.imports_passed_through(): import httpx @@ -58,7 +60,7 @@ async def start_operation( # TODO(nexus-preview): Support callback URL as query param async with httpx.AsyncClient() as http_client: return await http_client.post( - f"{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}", + f"http://{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}", json=body, headers=headers, ) @@ -70,11 +72,20 @@ async def cancel_operation( ) -> httpx.Response: async with httpx.AsyncClient() as http_client: return await http_client.post( - f"{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}/cancel", + f"http://{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}/cancel", # Token can also be sent as "Nexus-Operation-Token" header params={"token": token}, ) + @staticmethod + def default_server_address(env: WorkflowEnvironment) -> str: + # TODO(nexus-preview): nexus tests are making http requests directly but this is + # not officially supported. + parsed = urlparse(env.client.service_client.config.target_host) + host = parsed.hostname or "127.0.0.1" + http_port = getattr(env, "_http_port", 7243) + return f"{host}:{http_port}" + def dataclass_as_dict(dataclass: Any) -> dict[str, Any]: """ diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index 8177d8b27..0eef14b84 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -10,9 +10,7 @@ from temporalio.nexus._util import get_operation_factory from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers.nexus import create_nexus_endpoint - -HTTP_PORT = 7243 +from tests.helpers.nexus import ServiceClient, create_nexus_endpoint @workflow.defn @@ -102,9 +100,10 @@ async def test_run_nexus_service_from_programmatically_created_service_handler( task_queue=task_queue, nexus_service_handlers=[service_handler], ): + server_address = ServiceClient.default_server_address(env) async with httpx.AsyncClient() as http_client: response = await http_client.post( - f"http://127.0.0.1:{HTTP_PORT}/nexus/endpoints/{endpoint}/services/{service_name}/increment", + f"http://{server_address}/nexus/endpoints/{endpoint}/services/{service_name}/increment", json=1, ) assert response.status_code == 201 @@ -147,7 +146,9 @@ async def _increment_op( @pytest.mark.skip( reason="Dynamic creation of service contract using type() is not supported" ) -async def test_dynamic_creation_of_user_handler_classes(client: Client): +async def test_dynamic_creation_of_user_handler_classes( + client: Client, env: WorkflowEnvironment +): task_queue = str(uuid.uuid4()) service_cls, handler_cls = ( @@ -165,9 +166,10 @@ async def test_dynamic_creation_of_user_handler_classes(client: Client): task_queue=task_queue, nexus_service_handlers=[handler_cls()], ): + server_address = ServiceClient.default_server_address(env) async with httpx.AsyncClient() as http_client: response = await http_client.post( - f"http://127.0.0.1:{HTTP_PORT}/nexus/endpoints/{endpoint}/services/{service_name}/increment", + f"http://{server_address}/nexus/endpoints/{endpoint}/services/{service_name}/increment", json=1, ) assert response.status_code == 200 diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 07378fdc4..2d0e59908 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -62,8 +62,6 @@ dataclass_as_dict, ) -HTTP_PORT = 7243 - @dataclass class Input: @@ -622,7 +620,7 @@ async def _test_start_operation_with_service_definition( task_queue = str(uuid.uuid4()) endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id service_client = ServiceClient( - server_address=server_address(env), + server_address=ServiceClient.default_server_address(env), endpoint=endpoint, service=(test_case.service_defn), ) @@ -656,7 +654,7 @@ async def _test_start_operation_without_service_definition( task_queue = str(uuid.uuid4()) endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id service_client = ServiceClient( - server_address=server_address(env), + server_address=ServiceClient.default_server_address(env), endpoint=endpoint, service=MyServiceHandler.__name__, ) @@ -744,7 +742,7 @@ async def test_start_operation_without_type_annotations( task_queue = str(uuid.uuid4()) endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id service_client = ServiceClient( - server_address=server_address(env), + server_address=ServiceClient.default_server_address(env), endpoint=endpoint, service=MyServiceWithOperationsWithoutTypeAnnotations.__name__, ) @@ -791,7 +789,7 @@ async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: A resp = await create_nexus_endpoint(task_queue, env.client) endpoint = resp.endpoint.id service_client = ServiceClient( - server_address=server_address(env), + server_address=ServiceClient.default_server_address(env), endpoint=endpoint, service=service_name, ) @@ -950,7 +948,7 @@ async def test_cancel_operation_with_invalid_token(env: WorkflowEnvironment): task_queue = str(uuid.uuid4()) endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id service_client = ServiceClient( - server_address=server_address(env), + server_address=ServiceClient.default_server_address(env), endpoint=endpoint, service=MyService.__name__, ) @@ -982,7 +980,7 @@ async def test_request_id_is_received_by_sync_operation( task_queue = str(uuid.uuid4()) endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id service_client = ServiceClient( - server_address=server_address(env), + server_address=ServiceClient.default_server_address(env), endpoint=endpoint, service=MyService.__name__, ) @@ -1056,7 +1054,7 @@ async def test_request_id_becomes_start_workflow_request_id(env: WorkflowEnviron task_queue = str(uuid.uuid4()) endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id service_client = ServiceClient( - server_address=server_address(env), + server_address=ServiceClient.default_server_address(env), endpoint=endpoint, service=ServiceHandlerForRequestIdTest.__name__, ) @@ -1124,8 +1122,3 @@ async def start_two_workflows_in_a_single_operation( await start_two_workflows_in_a_single_operation( request_id_1, 500, "Workflow execution already started" ) - - -def server_address(env: WorkflowEnvironment) -> str: - http_port = getattr(env, "_http_port", 7243) - return f"http://127.0.0.1:{http_port}" diff --git a/tests/nexus/test_handler_async_operation.py b/tests/nexus/test_handler_async_operation.py index 808e93b3c..df245d0ff 100644 --- a/tests/nexus/test_handler_async_operation.py +++ b/tests/nexus/test_handler_async_operation.py @@ -151,7 +151,7 @@ async def test_async_operation_lifecycle( task_queue = str(uuid.uuid4()) endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id service_client = ServiceClient( - f"http://127.0.0.1:{env._http_port}", # type: ignore + ServiceClient.default_server_address(env), endpoint, service_handler_cls.__name__, ) diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py index 89e7da2bd..8db3c7ddc 100644 --- a/tests/nexus/test_handler_interface_implementation.py +++ b/tests/nexus/test_handler_interface_implementation.py @@ -8,8 +8,6 @@ from temporalio import nexus from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation -HTTP_PORT = 7243 - class _InterfaceImplementationTestCase: Interface: Type[Any] diff --git a/tests/nexus/test_link_conversion.py b/tests/nexus/test_link_conversion.py new file mode 100644 index 000000000..4170f515b --- /dev/null +++ b/tests/nexus/test_link_conversion.py @@ -0,0 +1,92 @@ +import urllib.parse +from typing import Any + +import pytest + +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.nexus._link_conversion + + +@pytest.mark.parametrize( + ["query_param_str", "expected_event_ref"], + [ + ( + "eventType=NexusOperationScheduled&referenceType=EventReference&eventID=7", + { + "event_type": temporalio.api.enums.v1.EventType.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED, + "event_id": 7, + }, + ), + # event ID is optional in query params; we set it to 0 in the event ref if missing + ( + "eventType=NexusOperationScheduled&referenceType=EventReference", + { + "event_type": temporalio.api.enums.v1.EventType.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED, + "event_id": 0, + }, + ), + # Older server sends EVENT_TYPE_CONSTANT_CASE event type name + ( + "eventType=EVENT_TYPE_NEXUS_OPERATION_SCHEDULED&referenceType=EventReference", + { + "event_type": temporalio.api.enums.v1.EventType.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED, + "event_id": 0, + }, + ), + ], +) +def test_query_params_to_event_reference( + query_param_str: str, expected_event_ref: dict[str, Any] +): + event_ref = temporalio.nexus._link_conversion._query_params_to_event_reference( + query_param_str + ) + for k, v in expected_event_ref.items(): + assert getattr(event_ref, k) == v + + +@pytest.mark.parametrize( + ["event_ref", "expected_query_param_str"], + [ + # We always send PascalCase event type names (no EventType prefix) + ( + { + "event_type": temporalio.api.enums.v1.EventType.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED, + "event_id": 7, + }, + "eventType=NexusOperationScheduled&referenceType=EventReference&eventID=7", + ), + ], +) +def test_event_reference_to_query_params( + event_ref: dict[str, Any], expected_query_param_str: str +): + query_params_str = ( + temporalio.nexus._link_conversion._event_reference_to_query_params( + temporalio.api.common.v1.Link.WorkflowEvent.EventReference(**event_ref) + ) + ) + query_params = urllib.parse.parse_qs(query_params_str) + expected_query_params = urllib.parse.parse_qs(expected_query_param_str) + assert query_params == expected_query_params + + +def test_link_conversion_utilities(): + p2c = temporalio.nexus._link_conversion._event_type_pascal_case_to_constant_case + c2p = temporalio.nexus._link_conversion._event_type_constant_case_to_pascal_case + + for p, c in [ + ("", ""), + ("A", "A"), + ("Ab", "AB"), + ("AbCd", "AB_CD"), + ("AbCddE", "AB_CDD_E"), + ("ContainsAOneLetterWord", "CONTAINS_A_ONE_LETTER_WORD"), + ("NexusOperationScheduled", "NEXUS_OPERATION_SCHEDULED"), + ]: + assert p2c(p) == c + assert c2p(c) == p + + assert p2c("a") == "A" + assert c2p("A") == "A" diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 9d66238f5..c9417ef58 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -23,14 +23,9 @@ from nexusrpc.handler._decorators import operation_handler import temporalio.api -import temporalio.api.common import temporalio.api.common.v1 import temporalio.api.enums.v1 -import temporalio.api.nexus -import temporalio.api.nexus.v1 -import temporalio.api.operatorservice -import temporalio.api.operatorservice.v1 -import temporalio.exceptions +import temporalio.api.history.v1 import temporalio.nexus._operation_handlers from temporalio import nexus, workflow from temporalio.client import ( @@ -430,6 +425,79 @@ async def run( # +async def test_sync_operation_happy_path(client: Client, env: WorkflowEnvironment): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with time-skipping server") + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_service_handlers=[ServiceImpl()], + workflows=[CallerWorkflow, HandlerWorkflow], + task_queue=task_queue, + workflow_failure_exception_types=[Exception], + ): + await create_nexus_endpoint(task_queue, client) + wf_output = await client.execute_workflow( + CallerWorkflow.run, + args=[ + CallerWfInput( + op_input=OpInput( + response_type=SyncResponse( + op_definition_type=OpDefinitionType.SHORTHAND, + use_async_def=True, + exception_in_operation_start=False, + ), + headers={}, + caller_reference=CallerReference.IMPL_WITH_INTERFACE, + ), + ), + False, + task_queue, + ], + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + assert wf_output.op_output.value == "sync response" + + +async def test_workflow_run_operation_happy_path( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with time-skipping server") + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_service_handlers=[ServiceImpl()], + workflows=[CallerWorkflow, HandlerWorkflow], + task_queue=task_queue, + workflow_failure_exception_types=[Exception], + ): + await create_nexus_endpoint(task_queue, client) + wf_output = await client.execute_workflow( + CallerWorkflow.run, + args=[ + CallerWfInput( + op_input=OpInput( + response_type=AsyncResponse( + operation_workflow_id=str(uuid.uuid4()), + block_forever_waiting_for_cancellation=False, + op_definition_type=OpDefinitionType.SHORTHAND, + exception_in_operation_start=False, + ), + headers={}, + caller_reference=CallerReference.IMPL_WITH_INTERFACE, + ), + ), + False, + task_queue, + ], + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + assert wf_output.op_output.value == "workflow result" + + # TODO(nexus-preview): cross-namespace tests # TODO(nexus-preview): nexus endpoint pytest fixture? # TODO(nexus-prerelease): test headers @@ -568,12 +636,12 @@ async def test_async_response( WorkflowExecutionStatus.RUNNING, WorkflowExecutionStatus.COMPLETED, ] - await assert_caller_workflow_has_link_to_handler_workflow( - caller_wf_handle, handler_wf_handle, handler_wf_info.run_id - ) await assert_handler_workflow_has_link_to_caller_workflow( caller_wf_handle, handler_wf_handle ) + await assert_caller_workflow_has_link_to_handler_workflow( + caller_wf_handle, handler_wf_handle, handler_wf_info.run_id + ) if request_cancel: # The operation response was asynchronous and so request_cancel is honored. See @@ -1047,11 +1115,12 @@ async def assert_handler_workflow_has_link_to_caller_workflow( == temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED ) ) - if not len(wf_started_event.links) == 1: + links = _get_links_from_workflow_execution_started_event(wf_started_event) + if not len(links) == 1: pytest.fail( - f"Expected 1 link on WorkflowExecutionStarted event, got {len(wf_started_event.links)}" + f"Expected 1 link on WorkflowExecutionStarted event, got {len(links)}" ) - [link] = wf_started_event.links + [link] = links assert link.workflow_event.namespace == caller_wf_handle._client.namespace assert link.workflow_event.workflow_id == caller_wf_handle.id assert link.workflow_event.run_id @@ -1062,6 +1131,16 @@ async def assert_handler_workflow_has_link_to_caller_workflow( ) +def _get_links_from_workflow_execution_started_event( + event: temporalio.api.history.v1.HistoryEvent, +) -> list[temporalio.api.common.v1.Link]: + [callback] = event.workflow_execution_started_event_attributes.completion_callbacks + if links := callback.links: + return list(links) + else: + return list(event.links) + + # When request_cancel is True, the NexusOperationHandle in the workflow evolves # through the following states: # start_fut result_fut handle_task w/ fut_waiter (task._must_cancel) diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index b784893f4..0869a1d00 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -26,8 +26,6 @@ dataclass_as_dict, ) -HTTP_PORT = 7243 - @dataclass class Input: @@ -97,7 +95,7 @@ async def test_workflow_run_operation( endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id assert (service_defn := nexusrpc.get_service_definition(service_handler_cls)) service_client = ServiceClient( - server_address=server_address(env), + server_address=ServiceClient.default_server_address(env), endpoint=endpoint, service=service_defn.name, ) @@ -117,8 +115,3 @@ async def test_workflow_run_operation( assert re.search(message, failure.message) else: assert resp.status_code == 201 - - -def server_address(env: WorkflowEnvironment) -> str: - http_port = getattr(env, "_http_port", 7243) - return f"http://127.0.0.1:{http_port}"