Skip to content

Commit ee6f8d7

Browse files
authored
Instantiate interceptors later in workflow instance construction so that more variables are available. Notably, allows signal registration to work (#887)
1 parent 520aefd commit ee6f8d7

File tree

2 files changed

+48
-18
lines changed

2 files changed

+48
-18
lines changed

temporalio/worker/_workflow_instance.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -301,24 +301,6 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
301301
str, List[temporalio.bridge.proto.workflow_activation.SignalWorkflow]
302302
] = {}
303303

304-
# Create interceptors. We do this with our runtime on the loop just in
305-
# case they want to access info() during init().
306-
temporalio.workflow._Runtime.set_on_loop(asyncio.get_running_loop(), self)
307-
try:
308-
root_inbound = _WorkflowInboundImpl(self)
309-
self._inbound: WorkflowInboundInterceptor = root_inbound
310-
for interceptor_class in reversed(list(det.interceptor_classes)):
311-
self._inbound = interceptor_class(self._inbound)
312-
# During init we set ourselves on the current loop
313-
self._inbound.init(_WorkflowOutboundImpl(self))
314-
self._outbound = root_inbound._outbound
315-
finally:
316-
# Remove our runtime from the loop
317-
temporalio.workflow._Runtime.set_on_loop(asyncio.get_running_loop(), None)
318-
319-
# Set ourselves on our own loop
320-
temporalio.workflow._Runtime.set_on_loop(self, self)
321-
322304
# When we evict, we have to mark the workflow as deleting so we don't
323305
# add any commands and we swallow exceptions on tear down
324306
self._deleting = False
@@ -342,6 +324,24 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
342324
Sequence[type[BaseException]]
343325
] = None
344326

327+
# Create interceptors. We do this with our runtime on the loop just in
328+
# case they want to access info() during init(). This should remain at the end of the constructor so that variables are defined during interceptor creation
329+
temporalio.workflow._Runtime.set_on_loop(asyncio.get_running_loop(), self)
330+
try:
331+
root_inbound = _WorkflowInboundImpl(self)
332+
self._inbound: WorkflowInboundInterceptor = root_inbound
333+
for interceptor_class in reversed(list(det.interceptor_classes)):
334+
self._inbound = interceptor_class(self._inbound)
335+
# During init we set ourselves on the current loop
336+
self._inbound.init(_WorkflowOutboundImpl(self))
337+
self._outbound = root_inbound._outbound
338+
finally:
339+
# Remove our runtime from the loop
340+
temporalio.workflow._Runtime.set_on_loop(asyncio.get_running_loop(), None)
341+
342+
# Set ourselves on our own loop
343+
temporalio.workflow._Runtime.set_on_loop(self, self)
344+
345345
def get_thread_id(self) -> Optional[int]:
346346
return self._current_thread_id
347347

tests/worker/test_workflow.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7973,3 +7973,33 @@ async def test_quick_activity_swallows_cancellation(client: Client):
79737973
assert cause.message == "Workflow cancelled"
79747974

79757975
temporalio.worker._workflow_instance._raise_on_cancelling_completed_activity_override = False
7976+
7977+
7978+
class SignalInterceptor(temporalio.worker.Interceptor):
7979+
def workflow_interceptor_class(
7980+
self, input: temporalio.worker.WorkflowInterceptorClassInput
7981+
) -> Type[SignalInboundInterceptor]:
7982+
return SignalInboundInterceptor
7983+
7984+
7985+
class SignalInboundInterceptor(temporalio.worker.WorkflowInboundInterceptor):
7986+
def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None:
7987+
def unblock() -> None:
7988+
return None
7989+
7990+
workflow.set_signal_handler("my_random_signal", unblock)
7991+
super().init(outbound)
7992+
7993+
7994+
async def test_signal_handler_in_interceptor(client: Client):
7995+
async with new_worker(
7996+
client,
7997+
HelloWorkflow,
7998+
interceptors=[SignalInterceptor()],
7999+
) as worker:
8000+
await client.execute_workflow(
8001+
HelloWorkflow.run,
8002+
"Temporal",
8003+
id=f"workflow-{uuid.uuid4()}",
8004+
task_queue=worker.task_queue,
8005+
)

0 commit comments

Comments
 (0)