Skip to content

Commit 9e34301

Browse files
authored
Add tests of asyncio.Lock and asyncio.Semaphore usage (#567)
* Add tests of asyncio.Lock and asyncio.Semaphore usage
1 parent 1a68b58 commit 9e34301

File tree

2 files changed

+324
-2
lines changed

2 files changed

+324
-2
lines changed

tests/helpers/__init__.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import uuid
55
from contextlib import closing
66
from datetime import timedelta
7-
from typing import Awaitable, Callable, Optional, Sequence, Type, TypeVar
7+
from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar
88

99
from temporalio.api.common.v1 import WorkflowExecution
1010
from temporalio.api.enums.v1 import IndexedValueType
@@ -14,11 +14,12 @@
1414
)
1515
from temporalio.api.update.v1 import UpdateRef
1616
from temporalio.api.workflowservice.v1 import PollWorkflowExecutionUpdateRequest
17-
from temporalio.client import BuildIdOpAddNewDefault, Client
17+
from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle
1818
from temporalio.common import SearchAttributeKey
1919
from temporalio.service import RPCError, RPCStatusCode
2020
from temporalio.worker import Worker, WorkflowRunner
2121
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
22+
from temporalio.workflow import UpdateMethodMultiParam
2223

2324

2425
def new_worker(
@@ -128,3 +129,24 @@ async def workflow_update_exists(
128129
if err.status != RPCStatusCode.NOT_FOUND:
129130
raise
130131
return False
132+
133+
134+
# TODO: type update return value
135+
async def admitted_update_task(
136+
client: Client,
137+
handle: WorkflowHandle,
138+
update_method: UpdateMethodMultiParam,
139+
id: str,
140+
**kwargs,
141+
) -> asyncio.Task:
142+
"""
143+
Return an asyncio.Task for an update after waiting for it to be admitted.
144+
"""
145+
update_task = asyncio.create_task(
146+
handle.execute_update(update_method, id=id, **kwargs)
147+
)
148+
await assert_eq_eventually(
149+
True,
150+
lambda: workflow_update_exists(client, handle.id, id),
151+
)
152+
return update_task

tests/worker/test_workflow.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
WorkflowRunner,
108108
)
109109
from tests.helpers import (
110+
admitted_update_task,
110111
assert_eq_eventually,
111112
ensure_search_attributes_present,
112113
find_free_port,
@@ -6611,3 +6612,302 @@ async def test_alternate_async_loop_ordering(client: Client, env: WorkflowEnviro
66116612
task_queue=task_queue,
66126613
):
66136614
await handle.result()
6615+
6616+
6617+
# The following Lock and Semaphore tests test that asyncio concurrency primitives work as expected
6618+
# in workflow code. There is nothing Temporal-specific about the way that asyncio.Lock and
6619+
# asyncio.Semaphore are used here.
6620+
6621+
6622+
@activity.defn
6623+
async def noop_activity_for_lock_or_semaphore_tests() -> None:
6624+
return None
6625+
6626+
6627+
@dataclass
6628+
class LockOrSemaphoreWorkflowConcurrencySummary:
6629+
ever_in_critical_section: int
6630+
peak_in_critical_section: int
6631+
6632+
6633+
@dataclass
6634+
class UseLockOrSemaphoreWorkflowParameters:
6635+
n_coroutines: int = 0
6636+
semaphore_initial_value: Optional[int] = None
6637+
sleep: Optional[float] = None
6638+
timeout: Optional[float] = None
6639+
6640+
6641+
@workflow.defn
6642+
class CoroutinesUseLockOrSemaphoreWorkflow:
6643+
def __init__(self) -> None:
6644+
self.params: UseLockOrSemaphoreWorkflowParameters
6645+
self.lock_or_semaphore: Union[asyncio.Lock, asyncio.Semaphore]
6646+
self._currently_in_critical_section: set[str] = set()
6647+
self._ever_in_critical_section: set[str] = set()
6648+
self._peak_in_critical_section = 0
6649+
6650+
def init(self, params: UseLockOrSemaphoreWorkflowParameters):
6651+
self.params = params
6652+
if self.params.semaphore_initial_value is not None:
6653+
self.lock_or_semaphore = asyncio.Semaphore(
6654+
self.params.semaphore_initial_value
6655+
)
6656+
else:
6657+
self.lock_or_semaphore = asyncio.Lock()
6658+
6659+
@workflow.run
6660+
async def run(
6661+
self,
6662+
params: Optional[UseLockOrSemaphoreWorkflowParameters],
6663+
) -> LockOrSemaphoreWorkflowConcurrencySummary:
6664+
# TODO: Use workflow init method when it exists.
6665+
assert params
6666+
self.init(params)
6667+
await asyncio.gather(
6668+
*(self.coroutine(f"{i}") for i in range(self.params.n_coroutines))
6669+
)
6670+
assert not any(self._currently_in_critical_section)
6671+
return LockOrSemaphoreWorkflowConcurrencySummary(
6672+
len(self._ever_in_critical_section),
6673+
self._peak_in_critical_section,
6674+
)
6675+
6676+
async def coroutine(self, id: str):
6677+
if self.params.timeout:
6678+
try:
6679+
await asyncio.wait_for(
6680+
self.lock_or_semaphore.acquire(), self.params.timeout
6681+
)
6682+
except asyncio.TimeoutError:
6683+
return
6684+
else:
6685+
await self.lock_or_semaphore.acquire()
6686+
self._enters_critical_section(id)
6687+
try:
6688+
if self.params.sleep:
6689+
await asyncio.sleep(self.params.sleep)
6690+
else:
6691+
await workflow.execute_activity(
6692+
noop_activity_for_lock_or_semaphore_tests,
6693+
schedule_to_close_timeout=timedelta(seconds=30),
6694+
)
6695+
finally:
6696+
self.lock_or_semaphore.release()
6697+
self._exits_critical_section(id)
6698+
6699+
def _enters_critical_section(self, id: str) -> None:
6700+
self._currently_in_critical_section.add(id)
6701+
self._ever_in_critical_section.add(id)
6702+
self._peak_in_critical_section = max(
6703+
self._peak_in_critical_section,
6704+
len(self._currently_in_critical_section),
6705+
)
6706+
6707+
def _exits_critical_section(self, id: str) -> None:
6708+
self._currently_in_critical_section.remove(id)
6709+
6710+
6711+
@workflow.defn
6712+
class HandlerCoroutinesUseLockOrSemaphoreWorkflow(CoroutinesUseLockOrSemaphoreWorkflow):
6713+
def __init__(self) -> None:
6714+
super().__init__()
6715+
self.workflow_may_exit = False
6716+
6717+
@workflow.run
6718+
async def run(
6719+
self,
6720+
_: Optional[UseLockOrSemaphoreWorkflowParameters] = None,
6721+
) -> LockOrSemaphoreWorkflowConcurrencySummary:
6722+
await workflow.wait_condition(lambda: self.workflow_may_exit)
6723+
return LockOrSemaphoreWorkflowConcurrencySummary(
6724+
len(self._ever_in_critical_section),
6725+
self._peak_in_critical_section,
6726+
)
6727+
6728+
@workflow.update
6729+
async def my_update(self, params: UseLockOrSemaphoreWorkflowParameters):
6730+
# TODO: Use workflow init method when it exists.
6731+
if not hasattr(self, "params"):
6732+
self.init(params)
6733+
assert (update_info := workflow.current_update_info())
6734+
await self.coroutine(update_info.id)
6735+
6736+
@workflow.signal
6737+
async def finish(self):
6738+
self.workflow_may_exit = True
6739+
6740+
6741+
async def _do_workflow_coroutines_lock_or_semaphore_test(
6742+
client: Client,
6743+
params: UseLockOrSemaphoreWorkflowParameters,
6744+
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
6745+
):
6746+
async with new_worker(
6747+
client,
6748+
CoroutinesUseLockOrSemaphoreWorkflow,
6749+
activities=[noop_activity_for_lock_or_semaphore_tests],
6750+
) as worker:
6751+
summary = await client.execute_workflow(
6752+
CoroutinesUseLockOrSemaphoreWorkflow.run,
6753+
arg=params,
6754+
id=str(uuid.uuid4()),
6755+
task_queue=worker.task_queue,
6756+
)
6757+
assert summary == expectation
6758+
6759+
6760+
async def _do_update_handler_lock_or_semaphore_test(
6761+
client: Client,
6762+
env: WorkflowEnvironment,
6763+
params: UseLockOrSemaphoreWorkflowParameters,
6764+
n_updates: int,
6765+
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
6766+
):
6767+
if env.supports_time_skipping:
6768+
pytest.skip(
6769+
"Java test server: https://github.com/temporalio/sdk-java/issues/1903"
6770+
)
6771+
6772+
task_queue = "tq"
6773+
handle = await client.start_workflow(
6774+
HandlerCoroutinesUseLockOrSemaphoreWorkflow.run,
6775+
id=f"wf-{str(uuid.uuid4())}",
6776+
task_queue=task_queue,
6777+
)
6778+
# Create updates in Admitted state, before the worker starts polling.
6779+
admitted_updates = [
6780+
await admitted_update_task(
6781+
client,
6782+
handle,
6783+
HandlerCoroutinesUseLockOrSemaphoreWorkflow.my_update,
6784+
arg=params,
6785+
id=f"update-{i}",
6786+
)
6787+
for i in range(n_updates)
6788+
]
6789+
async with new_worker(
6790+
client,
6791+
HandlerCoroutinesUseLockOrSemaphoreWorkflow,
6792+
activities=[noop_activity_for_lock_or_semaphore_tests],
6793+
task_queue=task_queue,
6794+
):
6795+
for update_task in admitted_updates:
6796+
await update_task
6797+
await handle.signal(HandlerCoroutinesUseLockOrSemaphoreWorkflow.finish)
6798+
summary = await handle.result()
6799+
assert summary == expectation
6800+
6801+
6802+
async def test_workflow_coroutines_can_use_lock(client: Client):
6803+
await _do_workflow_coroutines_lock_or_semaphore_test(
6804+
client,
6805+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5),
6806+
# The lock limits concurrency to 1
6807+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6808+
ever_in_critical_section=5, peak_in_critical_section=1
6809+
),
6810+
)
6811+
6812+
6813+
async def test_update_handler_can_use_lock_to_serialize_handler_executions(
6814+
client: Client, env: WorkflowEnvironment
6815+
):
6816+
await _do_update_handler_lock_or_semaphore_test(
6817+
client,
6818+
env,
6819+
UseLockOrSemaphoreWorkflowParameters(),
6820+
n_updates=5,
6821+
# The lock limits concurrency to 1
6822+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6823+
ever_in_critical_section=5, peak_in_critical_section=1
6824+
),
6825+
)
6826+
6827+
6828+
async def test_workflow_coroutines_lock_acquisition_respects_timeout(client: Client):
6829+
await _do_workflow_coroutines_lock_or_semaphore_test(
6830+
client,
6831+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, sleep=0.5, timeout=0.1),
6832+
# Second and subsequent coroutines fail to acquire the lock due to the timeout.
6833+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6834+
ever_in_critical_section=1, peak_in_critical_section=1
6835+
),
6836+
)
6837+
6838+
6839+
async def test_update_handler_lock_acquisition_respects_timeout(
6840+
client: Client, env: WorkflowEnvironment
6841+
):
6842+
await _do_update_handler_lock_or_semaphore_test(
6843+
client,
6844+
env,
6845+
# Second and subsequent handler executions fail to acquire the lock due to the timeout.
6846+
UseLockOrSemaphoreWorkflowParameters(sleep=0.5, timeout=0.1),
6847+
n_updates=5,
6848+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6849+
ever_in_critical_section=1, peak_in_critical_section=1
6850+
),
6851+
)
6852+
6853+
6854+
async def test_workflow_coroutines_can_use_semaphore(client: Client):
6855+
await _do_workflow_coroutines_lock_or_semaphore_test(
6856+
client,
6857+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, semaphore_initial_value=3),
6858+
# The semaphore limits concurrency to 3
6859+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6860+
ever_in_critical_section=5, peak_in_critical_section=3
6861+
),
6862+
)
6863+
6864+
6865+
async def test_update_handler_can_use_semaphore_to_control_handler_execution_concurrency(
6866+
client: Client, env: WorkflowEnvironment
6867+
):
6868+
await _do_update_handler_lock_or_semaphore_test(
6869+
client,
6870+
env,
6871+
# The semaphore limits concurrency to 3
6872+
UseLockOrSemaphoreWorkflowParameters(semaphore_initial_value=3),
6873+
n_updates=5,
6874+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6875+
ever_in_critical_section=5, peak_in_critical_section=3
6876+
),
6877+
)
6878+
6879+
6880+
async def test_workflow_coroutine_semaphore_acquisition_respects_timeout(
6881+
client: Client,
6882+
):
6883+
await _do_workflow_coroutines_lock_or_semaphore_test(
6884+
client,
6885+
UseLockOrSemaphoreWorkflowParameters(
6886+
n_coroutines=5, semaphore_initial_value=3, sleep=0.5, timeout=0.1
6887+
),
6888+
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
6889+
# slot fail.
6890+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6891+
ever_in_critical_section=3, peak_in_critical_section=3
6892+
),
6893+
)
6894+
6895+
6896+
async def test_update_handler_semaphore_acquisition_respects_timeout(
6897+
client: Client, env: WorkflowEnvironment
6898+
):
6899+
await _do_update_handler_lock_or_semaphore_test(
6900+
client,
6901+
env,
6902+
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
6903+
# slot fail.
6904+
UseLockOrSemaphoreWorkflowParameters(
6905+
semaphore_initial_value=3,
6906+
sleep=0.5,
6907+
timeout=0.1,
6908+
),
6909+
n_updates=5,
6910+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6911+
ever_in_critical_section=3, peak_in_critical_section=3
6912+
),
6913+
)

0 commit comments

Comments
 (0)