Skip to content

Commit bbd5bf6

Browse files
committed
Use a dataclass
1 parent d121812 commit bbd5bf6

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

tests/nexus/test_workflow_caller.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,18 +1560,23 @@ async def test_timeout_error_raised_by_nexus_operation(client: Client):
15601560
# Test overloads
15611561

15621562

1563+
@dataclass
1564+
class OverloadTestValue:
1565+
value: int
1566+
1567+
15631568
@workflow.defn
15641569
class OverloadTestHandlerWorkflow:
15651570
@workflow.run
1566-
async def run(self, input: int) -> int:
1567-
return input * 2
1571+
async def run(self, input: OverloadTestValue) -> OverloadTestValue:
1572+
return OverloadTestValue(value=input.value * 2)
15681573

15691574

15701575
@workflow.defn
15711576
class OverloadTestHandlerWorkflowNoParam:
15721577
@workflow.run
1573-
async def run(self) -> int:
1574-
return 0
1578+
async def run(self) -> OverloadTestValue:
1579+
return OverloadTestValue(value=0)
15751580

15761581

15771582
@nexusrpc.handler.service_handler
@@ -1580,17 +1585,17 @@ class OverloadTestServiceHandler:
15801585
async def no_param(
15811586
self,
15821587
ctx: WorkflowRunOperationContext,
1583-
_: int,
1584-
) -> nexus.WorkflowHandle[int]:
1588+
_: OverloadTestValue,
1589+
) -> nexus.WorkflowHandle[OverloadTestValue]:
15851590
return await ctx.start_workflow(
15861591
OverloadTestHandlerWorkflowNoParam.run,
15871592
id=str(uuid.uuid4()),
15881593
)
15891594

15901595
@workflow_run_operation
15911596
async def single_param(
1592-
self, ctx: WorkflowRunOperationContext, input: int
1593-
) -> nexus.WorkflowHandle[int]:
1597+
self, ctx: WorkflowRunOperationContext, input: OverloadTestValue
1598+
) -> nexus.WorkflowHandle[OverloadTestValue]:
15941599
return await ctx.start_workflow(
15951600
OverloadTestHandlerWorkflow.run,
15961601
input,
@@ -1599,8 +1604,8 @@ async def single_param(
15991604

16001605
@workflow_run_operation
16011606
async def multi_param(
1602-
self, ctx: WorkflowRunOperationContext, input: int
1603-
) -> nexus.WorkflowHandle[int]:
1607+
self, ctx: WorkflowRunOperationContext, input: OverloadTestValue
1608+
) -> nexus.WorkflowHandle[OverloadTestValue]:
16041609
return await ctx.start_workflow(
16051610
OverloadTestHandlerWorkflow.run,
16061611
args=[input],
@@ -1609,8 +1614,8 @@ async def multi_param(
16091614

16101615
@workflow_run_operation
16111616
async def by_name(
1612-
self, ctx: WorkflowRunOperationContext, input: int
1613-
) -> nexus.WorkflowHandle[int]:
1617+
self, ctx: WorkflowRunOperationContext, input: OverloadTestValue
1618+
) -> nexus.WorkflowHandle[OverloadTestValue]:
16141619
return await ctx.start_workflow(
16151620
"OverloadTestHandlerWorkflow",
16161621
input,
@@ -1620,8 +1625,8 @@ async def by_name(
16201625

16211626
@workflow_run_operation
16221627
async def by_name_multi_param(
1623-
self, ctx: WorkflowRunOperationContext, input: int
1624-
) -> nexus.WorkflowHandle[int]:
1628+
self, ctx: WorkflowRunOperationContext, input: OverloadTestValue
1629+
) -> nexus.WorkflowHandle[OverloadTestValue]:
16251630
return await ctx.start_workflow(
16261631
"OverloadTestHandlerWorkflow",
16271632
args=[input],
@@ -1642,7 +1647,7 @@ class OverloadTestInput:
16421647
@workflow.defn
16431648
class OverloadTestCallerWorkflow:
16441649
@workflow.run
1645-
async def run(self, op: str, input: int) -> int:
1650+
async def run(self, op: str, input: OverloadTestValue) -> OverloadTestValue:
16461651
nexus_client = workflow.NexusClient(
16471652
service=OverloadTestServiceHandler,
16481653
endpoint=make_nexus_endpoint_name(workflow.info().task_queue),
@@ -1696,8 +1701,12 @@ async def test_workflow_run_operation_overloads(client: Client, op: str):
16961701
await create_nexus_endpoint(task_queue, client)
16971702
res = await client.execute_workflow(
16981703
OverloadTestCallerWorkflow.run,
1699-
args=[op, 2],
1704+
args=[op, OverloadTestValue(value=2)],
17001705
id=str(uuid.uuid4()),
17011706
task_queue=task_queue,
17021707
)
1703-
assert res == (4 if op != "no_param" else 0)
1708+
assert res == (
1709+
OverloadTestValue(value=4)
1710+
if op != "no_param"
1711+
else OverloadTestValue(value=0)
1712+
)

0 commit comments

Comments
 (0)