Skip to content

Commit 962a5e8

Browse files
authored
Patch support and random/UUID helpers (#35)
1 parent 39eb19d commit 962a5e8

File tree

4 files changed

+233
-5
lines changed

4 files changed

+233
-5
lines changed

temporalio/worker/workflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,5 +271,6 @@ async def _create_workflow_instance(
271271
defn=defn,
272272
info=info,
273273
type_hint_eval_str=self._type_hint_eval_str,
274+
randomness_seed=start.randomness_seed,
274275
)
275276
)

temporalio/worker/workflow_instance.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import contextvars
88
import inspect
99
import logging
10+
import random
1011
import sys
1112
import traceback
1213
from abc import ABC, abstractmethod
@@ -26,6 +27,7 @@
2627
NoReturn,
2728
Optional,
2829
Sequence,
30+
Set,
2931
Tuple,
3032
Type,
3133
TypeVar,
@@ -92,6 +94,7 @@ class WorkflowInstanceDetails:
9294
defn: temporalio.workflow._Definition
9395
info: temporalio.workflow.Info
9496
type_hint_eval_str: bool
97+
randomness_seed: int
9598

9699

97100
class WorkflowInstance(ABC):
@@ -164,6 +167,11 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
164167
# The actual instance, instantiated on first _run_once
165168
self._object: Any = None
166169
self._is_replaying: bool = False
170+
self._random = random.Random(det.randomness_seed)
171+
172+
# Patches we have been notified of and patches that have been sent
173+
self._patches_notified: Set[str] = set()
174+
self._patches_sent: Set[str] = set()
167175

168176
# We maintain signals and queries on this class since handlers can be
169177
# added during workflow execution
@@ -293,8 +301,7 @@ def _apply(
293301
elif job.HasField("query_workflow"):
294302
self._apply_query_workflow(job.query_workflow)
295303
elif job.HasField("notify_has_patch"):
296-
# TODO(cretz): This
297-
pass
304+
self._apply_notify_has_patch(job.notify_has_patch)
298305
elif job.HasField("remove_from_cache"):
299306
# Ignore, handled externally
300307
pass
@@ -321,8 +328,7 @@ def _apply(
321328
elif job.HasField("start_workflow"):
322329
self._apply_start_workflow(job.start_workflow)
323330
elif job.HasField("update_random_seed"):
324-
# TODO(cretz): This
325-
pass
331+
self._apply_update_random_seed(job.update_random_seed)
326332
else:
327333
raise RuntimeError(f"Unrecognized job: {job.WhichOneof('variant')}")
328334

@@ -391,6 +397,11 @@ async def run_query(input: HandleQueryInput) -> None:
391397
)
392398
)
393399

400+
def _apply_notify_has_patch(
401+
self, job: temporalio.bridge.proto.workflow_activation.NotifyHasPatch
402+
) -> None:
403+
self._patches_notified.add(job.patch_id)
404+
394405
def _apply_resolve_activity(
395406
self, job: temporalio.bridge.proto.workflow_activation.ResolveActivity
396407
) -> None:
@@ -589,6 +600,11 @@ async def run_workflow(input: ExecuteWorkflowInput) -> None:
589600
self._run_top_level_workflow_function(run_workflow(input))
590601
)
591602

603+
def _apply_update_random_seed(
604+
self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed
605+
) -> None:
606+
self._random.seed(job.randomness_seed)
607+
592608
#### _Runtime direct workflow call overrides ####
593609
# These are in alphabetical order and all start with "workflow_".
594610

@@ -679,6 +695,19 @@ def workflow_is_replaying(self) -> bool:
679695
def workflow_now(self) -> datetime:
680696
return datetime.utcfromtimestamp(asyncio.get_running_loop().time())
681697

698+
def workflow_patch(self, id: str, *, deprecated: bool) -> bool:
699+
use_patch = not self._is_replaying or id in self._patches_notified
700+
# Only add patch command if never sent before for this ID
701+
if use_patch and not id in self._patches_sent:
702+
command = self._add_command()
703+
command.set_patch_marker.patch_id = id
704+
command.set_patch_marker.deprecated = deprecated
705+
self._patches_sent.add(id)
706+
return use_patch
707+
708+
def workflow_random(self) -> random.Random:
709+
return self._random
710+
682711
def workflow_set_query_handler(
683712
self, name: Optional[str], handler: Optional[Callable]
684713
) -> None:

