Skip to content

Commit 34681ca

Browse files
authored
Reachability type wasn't passed through all the way (#343)
1 parent b902ec8 commit 34681ca

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

temporalio/client.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4794,6 +4794,9 @@ async def get_worker_task_reachability(
47944794
namespace=self._client.namespace,
47954795
build_ids=input.build_ids,
47964796
task_queues=input.task_queues,
4797+
reachability=input.reachability._to_proto()
4798+
if input.reachability
4799+
else temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_UNSPECIFIED,
47974800
)
47984801
resp = await self._client.workflow_service.get_worker_task_reachability(
47994802
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
@@ -5170,3 +5173,25 @@ def _from_proto(
51705173
return TaskReachabilityType.CLOSED_WORKFLOWS
51715174
else:
51725175
raise ValueError(f"Cannot convert reachability type: {reachability}")
5176+
5177+
def _to_proto(self) -> temporalio.api.enums.v1.TaskReachability.ValueType:
5178+
if self == TaskReachabilityType.NEW_WORKFLOWS:
5179+
return (
5180+
temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_NEW_WORKFLOWS
5181+
)
5182+
elif self == TaskReachabilityType.EXISTING_WORKFLOWS:
5183+
return (
5184+
temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_EXISTING_WORKFLOWS
5185+
)
5186+
elif self == TaskReachabilityType.OPEN_WORKFLOWS:
5187+
return (
5188+
temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_OPEN_WORKFLOWS
5189+
)
5190+
elif self == TaskReachabilityType.CLOSED_WORKFLOWS:
5191+
return (
5192+
temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_CLOSED_WORKFLOWS
5193+
)
5194+
else:
5195+
return (
5196+
temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_UNSPECIFIED
5197+
)

tests/worker/test_worker.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import temporalio.worker._worker
1111
from temporalio import activity, workflow
12-
from temporalio.client import BuildIdOpAddNewDefault, Client
12+
from temporalio.client import BuildIdOpAddNewDefault, Client, TaskReachabilityType
1313
from temporalio.testing import WorkflowEnvironment
1414
from temporalio.worker import Worker
1515
from temporalio.workflow import VersioningIntent
@@ -198,6 +198,16 @@ async def test_worker_versioning(client: Client, env: WorkflowEnvironment):
198198
build_id="2.0",
199199
use_worker_versioning=True,
200200
):
201+
# Confirm reachability type parameter is respected. If it wasn't, list would have
202+
# `OPEN_WORKFLOWS` in it.
203+
reachability = await client.get_worker_task_reachability(
204+
build_ids=["2.0"],
205+
reachability_type=TaskReachabilityType.CLOSED_WORKFLOWS,
206+
)
207+
assert reachability.build_id_reachability["2.0"].task_queue_reachability[
208+
task_queue
209+
] == [TaskReachabilityType.NEW_WORKFLOWS]
210+
201211
await wf1.signal(WaitOnSignalWorkflow.my_signal, "finish")
202212
await wf2.signal(WaitOnSignalWorkflow.my_signal, "finish")
203213
await wf1.result()

0 commit comments

Comments
 (0)