@@ -1560,18 +1560,23 @@ async def test_timeout_error_raised_by_nexus_operation(client: Client):
1560
1560
# Test overloads
1561
1561
1562
1562
1563
+ @dataclass
1564
+ class OverloadTestValue :
1565
+ value : int
1566
+
1567
+
1563
1568
@workflow .defn
1564
1569
class OverloadTestHandlerWorkflow :
1565
1570
@workflow .run
1566
- async def run (self , input : int ) -> int :
1567
- return input * 2
1571
+ async def run (self , input : OverloadTestValue ) -> OverloadTestValue :
1572
+ return OverloadTestValue ( value = input . value * 2 )
1568
1573
1569
1574
1570
1575
@workflow .defn
1571
1576
class OverloadTestHandlerWorkflowNoParam :
1572
1577
@workflow .run
1573
- async def run (self ) -> int :
1574
- return 0
1578
+ async def run (self ) -> OverloadTestValue :
1579
+ return OverloadTestValue ( value = 0 )
1575
1580
1576
1581
1577
1582
@nexusrpc .handler .service_handler
@@ -1580,17 +1585,17 @@ class OverloadTestServiceHandler:
1580
1585
async def no_param (
1581
1586
self ,
1582
1587
ctx : WorkflowRunOperationContext ,
1583
- _ : int ,
1584
- ) -> nexus .WorkflowHandle [int ]:
1588
+ _ : OverloadTestValue ,
1589
+ ) -> nexus .WorkflowHandle [OverloadTestValue ]:
1585
1590
return await ctx .start_workflow (
1586
1591
OverloadTestHandlerWorkflowNoParam .run ,
1587
1592
id = str (uuid .uuid4 ()),
1588
1593
)
1589
1594
1590
1595
@workflow_run_operation
1591
1596
async def single_param (
1592
- self , ctx : WorkflowRunOperationContext , input : int
1593
- ) -> nexus .WorkflowHandle [int ]:
1597
+ self , ctx : WorkflowRunOperationContext , input : OverloadTestValue
1598
+ ) -> nexus .WorkflowHandle [OverloadTestValue ]:
1594
1599
return await ctx .start_workflow (
1595
1600
OverloadTestHandlerWorkflow .run ,
1596
1601
input ,
@@ -1599,8 +1604,8 @@ async def single_param(
1599
1604
1600
1605
@workflow_run_operation
1601
1606
async def multi_param (
1602
- self , ctx : WorkflowRunOperationContext , input : int
1603
- ) -> nexus .WorkflowHandle [int ]:
1607
+ self , ctx : WorkflowRunOperationContext , input : OverloadTestValue
1608
+ ) -> nexus .WorkflowHandle [OverloadTestValue ]:
1604
1609
return await ctx .start_workflow (
1605
1610
OverloadTestHandlerWorkflow .run ,
1606
1611
args = [input ],
@@ -1609,8 +1614,8 @@ async def multi_param(
1609
1614
1610
1615
@workflow_run_operation
1611
1616
async def by_name (
1612
- self , ctx : WorkflowRunOperationContext , input : int
1613
- ) -> nexus .WorkflowHandle [int ]:
1617
+ self , ctx : WorkflowRunOperationContext , input : OverloadTestValue
1618
+ ) -> nexus .WorkflowHandle [OverloadTestValue ]:
1614
1619
return await ctx .start_workflow (
1615
1620
"OverloadTestHandlerWorkflow" ,
1616
1621
input ,
@@ -1620,8 +1625,8 @@ async def by_name(
1620
1625
1621
1626
@workflow_run_operation
1622
1627
async def by_name_multi_param (
1623
- self , ctx : WorkflowRunOperationContext , input : int
1624
- ) -> nexus .WorkflowHandle [int ]:
1628
+ self , ctx : WorkflowRunOperationContext , input : OverloadTestValue
1629
+ ) -> nexus .WorkflowHandle [OverloadTestValue ]:
1625
1630
return await ctx .start_workflow (
1626
1631
"OverloadTestHandlerWorkflow" ,
1627
1632
args = [input ],
@@ -1642,7 +1647,7 @@ class OverloadTestInput:
1642
1647
@workflow .defn
1643
1648
class OverloadTestCallerWorkflow :
1644
1649
@workflow .run
1645
- async def run (self , op : str , input : int ) -> int :
1650
+ async def run (self , op : str , input : OverloadTestValue ) -> OverloadTestValue :
1646
1651
nexus_client = workflow .NexusClient (
1647
1652
service = OverloadTestServiceHandler ,
1648
1653
endpoint = make_nexus_endpoint_name (workflow .info ().task_queue ),
@@ -1696,8 +1701,12 @@ async def test_workflow_run_operation_overloads(client: Client, op: str):
1696
1701
await create_nexus_endpoint (task_queue , client )
1697
1702
res = await client .execute_workflow (
1698
1703
OverloadTestCallerWorkflow .run ,
1699
- args = [op , 2 ],
1704
+ args = [op , OverloadTestValue ( value = 2 ) ],
1700
1705
id = str (uuid .uuid4 ()),
1701
1706
task_queue = task_queue ,
1702
1707
)
1703
- assert res == (4 if op != "no_param" else 0 )
1708
+ assert res == (
1709
+ OverloadTestValue (value = 4 )
1710
+ if op != "no_param"
1711
+ else OverloadTestValue (value = 0 )
1712
+ )
0 commit comments