Skip to content

Commit b2cce8c

Browse files
committed
Test start_workflow overloads
1 parent 605bcb3 commit b2cce8c

File tree

1 file changed

+148
-1
lines changed

1 file changed

+148
-1
lines changed

tests/nexus/test_workflow_caller.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from datetime import timedelta
55
from enum import IntEnum
66
from itertools import zip_longest
7-
from typing import Any, Callable, Literal, Union
7+
from typing import Any, Awaitable, Callable, Literal, Union
88

99
import nexusrpc
1010
import nexusrpc.handler
@@ -31,6 +31,7 @@
3131
import temporalio.api.operatorservice
3232
import temporalio.api.operatorservice.v1
3333
import temporalio.exceptions
34+
import temporalio.nexus
3435
from temporalio import nexus, workflow
3536
from temporalio.client import (
3637
Client,
@@ -1554,3 +1555,149 @@ async def test_timeout_error_raised_by_nexus_operation(client: Client):
15541555
assert isinstance(err, WorkflowFailureError)
15551556
assert isinstance(err.__cause__, NexusOperationError)
15561557
assert isinstance(err.__cause__.__cause__, TimeoutError)
1558+
1559+
1560+
# Test overloads
1561+
1562+
1563+
@workflow.defn
1564+
class OverloadTestHandlerWorkflow:
1565+
@workflow.run
1566+
async def run(self, input: int) -> int:
1567+
return input * 2
1568+
1569+
1570+
@workflow.defn
1571+
class OverloadTestHandlerWorkflowNoParam:
1572+
@workflow.run
1573+
async def run(self) -> int:
1574+
return 0
1575+
1576+
1577+
@nexusrpc.handler.service_handler
1578+
class OverloadTestServiceHandler:
1579+
@workflow_run_operation
1580+
async def no_param(
1581+
self,
1582+
ctx: WorkflowRunOperationContext,
1583+
_: int,
1584+
) -> nexus.WorkflowHandle[int]:
1585+
return await ctx.start_workflow(
1586+
OverloadTestHandlerWorkflowNoParam.run,
1587+
id=str(uuid.uuid4()),
1588+
)
1589+
1590+
@workflow_run_operation
1591+
async def single_param(
1592+
self, ctx: WorkflowRunOperationContext, input: int
1593+
) -> nexus.WorkflowHandle[int]:
1594+
return await ctx.start_workflow(
1595+
OverloadTestHandlerWorkflow.run,
1596+
input,
1597+
id=str(uuid.uuid4()),
1598+
)
1599+
1600+
@workflow_run_operation
1601+
async def multi_param(
1602+
self, ctx: WorkflowRunOperationContext, input: int
1603+
) -> nexus.WorkflowHandle[int]:
1604+
return await ctx.start_workflow(
1605+
OverloadTestHandlerWorkflow.run,
1606+
args=[input],
1607+
id=str(uuid.uuid4()),
1608+
)
1609+
1610+
@workflow_run_operation
1611+
async def by_name(
1612+
self, ctx: WorkflowRunOperationContext, input: int
1613+
) -> nexus.WorkflowHandle[int]:
1614+
return await ctx.start_workflow(
1615+
"OverloadTestHandlerWorkflow",
1616+
input,
1617+
id=str(uuid.uuid4()),
1618+
result_type=OverloadTestValue,
1619+
)
1620+
1621+
@workflow_run_operation
1622+
async def by_name_multi_param(
1623+
self, ctx: WorkflowRunOperationContext, input: int
1624+
) -> nexus.WorkflowHandle[int]:
1625+
return await ctx.start_workflow(
1626+
"OverloadTestHandlerWorkflow",
1627+
args=[input],
1628+
id=str(uuid.uuid4()),
1629+
)
1630+
1631+
1632+
@dataclass
1633+
class OverloadTestInput:
1634+
op: Callable[
1635+
[Any, WorkflowRunOperationContext, Any],
1636+
Awaitable[temporalio.nexus.WorkflowHandle[Any]],
1637+
]
1638+
input: Any
1639+
output: Any
1640+
1641+
1642+
@workflow.defn
1643+
class OverloadTestCallerWorkflow:
1644+
@workflow.run
1645+
async def run(self, op: str, input: int) -> int:
1646+
nexus_client = workflow.NexusClient(
1647+
service=OverloadTestServiceHandler,
1648+
endpoint=make_nexus_endpoint_name(workflow.info().task_queue),
1649+
)
1650+
if op == "no_param":
1651+
return await nexus_client.execute_operation(
1652+
OverloadTestServiceHandler.no_param, input
1653+
)
1654+
elif op == "single_param":
1655+
return await nexus_client.execute_operation(
1656+
OverloadTestServiceHandler.single_param, input
1657+
)
1658+
elif op == "multi_param":
1659+
return await nexus_client.execute_operation(
1660+
OverloadTestServiceHandler.multi_param, input
1661+
)
1662+
elif op == "by_name":
1663+
return await nexus_client.execute_operation(
1664+
OverloadTestServiceHandler.by_name, input
1665+
)
1666+
elif op == "by_name_multi_param":
1667+
return await nexus_client.execute_operation(
1668+
OverloadTestServiceHandler.by_name_multi_param, input
1669+
)
1670+
else:
1671+
raise ValueError(f"Unknown op: {op}")
1672+
1673+
1674+
@pytest.mark.parametrize(
1675+
"op",
1676+
[
1677+
"no_param",
1678+
"single_param",
1679+
"multi_param",
1680+
"by_name",
1681+
"by_name_multi_param",
1682+
],
1683+
)
1684+
async def test_workflow_run_operation_overloads(client: Client, op: str):
1685+
task_queue = str(uuid.uuid4())
1686+
async with Worker(
1687+
client,
1688+
task_queue=task_queue,
1689+
workflows=[
1690+
OverloadTestCallerWorkflow,
1691+
OverloadTestHandlerWorkflow,
1692+
OverloadTestHandlerWorkflowNoParam,
1693+
],
1694+
nexus_service_handlers=[OverloadTestServiceHandler()],
1695+
):
1696+
await create_nexus_endpoint(task_queue, client)
1697+
res = await client.execute_workflow(
1698+
OverloadTestCallerWorkflow.run,
1699+
args=[op, 2],
1700+
id=str(uuid.uuid4()),
1701+
task_queue=task_queue,
1702+
)
1703+
assert res == (4 if op != "no_param" else 0)

0 commit comments

Comments
 (0)