Skip to content

Commit 044b1de

Browse files
authored
Add workflow.instance() API for obtaining current workflow instance (#739)
1 parent 150878f commit 044b1de

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

temporalio/worker/_workflow_instance.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,9 @@ def workflow_get_update_validator(self, name: Optional[str]) -> Optional[Callabl
10231023
def workflow_info(self) -> temporalio.workflow.Info:
10241024
return self._outbound.info()
10251025

1026+
def workflow_instance(self) -> Any:
1027+
return self._object
1028+
10261029
def workflow_is_continue_as_new_suggested(self) -> bool:
10271030
return self._continue_as_new_suggested
10281031

temporalio/workflow.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,9 @@ def workflow_get_update_validator(
625625
@abstractmethod
626626
def workflow_info(self) -> Info: ...
627627

628+
@abstractmethod
629+
def workflow_instance(self) -> Any: ...
630+
628631
@abstractmethod
629632
def workflow_is_continue_as_new_suggested(self) -> bool: ...
630633

@@ -818,6 +821,15 @@ def info() -> Info:
818821
return _Runtime.current().workflow_info()
819822

820823

824+
def instance() -> Any:
825+
"""Current workflow's instance.
826+
827+
Returns:
828+
The currently running workflow instance.
829+
"""
830+
return _Runtime.current().workflow_instance()
831+
832+
821833
def memo() -> Mapping[str, Any]:
822834
"""Current workflow's memo values, converted without type hints.
823835

tests/worker/test_interceptor.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,44 @@ def pop_trace(name: str, filter: Optional[Callable[[Any], bool]] = None) -> Any:
283283

284284
# Confirm no unexpected traces
285285
assert not interceptor_traces
286+
287+
288+
class WorkflowInstanceAccessInterceptor(Interceptor):
289+
def workflow_interceptor_class(
290+
self, input: WorkflowInterceptorClassInput
291+
) -> Optional[Type[WorkflowInboundInterceptor]]:
292+
return WorkflowInstanceAccessInboundInterceptor
293+
294+
295+
class WorkflowInstanceAccessInboundInterceptor(WorkflowInboundInterceptor):
296+
async def execute_workflow(self, input: ExecuteWorkflowInput) -> int:
297+
# Return integer difference between ids of workflow instance obtained from workflow run method and
298+
# from workflow.instance(). They should be the same, so the difference should be 0.
299+
from_workflow_instance_api = workflow.instance()
300+
assert from_workflow_instance_api is not None
301+
id_from_workflow_instance_api = id(from_workflow_instance_api)
302+
id_from_workflow_run_method = await super().execute_workflow(input)
303+
return id_from_workflow_run_method - id_from_workflow_instance_api
304+
305+
306+
@workflow.defn
307+
class WorkflowInstanceAccessWorkflow:
308+
@workflow.run
309+
async def run(self) -> int:
310+
return id(self)
311+
312+
313+
async def test_workflow_instance_access_from_interceptor(client: Client):
314+
task_queue = f"task_queue_{uuid.uuid4()}"
315+
async with Worker(
316+
client,
317+
task_queue=task_queue,
318+
workflows=[WorkflowInstanceAccessWorkflow],
319+
interceptors=[WorkflowInstanceAccessInterceptor()],
320+
):
321+
difference = await client.execute_workflow(
322+
WorkflowInstanceAccessWorkflow.run,
323+
id=f"workflow_{uuid.uuid4()}",
324+
task_queue=task_queue,
325+
)
326+
assert difference == 0

0 commit comments

Comments
 (0)