|
4 | 4 | from datetime import timedelta
|
5 | 5 | from enum import IntEnum
|
6 | 6 | from itertools import zip_longest
|
7 |
| -from typing import Any, Callable, Literal, Union |
| 7 | +from typing import Any, Awaitable, Callable, Literal, Union |
8 | 8 |
|
9 | 9 | import nexusrpc
|
10 | 10 | import nexusrpc.handler
|
|
31 | 31 | import temporalio.api.operatorservice
|
32 | 32 | import temporalio.api.operatorservice.v1
|
33 | 33 | import temporalio.exceptions
|
| 34 | +import temporalio.nexus |
34 | 35 | from temporalio import nexus, workflow
|
35 | 36 | from temporalio.client import (
|
36 | 37 | Client,
|
@@ -1554,3 +1555,149 @@ async def test_timeout_error_raised_by_nexus_operation(client: Client):
|
1554 | 1555 | assert isinstance(err, WorkflowFailureError)
|
1555 | 1556 | assert isinstance(err.__cause__, NexusOperationError)
|
1556 | 1557 | assert isinstance(err.__cause__.__cause__, TimeoutError)
|
| 1558 | + |
| 1559 | + |
| 1560 | +# Test overloads |
| 1561 | + |
| 1562 | + |
| 1563 | +@workflow.defn |
| 1564 | +class OverloadTestHandlerWorkflow: |
| 1565 | + @workflow.run |
| 1566 | + async def run(self, input: int) -> int: |
| 1567 | + return input * 2 |
| 1568 | + |
| 1569 | + |
| 1570 | +@workflow.defn |
| 1571 | +class OverloadTestHandlerWorkflowNoParam: |
| 1572 | + @workflow.run |
| 1573 | + async def run(self) -> int: |
| 1574 | + return 0 |
| 1575 | + |
| 1576 | + |
| 1577 | +@nexusrpc.handler.service_handler |
| 1578 | +class OverloadTestServiceHandler: |
| 1579 | + @workflow_run_operation |
| 1580 | + async def no_param( |
| 1581 | + self, |
| 1582 | + ctx: WorkflowRunOperationContext, |
| 1583 | + _: int, |
| 1584 | + ) -> nexus.WorkflowHandle[int]: |
| 1585 | + return await ctx.start_workflow( |
| 1586 | + OverloadTestHandlerWorkflowNoParam.run, |
| 1587 | + id=str(uuid.uuid4()), |
| 1588 | + ) |
| 1589 | + |
| 1590 | + @workflow_run_operation |
| 1591 | + async def single_param( |
| 1592 | + self, ctx: WorkflowRunOperationContext, input: int |
| 1593 | + ) -> nexus.WorkflowHandle[int]: |
| 1594 | + return await ctx.start_workflow( |
| 1595 | + OverloadTestHandlerWorkflow.run, |
| 1596 | + input, |
| 1597 | + id=str(uuid.uuid4()), |
| 1598 | + ) |
| 1599 | + |
| 1600 | + @workflow_run_operation |
| 1601 | + async def multi_param( |
| 1602 | + self, ctx: WorkflowRunOperationContext, input: int |
| 1603 | + ) -> nexus.WorkflowHandle[int]: |
| 1604 | + return await ctx.start_workflow( |
| 1605 | + OverloadTestHandlerWorkflow.run, |
| 1606 | + args=[input], |
| 1607 | + id=str(uuid.uuid4()), |
| 1608 | + ) |
| 1609 | + |
| 1610 | + @workflow_run_operation |
| 1611 | + async def by_name( |
| 1612 | + self, ctx: WorkflowRunOperationContext, input: int |
| 1613 | + ) -> nexus.WorkflowHandle[int]: |
| 1614 | + return await ctx.start_workflow( |
| 1615 | + "OverloadTestHandlerWorkflow", |
| 1616 | + input, |
| 1617 | + id=str(uuid.uuid4()), |
| 1618 | + result_type=OverloadTestValue, |
| 1619 | + ) |
| 1620 | + |
| 1621 | + @workflow_run_operation |
| 1622 | + async def by_name_multi_param( |
| 1623 | + self, ctx: WorkflowRunOperationContext, input: int |
| 1624 | + ) -> nexus.WorkflowHandle[int]: |
| 1625 | + return await ctx.start_workflow( |
| 1626 | + "OverloadTestHandlerWorkflow", |
| 1627 | + args=[input], |
| 1628 | + id=str(uuid.uuid4()), |
| 1629 | + ) |
| 1630 | + |
| 1631 | + |
| 1632 | +@dataclass |
| 1633 | +class OverloadTestInput: |
| 1634 | + op: Callable[ |
| 1635 | + [Any, WorkflowRunOperationContext, Any], |
| 1636 | + Awaitable[temporalio.nexus.WorkflowHandle[Any]], |
| 1637 | + ] |
| 1638 | + input: Any |
| 1639 | + output: Any |
| 1640 | + |
| 1641 | + |
| 1642 | +@workflow.defn |
| 1643 | +class OverloadTestCallerWorkflow: |
| 1644 | + @workflow.run |
| 1645 | + async def run(self, op: str, input: int) -> int: |
| 1646 | + nexus_client = workflow.NexusClient( |
| 1647 | + service=OverloadTestServiceHandler, |
| 1648 | + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), |
| 1649 | + ) |
| 1650 | + if op == "no_param": |
| 1651 | + return await nexus_client.execute_operation( |
| 1652 | + OverloadTestServiceHandler.no_param, input |
| 1653 | + ) |
| 1654 | + elif op == "single_param": |
| 1655 | + return await nexus_client.execute_operation( |
| 1656 | + OverloadTestServiceHandler.single_param, input |
| 1657 | + ) |
| 1658 | + elif op == "multi_param": |
| 1659 | + return await nexus_client.execute_operation( |
| 1660 | + OverloadTestServiceHandler.multi_param, input |
| 1661 | + ) |
| 1662 | + elif op == "by_name": |
| 1663 | + return await nexus_client.execute_operation( |
| 1664 | + OverloadTestServiceHandler.by_name, input |
| 1665 | + ) |
| 1666 | + elif op == "by_name_multi_param": |
| 1667 | + return await nexus_client.execute_operation( |
| 1668 | + OverloadTestServiceHandler.by_name_multi_param, input |
| 1669 | + ) |
| 1670 | + else: |
| 1671 | + raise ValueError(f"Unknown op: {op}") |
| 1672 | + |
| 1673 | + |
| 1674 | +@pytest.mark.parametrize( |
| 1675 | + "op", |
| 1676 | + [ |
| 1677 | + "no_param", |
| 1678 | + "single_param", |
| 1679 | + "multi_param", |
| 1680 | + "by_name", |
| 1681 | + "by_name_multi_param", |
| 1682 | + ], |
| 1683 | +) |
| 1684 | +async def test_workflow_run_operation_overloads(client: Client, op: str): |
| 1685 | + task_queue = str(uuid.uuid4()) |
| 1686 | + async with Worker( |
| 1687 | + client, |
| 1688 | + task_queue=task_queue, |
| 1689 | + workflows=[ |
| 1690 | + OverloadTestCallerWorkflow, |
| 1691 | + OverloadTestHandlerWorkflow, |
| 1692 | + OverloadTestHandlerWorkflowNoParam, |
| 1693 | + ], |
| 1694 | + nexus_service_handlers=[OverloadTestServiceHandler()], |
| 1695 | + ): |
| 1696 | + await create_nexus_endpoint(task_queue, client) |
| 1697 | + res = await client.execute_workflow( |
| 1698 | + OverloadTestCallerWorkflow.run, |
| 1699 | + args=[op, 2], |
| 1700 | + id=str(uuid.uuid4()), |
| 1701 | + task_queue=task_queue, |
| 1702 | + ) |
| 1703 | + assert res == (4 if op != "no_param" else 0) |
0 commit comments