Skip to content

Commit a17c0ef

Browse files
authored
Disallow most workflow operations in read-only context (#351)
Fixes #250
1 parent b9df212 commit a17c0ef

File tree

5 files changed

+160
-37
lines changed

5 files changed

+160
-37
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ class GreetingWorkflow:
499499
self._complete.set()
500500

501501
@workflow.query
502-
async def current_greeting(self) -> str:
502+
def current_greeting(self) -> str:
503503
return self._current_greeting
504504

505505
```
@@ -566,7 +566,8 @@ Here are the decorators that can be applied:
566566
* Return value is ignored
567567
* `@workflow.query` - Defines a method as a query
568568
* All the same constraints as `@workflow.signal` but should return a value
569-
* Temporal queries should never mutate anything in the workflow
569+
* Should not be `async`
570+
* Temporal queries should never mutate anything in the workflow or call any calls that would mutate the workflow
570571

571572
#### Running
572573

temporalio/worker/_workflow_instance.py

Lines changed: 69 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import traceback
1313
import warnings
1414
from abc import ABC, abstractmethod
15+
from contextlib import contextmanager
1516
from dataclasses import dataclass
1617
from datetime import timedelta
1718
from typing import (
@@ -21,6 +22,7 @@
2122
Deque,
2223
Dict,
2324
Generator,
25+
Iterator,
2426
List,
2527
Mapping,
2628
MutableMapping,
@@ -193,6 +195,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
193195
self._object: Any = None
194196
self._is_replaying: bool = False
195197
self._random = random.Random(det.randomness_seed)
198+
self._read_only = False
196199

197200
# Patches we have been notified of and memoized patch responses
198201
self._patches_notified: Set[str] = set()
@@ -421,36 +424,39 @@ async def run_query() -> None:
421424
command = self._add_command()
422425
command.respond_to_query.query_id = job.query_id
423426
try:
424-
# Named query or dynamic
425-
defn = self._queries.get(job.query_type) or self._queries.get(None)
426-
if not defn:
427-
known_queries = sorted([k for k in self._queries.keys() if k])
428-
raise RuntimeError(
429-
f"Query handler for '{job.query_type}' expected but not found, "
430-
f"known queries: [{' '.join(known_queries)}]"
427+
with self._as_read_only():
428+
# Named query or dynamic
429+
defn = self._queries.get(job.query_type) or self._queries.get(None)
430+
if not defn:
431+
known_queries = sorted([k for k in self._queries.keys() if k])
432+
raise RuntimeError(
433+
f"Query handler for '{job.query_type}' expected but not found, "
434+
f"known queries: [{' '.join(known_queries)}]"
435+
)
436+
437+
# Create input
438+
args = self._process_handler_args(
439+
job.query_type,
440+
job.arguments,
441+
defn.name,
442+
defn.arg_types,
443+
defn.dynamic_vararg,
431444
)
432-
433-
# Create input
434-
args = self._process_handler_args(
435-
job.query_type,
436-
job.arguments,
437-
defn.name,
438-
defn.arg_types,
439-
defn.dynamic_vararg,
440-
)
441-
input = HandleQueryInput(
442-
id=job.query_id,
443-
query=job.query_type,
444-
args=args,
445-
headers=job.headers,
446-
)
447-
success = await self._inbound.handle_query(input)
448-
result_payloads = self._payload_converter.to_payloads([success])
449-
if len(result_payloads) != 1:
450-
raise ValueError(
451-
f"Expected 1 result payload, got {len(result_payloads)}"
445+
input = HandleQueryInput(
446+
id=job.query_id,
447+
query=job.query_type,
448+
args=args,
449+
headers=job.headers,
450+
)
451+
success = await self._inbound.handle_query(input)
452+
result_payloads = self._payload_converter.to_payloads([success])
453+
if len(result_payloads) != 1:
454+
raise ValueError(
455+
f"Expected 1 result payload, got {len(result_payloads)}"
456+
)
457+
command.respond_to_query.succeeded.response.CopyFrom(
458+
result_payloads[0]
452459
)
453-
command.respond_to_query.succeeded.response.CopyFrom(result_payloads[0])
454460
except Exception as err:
455461
try:
456462
self._failure_converter.to_failure(
@@ -695,6 +701,7 @@ def workflow_continue_as_new(
695701
search_attributes: Optional[temporalio.common.SearchAttributes],
696702
versioning_intent: Optional[temporalio.workflow.VersioningIntent],
697703
) -> NoReturn:
704+
self._assert_not_read_only("continue as new")
698705
# Use definition if callable
699706
name: Optional[str] = None
700707
arg_types: Optional[List[Type]] = None
@@ -795,12 +802,20 @@ def workflow_payload_converter(self) -> temporalio.converter.PayloadConverter:
795802
return self._payload_converter
796803

797804
def workflow_random(self) -> random.Random:
805+
self._assert_not_read_only("random")
798806
return self._random
799807

800808
def workflow_set_query_handler(
801809
self, name: Optional[str], handler: Optional[Callable]
802810
) -> None:
811+
self._assert_not_read_only("set query handler")
803812
if handler:
813+
if inspect.iscoroutinefunction(handler):
814+
warnings.warn(
815+
"Queries as async def functions are deprecated",
816+
DeprecationWarning,
817+
stacklevel=3,
818+
)
804819
defn = temporalio.workflow._QueryDefinition(
805820
name=name, fn=handler, is_method=False
806821
)
@@ -817,6 +832,7 @@ def workflow_set_query_handler(
817832
def workflow_set_signal_handler(
818833
self, name: Optional[str], handler: Optional[Callable]
819834
) -> None:
835+
self._assert_not_read_only("set signal handler")
820836
if handler:
821837
defn = temporalio.workflow._SignalDefinition(
822838
name=name, fn=handler, is_method=False
@@ -855,6 +871,7 @@ def workflow_start_activity(
855871
activity_id: Optional[str],
856872
versioning_intent: Optional[temporalio.workflow.VersioningIntent],
857873
) -> temporalio.workflow.ActivityHandle[Any]:
874+
self._assert_not_read_only("start activity")
858875
# Get activity definition if it's callable
859876
name: str
860877
arg_types: Optional[List[Type]] = None
@@ -1012,6 +1029,7 @@ def workflow_upsert_search_attributes(
10121029
async def workflow_wait_condition(
10131030
self, fn: Callable[[], bool], *, timeout: Optional[float] = None
10141031
) -> None:
1032+
self._assert_not_read_only("wait condition")
10151033
fut = self.create_future()
10161034
self._conditions.append((fn, fut))
10171035
await asyncio.wait_for(fut, timeout)
@@ -1153,8 +1171,24 @@ async def run_child() -> Any:
11531171
# These are in alphabetical order.
11541172

11551173
def _add_command(self) -> temporalio.bridge.proto.workflow_commands.WorkflowCommand:
1174+
self._assert_not_read_only("add command")
11561175
return self._current_completion.successful.commands.add()
11571176

1177+
@contextmanager
1178+
def _as_read_only(self) -> Iterator[None]:
1179+
prev_val = self._read_only
1180+
self._read_only = True
1181+
try:
1182+
yield None
1183+
finally:
1184+
self._read_only = prev_val
1185+
1186+
def _assert_not_read_only(self, action_attempted: str) -> None:
1187+
if self._read_only:
1188+
raise temporalio.workflow.ReadOnlyContextError(
1189+
f"While in read-only function, action attempted: {action_attempted}"
1190+
)
1191+
11581192
async def _cancel_external_workflow(
11591193
self,
11601194
# Should not have seq set
@@ -1258,6 +1292,7 @@ def _register_task(
12581292
*,
12591293
name: Optional[str],
12601294
) -> None:
1295+
self._assert_not_read_only("create task")
12611296
# Name not supported on older Python versions
12621297
if sys.version_info >= (3, 8):
12631298
# Put the workflow info at the end of the task name
@@ -1423,6 +1458,7 @@ def call_soon(
14231458
*args: Any,
14241459
context: Optional[contextvars.Context] = None,
14251460
) -> asyncio.Handle:
1461+
self._assert_not_read_only("schedule task")
14261462
handle = asyncio.Handle(callback, args, self, context)
14271463
self._ready.append(handle)
14281464
return handle
@@ -1434,6 +1470,7 @@ def call_later(
14341470
*args: Any,
14351471
context: Optional[contextvars.Context] = None,
14361472
) -> asyncio.TimerHandle:
1473+
self._assert_not_read_only("schedule timer")
14371474
# Delay must be positive
14381475
if delay < 0:
14391476
raise RuntimeError("Attempting to schedule timer with negative delay")
@@ -1675,6 +1712,7 @@ def __init__(
16751712
instance._register_task(self, name=f"activity: {input.activity}")
16761713

16771714
def cancel(self, msg: Optional[Any] = None) -> bool:
1715+
self._instance._assert_not_read_only("cancel activity handle")
16781716
# We override this because if it's not yet started and not done, we need
16791717
# to send a cancel command because the async function won't run to trap
16801718
# the cancel (i.e. cancelled before started)
@@ -1821,6 +1859,7 @@ async def signal(
18211859
*,
18221860
args: Sequence[Any] = [],
18231861
) -> None:
1862+
self._instance._assert_not_read_only("signal child handle")
18241863
await self._instance._outbound.signal_child_workflow(
18251864
SignalChildWorkflowInput(
18261865
signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str(
@@ -1935,6 +1974,7 @@ async def signal(
19351974
*,
19361975
args: Sequence[Any] = [],
19371976
) -> None:
1977+
self._instance._assert_not_read_only("signal external handle")
19381978
await self._instance._outbound.signal_external_workflow(
19391979
SignalExternalWorkflowInput(
19401980
signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str(
@@ -1949,6 +1989,7 @@ async def signal(
19491989
)
19501990

19511991
async def cancel(self) -> None:
1992+
self._instance._assert_not_read_only("cancel external handle")
19521993
command = self._instance._add_command()
19531994
v = command.request_cancel_external_workflow_execution
19541995
v.workflow_execution.namespace = self._instance._info.namespace

temporalio/workflow.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,8 @@ def query(
242242
):
243243
"""Decorator for a workflow query method.
244244
245-
This is set on any async or non-async method that expects to handle a
246-
query. If a function overrides one with this decorator, it too must be
247-
decorated.
245+
This is set on any non-async method that expects to handle a query. If a
246+
function overrides one with this decorator, it too must be decorated.
248247
249248
Query methods can only have positional parameters. Best practice for
250249
non-dynamic query methods is to only take a single object/dataclass
@@ -262,7 +261,15 @@ def query(
262261
present.
263262
"""
264263

265-
def with_name(name: Optional[str], fn: CallableType) -> CallableType:
264+
def with_name(
265+
name: Optional[str], fn: CallableType, *, bypass_async_check: bool = False
266+
) -> CallableType:
267+
if not bypass_async_check and inspect.iscoroutinefunction(fn):
268+
warnings.warn(
269+
"Queries as async def functions are deprecated",
270+
DeprecationWarning,
271+
stacklevel=2,
272+
)
266273
defn = _QueryDefinition(name=name, fn=fn, is_method=True)
267274
setattr(fn, "__temporal_query_definition", defn)
268275
if defn.dynamic_vararg:
@@ -279,7 +286,13 @@ def with_name(name: Optional[str], fn: CallableType) -> CallableType:
279286
return partial(with_name, name)
280287
if fn is None:
281288
raise RuntimeError("Cannot create query without function or name or dynamic")
282-
return with_name(fn.__name__, fn)
289+
if inspect.iscoroutinefunction(fn):
290+
warnings.warn(
291+
"Queries as async def functions are deprecated",
292+
DeprecationWarning,
293+
stacklevel=2,
294+
)
295+
return with_name(fn.__name__, fn, bypass_async_check=True)
283296

284297

285298
@dataclass(frozen=True)
@@ -3919,6 +3932,17 @@ def __init__(self, message: str) -> None:
39193932
self.message = message
39203933

39213934

3935+
class ReadOnlyContextError(temporalio.exceptions.TemporalError):
3936+
"""Error thrown when trying to do mutable workflow calls in a read-only
3937+
context like a query or update validator.
3938+
"""
3939+
3940+
def __init__(self, message: str) -> None:
3941+
"""Initialize a read-only context error."""
3942+
super().__init__(message)
3943+
self.message = message
3944+
3945+
39223946
class _NotInWorkflowEventLoopError(temporalio.exceptions.TemporalError):
39233947
def __init__(self, *args: object) -> None:
39243948
super().__init__("Not in workflow event loop")

tests/testing/test_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ async def run(self) -> str:
2828
return "all done"
2929

3030
@workflow.query
31-
async def current_time(self) -> float:
31+
def current_time(self) -> float:
3232
return workflow.now().timestamp()
3333

3434
@workflow.signal

tests/worker/test_workflow.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3014,7 +3014,7 @@ async def signal(self) -> None:
30143014
self._signal_count += 1
30153015

30163016
@workflow.query
3017-
async def signal_count(self) -> int:
3017+
def signal_count(self) -> int:
30183018
return self._signal_count
30193019

30203020

@@ -3097,6 +3097,63 @@ async def test_workflow_dynamic(client: Client):
30973097
assert result == DynamicWorkflowValue("some-workflow - val1 - val2")
30983098

30993099

3100+
@workflow.defn
3101+
class QueriesDoingBadThingsWorkflow:
3102+
@workflow.run
3103+
async def run(self) -> None:
3104+
await workflow.wait_condition(lambda: False)
3105+
3106+
@workflow.query
3107+
async def bad_query(self, bad_thing: str) -> str:
3108+
if bad_thing == "wait_condition":
3109+
await workflow.wait_condition(lambda: True)
3110+
elif bad_thing == "continue_as_new":
3111+
workflow.continue_as_new()
3112+
elif bad_thing == "upsert_search_attribute":
3113+
workflow.upsert_search_attributes({"foo": ["bar"]})
3114+
elif bad_thing == "start_activity":
3115+
workflow.start_activity(
3116+
"some-activity", start_to_close_timeout=timedelta(minutes=10)
3117+
)
3118+
elif bad_thing == "start_child_workflow":
3119+
await workflow.start_child_workflow("some-workflow")
3120+
elif bad_thing == "random":
3121+
workflow.random().random()
3122+
elif bad_thing == "set_query_handler":
3123+
workflow.set_query_handler("some-handler", lambda: "whatever")
3124+
elif bad_thing == "patch":
3125+
workflow.patched("some-patch")
3126+
elif bad_thing == "signal_external_handle":
3127+
await workflow.get_external_workflow_handle("some-id").signal("some-signal")
3128+
return "should never get here"
3129+
3130+
3131+
async def test_workflow_queries_doing_bad_things(client: Client):
3132+
async with new_worker(client, QueriesDoingBadThingsWorkflow) as worker:
3133+
handle = await client.start_workflow(
3134+
QueriesDoingBadThingsWorkflow.run,
3135+
id=f"wf-{uuid.uuid4()}",
3136+
task_queue=worker.task_queue,
3137+
)
3138+
3139+
async def assert_bad_query(bad_thing: str) -> None:
3140+
with pytest.raises(WorkflowQueryFailedError) as err:
3141+
_ = await handle.query(
3142+
QueriesDoingBadThingsWorkflow.bad_query, bad_thing
3143+
)
3144+
assert "While in read-only function, action attempted" in str(err)
3145+
3146+
await assert_bad_query("wait_condition")
3147+
await assert_bad_query("continue_as_new")
3148+
await assert_bad_query("upsert_search_attribute")
3149+
await assert_bad_query("start_activity")
3150+
await assert_bad_query("start_child_workflow")
3151+
await assert_bad_query("random")
3152+
await assert_bad_query("set_query_handler")
3153+
await assert_bad_query("patch")
3154+
await assert_bad_query("signal_external_handle")
3155+
3156+
31003157
# typing.Self only in 3.11+
31013158
if sys.version_info >= (3, 11):
31023159

0 commit comments

Comments
 (0)