temporalio/workflow.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import asyncio
66
import inspect
77
import logging
8+
import uuid
89
from abc import ABC, abstractmethod
910
from dataclasses import dataclass
1011
from datetime import datetime, timedelta
1112
from enum import IntEnum
1213
from functools import partial
14+
from random import Random
1315
from typing import (
1416
TYPE_CHECKING,
1517
Any,
@@ -365,6 +367,14 @@ def workflow_is_replaying(self) -> bool:
365367
def workflow_now(self) -> datetime:
366368
...
367369

370+
@abstractmethod
371+
def workflow_patch(self, id: str, *, deprecated: bool) -> bool:
372+
...
373+
374+
@abstractmethod
375+
def workflow_random(self) -> Random:
376+
...
377+
368378
@abstractmethod
369379
def workflow_set_query_handler(
370380
self, name: Optional[str], handler: Optional[Callable]
@@ -436,6 +446,20 @@ async def workflow_wait_condition(
436446
...
437447

438448

449+
def deprecate_patch(id: str) -> None:
450+
"""Mark a patch as deprecated.
451+
452+
This marks a workflow that had :py:func:`patched` in a previous version of
453+
the code as no longer applicable because all workflows that use the old code
454+
path are done and will never be queried again. Therefore the old code path
455+
is removed as well.
456+
457+
Args:
458+
id: The identifier originally used with :py:func:`patched`.
459+
"""
460+
_Runtime.current().workflow_patch(id, deprecated=True)
461+
462+
439463
def info() -> Info:
440464
"""Current workflow's info.
441465
@@ -454,6 +478,51 @@ def now() -> datetime:
454478
return _Runtime.current().workflow_now()
455479

456480

481+
def patched(id: str) -> bool:
482+
"""Patch a workflow.
483+
484+
When called, this will only return true if code should take the newer path
485+
which means this is either not replaying or is replaying and has seen this
486+
patch before.
487+
488+
Use :py:func:`deprecate_patch` when all workflows are done and will never be
489+
queried again. The old code path can be used at that time too.
490+
491+
Args:
492+
id: The identifier for this patch. This identifier may be used
493+
repeatedly in the same workflow to represent the same patch
494+
495+
Returns:
496+
True if this should take the newer path, false if it should take the
497+
older path.
498+
"""
499+
return _Runtime.current().workflow_patch(id, deprecated=False)
500+
501+
502+
def random() -> Random:
503+
"""Get a deterministic pseudo-random number generator.
504+
505+
Note, this random number generator is not cryptographically safe and should
506+
not be used for security purposes.
507+
508+
Returns:
509+
The deterministically-seeded pseudo-random number generator.
510+
"""
511+
return _Runtime.current().workflow_random()
512+
513+
514+
def uuid4() -> uuid.UUID:
515+
"""Get a new, determinism-safe v4 UUID based on :py:func:`random`.
516+
517+
Note, this UUID is not cryptographically safe and should not be used for
518+
security purposes.
519+
520+
Returns:
521+
A deterministically-seeded v4 UUID.
522+
"""
523+
return uuid.UUID(bytes=random().getrandbits(16 * 8).to_bytes(16, "big"), version=4)
524+
525+
457526
async def wait_condition(
458527
fn: Callable[[], bool], *, timeout: Optional[float] = None
459528
) -> None:

tests/worker/test_workflow.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import temporalio.api.common.v1
2929
from temporalio import activity, workflow
30-
from temporalio.client import Client, WorkflowFailureError
30+
from temporalio.client import Client, WorkflowFailureError, WorkflowHandle
3131
from temporalio.common import RetryPolicy
3232
from temporalio.converter import DataConverter, PayloadCodec
3333
from temporalio.exceptions import (
@@ -1245,6 +1245,135 @@ async def test_workflow_child_already_started(client: Client):
12451245
assert err.value.cause.message == "Already started"
12461246

12471247

1248+
class PatchWorkflowBase:
1249+
def __init__(self) -> None:
1250+
self._result = "<unset>"
1251+
1252+
@workflow.query
1253+
def result(self) -> str:
1254+
return self._result
1255+
1256+
1257+
@workflow.defn(name="patch-workflow")
1258+
class PrePatchWorkflow(PatchWorkflowBase):
1259+
@workflow.run
1260+
async def run(self) -> None:
1261+
self._result = "pre-patch"
1262+
1263+
1264+
@workflow.defn(name="patch-workflow")
1265+
class PatchWorkflow(PatchWorkflowBase):
1266+
@workflow.run
1267+
async def run(self) -> None:
1268+
if workflow.patched("my-patch"):
1269+
self._result = "post-patch"
1270+
else:
1271+
self._result = "pre-patch"
1272+
1273+
1274+
@workflow.defn(name="patch-workflow")
1275+
class DeprecatePatchWorkflow(PatchWorkflowBase):
1276+
@workflow.run
1277+
async def run(self) -> None:
1278+
workflow.deprecate_patch("my-patch")
1279+
self._result = "post-patch"
1280+
1281+
1282+
@workflow.defn(name="patch-workflow")
1283+
class PostPatchWorkflow(PatchWorkflowBase):
1284+
@workflow.run
1285+
async def run(self) -> None:
1286+
self._result = "post-patch"
1287+
1288+
1289+
async def test_workflow_patch(client: Client):
1290+
workflow_run = PrePatchWorkflow.run
1291+
task_queue = str(uuid.uuid4())
1292+
1293+
async def execute() -> WorkflowHandle:
1294+
handle = await client.start_workflow(
1295+
workflow_run, id=f"workflow-{uuid.uuid4()}", task_queue=task_queue
1296+
)
1297+
await handle.result()
1298+
return handle
1299+
1300+
async def query_result(handle: WorkflowHandle) -> str:
1301+
return await handle.query(PatchWorkflowBase.result)
1302+
1303+
# Run a simple pre-patch workflow
1304+
async with new_worker(client, PrePatchWorkflow, task_queue=task_queue):
1305+
pre_patch_handle = await execute()
1306+
assert "pre-patch" == await query_result(pre_patch_handle)
1307+
1308+
# Confirm patched workflow gives old result for pre-patched but new result
1309+
# for patched
1310+
async with new_worker(client, PatchWorkflow, task_queue=task_queue):
1311+
patch_handle = await execute()
1312+
assert "post-patch" == await query_result(patch_handle)
1313+
assert "pre-patch" == await query_result(pre_patch_handle)
1314+
1315+
# Confirm what works during deprecated
1316+
async with new_worker(client, DeprecatePatchWorkflow, task_queue=task_queue):
1317+
deprecate_patch_handle = await execute()
1318+
assert "post-patch" == await query_result(deprecate_patch_handle)
1319+
assert "post-patch" == await query_result(patch_handle)
1320+
1321+
# Confirm what works when deprecation gone
1322+
async with new_worker(client, PostPatchWorkflow, task_queue=task_queue):
1323+
post_patch_handle = await execute()
1324+
assert "post-patch" == await query_result(post_patch_handle)
1325+
assert "post-patch" == await query_result(deprecate_patch_handle)
1326+
# TODO(cretz): This causes a non-determinism failure due to having the
1327+
# patch marker, but we don't have an easy way to test it
1328+
# await query_result(patch_handle)
1329+
1330+
1331+
@workflow.defn
1332+
class UUIDWorkflow:
1333+
def __init__(self) -> None:
1334+
self._result = "<unset>"
1335+
1336+
@workflow.run
1337+
async def run(self) -> None:
1338+
self._result = str(workflow.uuid4())
1339+
1340+
@workflow.query
1341+
def result(self) -> str:
1342+
return self._result
1343+
1344+
1345+
async def test_workflow_uuid(client: Client):
1346+
task_queue = str(uuid.uuid4())
1347+
async with new_worker(client, UUIDWorkflow, task_queue=task_queue):
1348+
# Get two handle UUID results
1349+
handle1 = await client.start_workflow(
1350+
UUIDWorkflow.run,
1351+
id=f"workflow-{uuid.uuid4()}",
1352+
task_queue=task_queue,
1353+
)
1354+
await handle1.result()
1355+
handle1_query_result = await handle1.query(UUIDWorkflow.result)
1356+
1357+
handle2 = await client.start_workflow(
1358+
UUIDWorkflow.run,
1359+
id=f"workflow-{uuid.uuid4()}",
1360+
task_queue=task_queue,
1361+
)
1362+
await handle2.result()
1363+
handle2_query_result = await handle2.query(UUIDWorkflow.result)
1364+
1365+
# Confirm they aren't equal to each other but they are equal to retries
1366+
# of the same query
1367+
assert handle1_query_result != handle2_query_result
1368+
assert handle1_query_result == await handle1.query(UUIDWorkflow.result)
1369+
assert handle2_query_result == await handle2.query(UUIDWorkflow.result)
1370+
1371+
# Now confirm those results are the same even on a new worker
1372+
async with new_worker(client, UUIDWorkflow, task_queue=task_queue):
1373+
assert handle1_query_result == await handle1.query(UUIDWorkflow.result)
1374+
assert handle2_query_result == await handle2.query(UUIDWorkflow.result)
1375+
1376+
12481377
# TODO:
12491378
# * Use typed dicts for activity, local activity, and child workflow configs
12501379
# * Local activity invalid options

0 commit comments

Comments
 (0)