Skip to content

Commit 3901df8

Browse files
authored
Memoize patched calls and support UUID conversion (#192)
1 parent 87fd193 commit 3901df8

File tree

5 files changed

+113
-6
lines changed

5 files changed

+113
-6
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,11 @@ The default data converter supports converting multiple types including:
189189
* Iterables including ones JSON dump may not support by default, e.g. `set`
190190
* Any class with a `dict()` method and a static `parse_obj()` method, e.g.
191191
[Pydantic models](https://pydantic-docs.helpmanual.io/usage/models)
192+
* Note, this doesn't mean every Pydantic field can be converted, only fields which the data converter supports
192193
* [IntEnum, StrEnum](https://docs.python.org/3/library/enum.html) based enumerates
194+
* [UUID](https://docs.python.org/3/library/uuid.html)
195+
196+
This notably doesn't include any `date`, `time`, or `datetime` objects as they may not work across SDKs.
193197

194198
For converting from JSON, the workflow/activity type hint is taken into account to convert to the proper type. Care has
195199
been taken to support all common typings including `Optional`, `Union`, all forms of iterables and mappings, `NewType`,

temporalio/converter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import json
1010
import sys
1111
import traceback
12+
import uuid
1213
from abc import ABC, abstractmethod
1314
from dataclasses import dataclass
1415
from datetime import datetime
@@ -426,6 +427,9 @@ def default(self, o: Any) -> Any:
426427
# Support for non-list iterables like set
427428
if not isinstance(o, list) and isinstance(o, collections.abc.Iterable):
428429
return list(o)
430+
# Support for UUID
431+
if isinstance(o, uuid.UUID):
432+
return str(o)
429433
return super().default(o)
430434

431435

@@ -1273,6 +1277,10 @@ def value_to_type(hint: Type, value: Any) -> Any:
12731277
)
12741278
return hint(value)
12751279

1280+
# UUID
1281+
if inspect.isclass(hint) and issubclass(hint, uuid.UUID):
1282+
return hint(value)
1283+
12761284
# Iterable. We intentionally put this last as it catches several others.
12771285
if inspect.isclass(origin) and issubclass(origin, collections.abc.Iterable):
12781286
if not isinstance(value, collections.abc.Iterable):

temporalio/worker/_workflow_instance.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
191191
self._is_replaying: bool = False
192192
self._random = random.Random(det.randomness_seed)
193193

194-
# Patches we have been notified of and patches that have been sent
194+
# Patches we have been notified of and memoized patch responses
195195
self._patches_notified: Set[str] = set()
196-
self._patches_sent: Set[str] = set()
196+
self._patches_memoized: Dict[str, bool] = {}
197197

198198
# Tasks stored by asyncio are weak references and therefore can get GC'd
199199
# which can cause warnings like "Task was destroyed but it is pending!".
@@ -739,13 +739,18 @@ def workflow_memo_value(
739739
)[0]
740740

741741
def workflow_patch(self, id: str, *, deprecated: bool) -> bool:
742+
# We use a previous memoized result of this if present. If this is being
743+
# deprecated, we can still use memoized result and skip the command.
744+
use_patch = self._patches_memoized.get(id)
745+
if use_patch is not None:
746+
return use_patch
747+
742748
use_patch = not self._is_replaying or id in self._patches_notified
743-
# Only add patch command if never sent before for this ID
744-
if use_patch and not id in self._patches_sent:
749+
self._patches_memoized[id] = use_patch
750+
if use_patch:
745751
command = self._add_command()
746752
command.set_patch_marker.patch_id = id
747753
command.set_patch_marker.deprecated = deprecated
748-
self._patches_sent.add(id)
749754
return use_patch
750755

751756
def workflow_random(self) -> random.Random:

tests/test_converter.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Type,
2626
Union,
2727
)
28+
from uuid import UUID, uuid4
2829

2930
import pydantic
3031
import pytest
@@ -224,6 +225,7 @@ class NestedDataClass:
224225
foo: str
225226
bar: List[NestedDataClass] = dataclasses.field(default_factory=list)
226227
baz: Optional[NestedDataClass] = None
228+
qux: Optional[UUID] = None
227229

228230

229231
class MyTypedDict(TypedDict):
@@ -239,6 +241,7 @@ class MyTypedDictNotTotal(TypedDict, total=False):
239241
class MyPydanticClass(pydantic.BaseModel):
240242
foo: str
241243
bar: List[MyPydanticClass]
244+
baz: Optional[UUID] = None
242245

243246

244247
def test_json_type_hints():
@@ -287,6 +290,7 @@ def fail(hint: Type, value: Any) -> None:
287290
ok(NestedDataClass, NestedDataClass("foo"))
288291
ok(NestedDataClass, NestedDataClass("foo", baz=NestedDataClass("bar")))
289292
ok(NestedDataClass, NestedDataClass("foo", bar=[NestedDataClass("bar")]))
293+
ok(NestedDataClass, NestedDataClass("foo", qux=uuid4()))
290294
# Missing required dataclass fields causes failure
291295
ok(NestedDataClass, {"foo": "bar"}, NestedDataClass("bar"))
292296
fail(NestedDataClass, {})
@@ -346,6 +350,8 @@ def fail(hint: Type, value: Any) -> None:
346350
ok(SerializableEnum, SerializableEnum.FOO)
347351
ok(List[SerializableEnum], [SerializableEnum.FOO, SerializableEnum.FOO])
348352

353+
# UUID
354+
349355
# StrEnum is available in 3.11+
350356
if sys.version_info >= (3, 11):
351357
# StrEnum
@@ -364,7 +370,9 @@ def fail(hint: Type, value: Any) -> None:
364370
# Pydantic
365371
ok(
366372
MyPydanticClass,
367-
MyPydanticClass(foo="foo", bar=[MyPydanticClass(foo="baz", bar=[])]),
373+
MyPydanticClass(
374+
foo="foo", bar=[MyPydanticClass(foo="baz", bar=[])], baz=uuid4()
375+
),
368376
)
369377
ok(List[MyPydanticClass], [MyPydanticClass(foo="foo", bar=[])])
370378
fail(List[MyPydanticClass], [MyPydanticClass(foo="foo", bar=[]), 5])

tests/worker/test_workflow.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from temporalio.testing import WorkflowEnvironment
6868
from temporalio.worker import (
6969
UnsandboxedWorkflowRunner,
70+
Worker,
7071
WorkflowInstance,
7172
WorkflowInstanceDetails,
7273
WorkflowRunner,
@@ -2069,6 +2070,87 @@ async def query_result(handle: WorkflowHandle) -> str:
20692070
# await query_result(patch_handle)
20702071

20712072

2073+
@workflow.defn(name="patch-memoized")
2074+
class PatchMemoizedWorkflowUnpatched:
2075+
def __init__(self, *, should_patch: bool = False) -> None:
2076+
self.should_patch = should_patch
2077+
self._waiting_signal = True
2078+
2079+
@workflow.run
2080+
async def run(self) -> List[str]:
2081+
results: List[str] = []
2082+
if self.should_patch and workflow.patched("some-patch"):
2083+
results.append("pre-patch")
2084+
self._waiting_signal = True
2085+
await workflow.wait_condition(lambda: not self._waiting_signal)
2086+
results.append("some-value")
2087+
if self.should_patch and workflow.patched("some-patch"):
2088+
results.append("post-patch")
2089+
return results
2090+
2091+
@workflow.signal
2092+
def signal(self) -> None:
2093+
self._waiting_signal = False
2094+
2095+
@workflow.query
2096+
def waiting_signal(self) -> bool:
2097+
return self._waiting_signal
2098+
2099+
2100+
@workflow.defn(name="patch-memoized")
2101+
class PatchMemoizedWorkflowPatched(PatchMemoizedWorkflowUnpatched):
2102+
def __init__(self) -> None:
2103+
super().__init__(should_patch=True)
2104+
2105+
@workflow.run
2106+
async def run(self) -> List[str]:
2107+
return await super().run()
2108+
2109+
2110+
async def test_workflow_patch_memoized(client: Client):
2111+
# Start a worker with the workflow unpatched and wait until halfway through
2112+
task_queue = f"tq-{uuid.uuid4()}"
2113+
async with Worker(
2114+
client, task_queue=task_queue, workflows=[PatchMemoizedWorkflowUnpatched]
2115+
):
2116+
pre_patch_handle = await client.start_workflow(
2117+
PatchMemoizedWorkflowUnpatched.run,
2118+
id=f"workflow-{uuid.uuid4()}",
2119+
task_queue=task_queue,
2120+
)
2121+
2122+
# Need to wait until it has gotten halfway through
2123+
async def waiting_signal() -> bool:
2124+
return await pre_patch_handle.query(
2125+
PatchMemoizedWorkflowUnpatched.waiting_signal
2126+
)
2127+
2128+
await assert_eq_eventually(True, waiting_signal)
2129+
2130+
# Now start the worker again, but this time with a patched workflow
2131+
async with Worker(
2132+
client, task_queue=task_queue, workflows=[PatchMemoizedWorkflowPatched]
2133+
):
2134+
# Start a new workflow post patch
2135+
post_patch_handle = await client.start_workflow(
2136+
PatchMemoizedWorkflowPatched.run,
2137+
id=f"workflow-{uuid.uuid4()}",
2138+
task_queue=task_queue,
2139+
)
2140+
2141+
# Send signal to both and check results
2142+
await pre_patch_handle.signal(PatchMemoizedWorkflowPatched.signal)
2143+
await post_patch_handle.signal(PatchMemoizedWorkflowPatched.signal)
2144+
2145+
# Confirm expected values
2146+
assert ["some-value"] == await pre_patch_handle.result()
2147+
assert [
2148+
"pre-patch",
2149+
"some-value",
2150+
"post-patch",
2151+
] == await post_patch_handle.result()
2152+
2153+
20722154
@workflow.defn
20732155
class UUIDWorkflow:
20742156
def __init__(self) -> None:

0 commit comments

Comments
 (0)