12
12
import traceback
13
13
import warnings
14
14
from abc import ABC , abstractmethod
15
+ from contextlib import contextmanager
15
16
from dataclasses import dataclass
16
17
from datetime import timedelta
17
18
from typing import (
21
22
Deque ,
22
23
Dict ,
23
24
Generator ,
25
+ Iterator ,
24
26
List ,
25
27
Mapping ,
26
28
MutableMapping ,
@@ -193,6 +195,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
193
195
self ._object : Any = None
194
196
self ._is_replaying : bool = False
195
197
self ._random = random .Random (det .randomness_seed )
198
+ self ._read_only = False
196
199
197
200
# Patches we have been notified of and memoized patch responses
198
201
self ._patches_notified : Set [str ] = set ()
@@ -421,36 +424,39 @@ async def run_query() -> None:
421
424
command = self ._add_command ()
422
425
command .respond_to_query .query_id = job .query_id
423
426
try :
424
- # Named query or dynamic
425
- defn = self ._queries .get (job .query_type ) or self ._queries .get (None )
426
- if not defn :
427
- known_queries = sorted ([k for k in self ._queries .keys () if k ])
428
- raise RuntimeError (
429
- f"Query handler for '{ job .query_type } ' expected but not found, "
430
- f"known queries: [{ ' ' .join (known_queries )} ]"
427
+ with self ._as_read_only ():
428
+ # Named query or dynamic
429
+ defn = self ._queries .get (job .query_type ) or self ._queries .get (None )
430
+ if not defn :
431
+ known_queries = sorted ([k for k in self ._queries .keys () if k ])
432
+ raise RuntimeError (
433
+ f"Query handler for '{ job .query_type } ' expected but not found, "
434
+ f"known queries: [{ ' ' .join (known_queries )} ]"
435
+ )
436
+
437
+ # Create input
438
+ args = self ._process_handler_args (
439
+ job .query_type ,
440
+ job .arguments ,
441
+ defn .name ,
442
+ defn .arg_types ,
443
+ defn .dynamic_vararg ,
431
444
)
432
-
433
- # Create input
434
- args = self ._process_handler_args (
435
- job .query_type ,
436
- job .arguments ,
437
- defn .name ,
438
- defn .arg_types ,
439
- defn .dynamic_vararg ,
440
- )
441
- input = HandleQueryInput (
442
- id = job .query_id ,
443
- query = job .query_type ,
444
- args = args ,
445
- headers = job .headers ,
446
- )
447
- success = await self ._inbound .handle_query (input )
448
- result_payloads = self ._payload_converter .to_payloads ([success ])
449
- if len (result_payloads ) != 1 :
450
- raise ValueError (
451
- f"Expected 1 result payload, got { len (result_payloads )} "
445
+ input = HandleQueryInput (
446
+ id = job .query_id ,
447
+ query = job .query_type ,
448
+ args = args ,
449
+ headers = job .headers ,
450
+ )
451
+ success = await self ._inbound .handle_query (input )
452
+ result_payloads = self ._payload_converter .to_payloads ([success ])
453
+ if len (result_payloads ) != 1 :
454
+ raise ValueError (
455
+ f"Expected 1 result payload, got { len (result_payloads )} "
456
+ )
457
+ command .respond_to_query .succeeded .response .CopyFrom (
458
+ result_payloads [0 ]
452
459
)
453
- command .respond_to_query .succeeded .response .CopyFrom (result_payloads [0 ])
454
460
except Exception as err :
455
461
try :
456
462
self ._failure_converter .to_failure (
@@ -695,6 +701,7 @@ def workflow_continue_as_new(
695
701
search_attributes : Optional [temporalio .common .SearchAttributes ],
696
702
versioning_intent : Optional [temporalio .workflow .VersioningIntent ],
697
703
) -> NoReturn :
704
+ self ._assert_not_read_only ("continue as new" )
698
705
# Use definition if callable
699
706
name : Optional [str ] = None
700
707
arg_types : Optional [List [Type ]] = None
@@ -795,12 +802,20 @@ def workflow_payload_converter(self) -> temporalio.converter.PayloadConverter:
795
802
return self ._payload_converter
796
803
797
804
def workflow_random (self ) -> random .Random :
805
+ self ._assert_not_read_only ("random" )
798
806
return self ._random
799
807
800
808
def workflow_set_query_handler (
801
809
self , name : Optional [str ], handler : Optional [Callable ]
802
810
) -> None :
811
+ self ._assert_not_read_only ("set query handler" )
803
812
if handler :
813
+ if inspect .iscoroutinefunction (handler ):
814
+ warnings .warn (
815
+ "Queries as async def functions are deprecated" ,
816
+ DeprecationWarning ,
817
+ stacklevel = 3 ,
818
+ )
804
819
defn = temporalio .workflow ._QueryDefinition (
805
820
name = name , fn = handler , is_method = False
806
821
)
@@ -817,6 +832,7 @@ def workflow_set_query_handler(
817
832
def workflow_set_signal_handler (
818
833
self , name : Optional [str ], handler : Optional [Callable ]
819
834
) -> None :
835
+ self ._assert_not_read_only ("set signal handler" )
820
836
if handler :
821
837
defn = temporalio .workflow ._SignalDefinition (
822
838
name = name , fn = handler , is_method = False
@@ -855,6 +871,7 @@ def workflow_start_activity(
855
871
activity_id : Optional [str ],
856
872
versioning_intent : Optional [temporalio .workflow .VersioningIntent ],
857
873
) -> temporalio .workflow .ActivityHandle [Any ]:
874
+ self ._assert_not_read_only ("start activity" )
858
875
# Get activity definition if it's callable
859
876
name : str
860
877
arg_types : Optional [List [Type ]] = None
@@ -1012,6 +1029,7 @@ def workflow_upsert_search_attributes(
1012
1029
async def workflow_wait_condition (
1013
1030
self , fn : Callable [[], bool ], * , timeout : Optional [float ] = None
1014
1031
) -> None :
1032
+ self ._assert_not_read_only ("wait condition" )
1015
1033
fut = self .create_future ()
1016
1034
self ._conditions .append ((fn , fut ))
1017
1035
await asyncio .wait_for (fut , timeout )
@@ -1153,8 +1171,24 @@ async def run_child() -> Any:
1153
1171
# These are in alphabetical order.
1154
1172
1155
1173
def _add_command (self ) -> temporalio .bridge .proto .workflow_commands .WorkflowCommand :
1174
+ self ._assert_not_read_only ("add command" )
1156
1175
return self ._current_completion .successful .commands .add ()
1157
1176
1177
+ @contextmanager
1178
+ def _as_read_only (self ) -> Iterator [None ]:
1179
+ prev_val = self ._read_only
1180
+ self ._read_only = True
1181
+ try :
1182
+ yield None
1183
+ finally :
1184
+ self ._read_only = prev_val
1185
+
1186
+ def _assert_not_read_only (self , action_attempted : str ) -> None :
1187
+ if self ._read_only :
1188
+ raise temporalio .workflow .ReadOnlyContextError (
1189
+ f"While in read-only function, action attempted: { action_attempted } "
1190
+ )
1191
+
1158
1192
async def _cancel_external_workflow (
1159
1193
self ,
1160
1194
# Should not have seq set
@@ -1258,6 +1292,7 @@ def _register_task(
1258
1292
* ,
1259
1293
name : Optional [str ],
1260
1294
) -> None :
1295
+ self ._assert_not_read_only ("create task" )
1261
1296
# Name not supported on older Python versions
1262
1297
if sys .version_info >= (3 , 8 ):
1263
1298
# Put the workflow info at the end of the task name
@@ -1423,6 +1458,7 @@ def call_soon(
1423
1458
* args : Any ,
1424
1459
context : Optional [contextvars .Context ] = None ,
1425
1460
) -> asyncio .Handle :
1461
+ self ._assert_not_read_only ("schedule task" )
1426
1462
handle = asyncio .Handle (callback , args , self , context )
1427
1463
self ._ready .append (handle )
1428
1464
return handle
@@ -1434,6 +1470,7 @@ def call_later(
1434
1470
* args : Any ,
1435
1471
context : Optional [contextvars .Context ] = None ,
1436
1472
) -> asyncio .TimerHandle :
1473
+ self ._assert_not_read_only ("schedule timer" )
1437
1474
# Delay must be positive
1438
1475
if delay < 0 :
1439
1476
raise RuntimeError ("Attempting to schedule timer with negative delay" )
@@ -1675,6 +1712,7 @@ def __init__(
1675
1712
instance ._register_task (self , name = f"activity: { input .activity } " )
1676
1713
1677
1714
def cancel (self , msg : Optional [Any ] = None ) -> bool :
1715
+ self ._instance ._assert_not_read_only ("cancel activity handle" )
1678
1716
# We override this because if it's not yet started and not done, we need
1679
1717
# to send a cancel command because the async function won't run to trap
1680
1718
# the cancel (i.e. cancelled before started)
@@ -1821,6 +1859,7 @@ async def signal(
1821
1859
* ,
1822
1860
args : Sequence [Any ] = [],
1823
1861
) -> None :
1862
+ self ._instance ._assert_not_read_only ("signal child handle" )
1824
1863
await self ._instance ._outbound .signal_child_workflow (
1825
1864
SignalChildWorkflowInput (
1826
1865
signal = temporalio .workflow ._SignalDefinition .must_name_from_fn_or_str (
@@ -1935,6 +1974,7 @@ async def signal(
1935
1974
* ,
1936
1975
args : Sequence [Any ] = [],
1937
1976
) -> None :
1977
+ self ._instance ._assert_not_read_only ("signal external handle" )
1938
1978
await self ._instance ._outbound .signal_external_workflow (
1939
1979
SignalExternalWorkflowInput (
1940
1980
signal = temporalio .workflow ._SignalDefinition .must_name_from_fn_or_str (
@@ -1949,6 +1989,7 @@ async def signal(
1949
1989
)
1950
1990
1951
1991
async def cancel (self ) -> None :
1992
+ self ._instance ._assert_not_read_only ("cancel external handle" )
1952
1993
command = self ._instance ._add_command ()
1953
1994
v = command .request_cancel_external_workflow_execution
1954
1995
v .workflow_execution .namespace = self ._instance ._info .namespace
0 commit comments