Skip to content

Commit 5480eb4

Browse files
committed
Make WorkflowRunOperationContext subclass StartOperationContext
1 parent a6bf7cb commit 5480eb4

File tree

4 files changed

+26
-10
lines changed

4 files changed

+26
-10
lines changed

temporalio/nexus/_decorators.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
)
1818

1919
from temporalio.nexus._operation_context import (
20-
TemporalStartOperationContext,
2120
WorkflowRunOperationContext,
2221
)
2322
from temporalio.nexus._operation_handlers import (
@@ -115,8 +114,11 @@ def operation_handler_factory(
115114
async def _start(
116115
ctx: StartOperationContext, input: InputT
117116
) -> WorkflowHandle[OutputT]:
118-
tctx = TemporalStartOperationContext.get()
119-
return await start(self, WorkflowRunOperationContext(tctx), input)
117+
return await start(
118+
self,
119+
WorkflowRunOperationContext.from_start_operation_context(ctx),
120+
input,
121+
)
120122

121123
_start.__doc__ = start.__doc__
122124
return WorkflowRunOperationHandler(_start, input_type, output_type)

temporalio/nexus/_operation_context.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import dataclasses
34
import logging
45
import re
56
import urllib.parse
@@ -168,17 +169,28 @@ def add_outbound_links(
168169
return workflow_handle
169170

170171

171-
@dataclass
172-
class WorkflowRunOperationContext:
173-
temporal_context: TemporalStartOperationContext
172+
@dataclass(frozen=True)
173+
class WorkflowRunOperationContext(StartOperationContext):
174+
_temporal_context: Optional[TemporalStartOperationContext] = None
175+
176+
@property
177+
def temporal_context(self) -> TemporalStartOperationContext:
178+
if not self._temporal_context:
179+
raise RuntimeError("Temporal context not set")
180+
return self._temporal_context
174181

175182
@property
176183
def nexus_context(self) -> StartOperationContext:
177184
return self.temporal_context.nexus_context
178185

179186
@classmethod
180-
def get(cls) -> WorkflowRunOperationContext:
181-
return cls(TemporalStartOperationContext.get())
187+
def from_start_operation_context(
188+
cls, ctx: StartOperationContext
189+
) -> WorkflowRunOperationContext:
190+
return cls(
191+
_temporal_context=TemporalStartOperationContext.get(),
192+
**{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)},
193+
)
182194

183195
# Overload for single-param workflow
184196
# TODO(nexus-prerelease): bring over other overloads

tests/nexus/test_workflow_caller.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ async def start(
160160
# TODO(nexus-preview): what do we want the DX to be for a user who is
161161
# starting a Nexus backing workflow from a custom start method? (They may
162162
# need to do this in order to customize the cancel method).
163-
handle = await WorkflowRunOperationContext.get().start_workflow(
163+
tctx = WorkflowRunOperationContext.from_start_operation_context(ctx)
164+
handle = await tctx.start_workflow(
164165
HandlerWorkflow.run,
165166
HandlerWfInput(op_input=input),
166167
id=input.response_type.operation_workflow_id,

tests/nexus/test_workflow_run_operation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def __init__(self):
4949
async def start(
5050
self, ctx: StartOperationContext, input: Input
5151
) -> StartOperationResultAsync:
52-
handle = await WorkflowRunOperationContext.get().start_workflow(
52+
tctx = WorkflowRunOperationContext.from_start_operation_context(ctx)
53+
handle = await tctx.start_workflow(
5354
EchoWorkflow.run,
5455
input.value,
5556
id=str(uuid.uuid4()),

0 commit comments

Comments
 (0)