Skip to content

Commit 38e2056

Browse files
committed
RTU: Copy operation factory getter/setter from nexusrpc
1 parent 42094c4 commit 38e2056

File tree

4 files changed

+39
-5
lines changed

4 files changed

+39
-5
lines changed

temporalio/nexus/_util.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,40 @@ def get_callable_name(fn: Callable[..., Any]) -> str:
125125
return method_name
126126

127127

128+
# TODO(nexus-preview) Copied from nexusrpc
129+
def get_operation_factory(
130+
obj: Any,
131+
) -> tuple[
132+
Optional[Callable[[Any], OperationHandler[InputT, OutputT]]],
133+
Optional[nexusrpc.Operation[InputT, OutputT]],
134+
]:
135+
"""Return the :py:class:`Operation` for the object along with the factory function.
136+
137+
``obj`` should be a decorated operation start method.
138+
"""
139+
op_defn = get_operation_definition(obj)
140+
if op_defn:
141+
factory = obj
142+
else:
143+
if factory := getattr(obj, "__nexus_operation_factory__", None):
144+
op_defn = get_operation_definition(factory)
145+
if not isinstance(op_defn, nexusrpc.Operation):
146+
return None, None
147+
return factory, op_defn
148+
149+
150+
# TODO(nexus-preview) Copied from nexusrpc
151+
def set_operation_factory(
152+
obj: Any,
153+
operation_factory: Callable[[Any], OperationHandler[InputT, OutputT]],
154+
) -> None:
155+
"""Set the :py:class:`OperationHandler` factory for this object.
156+
157+
``obj`` should be an operation start method.
158+
"""
159+
setattr(obj, "__nexus_operation_factory__", operation_factory)
160+
161+
128162
# Copied from https://github.com/modelcontextprotocol/python-sdk
129163
#
130164
# Copyright (c) 2024 Anthropic, PBC.

temporalio/worker/_interceptor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import temporalio.api.common.v1
2626
import temporalio.common
2727
import temporalio.nexus
28+
import temporalio.nexus._util
2829
import temporalio.workflow
2930
from temporalio.workflow import VersioningIntent
3031

@@ -313,7 +314,7 @@ def __post_init__(self) -> None:
313314
self._operation_name = self.operation
314315
self._input_type = None
315316
elif isinstance(self.operation, Callable):
316-
_, op = nexusrpc.get_operation_factory(self.operation)
317+
_, op = temporalio.nexus._util.get_operation_factory(self.operation)
317318
if isinstance(op, nexusrpc.Operation):
318319
self._operation_name = op.name
319320
self._input_type = op.input_type

tests/nexus/test_dynamic_creation_of_user_handler_classes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import httpx
44
import nexusrpc.handler
55
import pytest
6-
from nexusrpc import get_operation_factory
76
from nexusrpc.handler import sync_operation
87

98
from temporalio.client import Client
9+
from temporalio.nexus._util import get_operation_factory
1010
from temporalio.worker import Worker
1111
from tests.helpers.nexus import create_nexus_endpoint
1212

tests/nexus/test_handler_operation_definitions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from temporalio import nexus
1313
from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation
14+
from temporalio.nexus._util import get_operation_factory
1415

1516

1617
@dataclass
@@ -92,9 +93,7 @@ async def test_collected_operation_names(
9293
assert isinstance(service_defn, nexusrpc.ServiceDefinition)
9394
assert service_defn.name == "Service"
9495
for method_name, expected_op in test_case.expected_operations.items():
95-
_, actual_op = nexusrpc.get_operation_factory(
96-
getattr(test_case.Service, method_name)
97-
)
96+
_, actual_op = get_operation_factory(getattr(test_case.Service, method_name))
9897
assert isinstance(actual_op, nexusrpc.Operation)
9998
assert actual_op.name == expected_op.name
10099
assert actual_op.input_type == expected_op.input_type

0 commit comments

Comments
 (0)