Skip to content

Commit 257f143

Browse files
authored
Upsert memo support (#858)
Fixes #190
1 parent 2864297 commit 257f143

File tree

3 files changed

+157
-32
lines changed

3 files changed

+157
-32
lines changed

temporalio/worker/_workflow_instance.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
219219
self._current_history_size = 0
220220
self._continue_as_new_suggested = False
221221
# Lazily loaded
222-
self._memo: Optional[Mapping[str, Any]] = None
222+
self._untyped_converted_memo: Optional[MutableMapping[str, Any]] = None
223223
# Handles which are ready to run on the next event loop iteration
224224
self._ready: Deque[asyncio.Handle] = collections.deque()
225225
self._conditions: List[Tuple[Callable[[], bool], asyncio.Future]] = []
@@ -1066,12 +1066,12 @@ def workflow_is_replaying(self) -> bool:
10661066
return self._is_replaying
10671067

10681068
def workflow_memo(self) -> Mapping[str, Any]:
1069-
if self._memo is None:
1070-
self._memo = {
1071-
k: self._payload_converter.from_payloads([v])[0]
1069+
if self._untyped_converted_memo is None:
1070+
self._untyped_converted_memo = {
1071+
k: self._payload_converter.from_payload(v)
10721072
for k, v in self._info.raw_memo.items()
10731073
}
1074-
return self._memo
1074+
return self._untyped_converted_memo
10751075

10761076
def workflow_memo_value(
10771077
self, key: str, default: Any, *, type_hint: Optional[Type]
@@ -1081,9 +1081,52 @@ def workflow_memo_value(
10811081
if default is temporalio.common._arg_unset:
10821082
raise KeyError(f"Memo does not have a value for key {key}")
10831083
return default
1084-
return self._payload_converter.from_payloads(
1085-
[payload], [type_hint] if type_hint else None
1086-
)[0]
1084+
return self._payload_converter.from_payload(
1085+
payload,
1086+
type_hint, # type: ignore[arg-type]
1087+
)
1088+
1089+
def workflow_upsert_memo(self, updates: Mapping[str, Any]) -> None:
1090+
# Converting before creating a command so that we don't leave a partial command in case of conversion failure.
1091+
update_payloads = {}
1092+
removals = []
1093+
for k, v in updates.items():
1094+
if v is None:
1095+
# Intentionally not checking if memo exists, so that no-op removals show up in history too.
1096+
removals.append(k)
1097+
else:
1098+
update_payloads[k] = self._payload_converter.to_payload(v)
1099+
1100+
if not update_payloads and not removals:
1101+
return
1102+
1103+
command = self._add_command()
1104+
fields = command.modify_workflow_properties.upserted_memo.fields
1105+
1106+
# Updating memo inside info by downcasting to mutable mapping.
1107+
mut_raw_memo = cast(
1108+
MutableMapping[str, temporalio.api.common.v1.Payload],
1109+
self._info.raw_memo,
1110+
)
1111+
1112+
for k, v in update_payloads.items():
1113+
fields[k].CopyFrom(v)
1114+
mut_raw_memo[k] = v
1115+
1116+
if removals:
1117+
null_payload = self._payload_converter.to_payload(None)
1118+
for k in removals:
1119+
fields[k].CopyFrom(null_payload)
1120+
mut_raw_memo.pop(k, None)
1121+
1122+
# Keeping deserialized memo dict in sync, if exists
1123+
if self._untyped_converted_memo is not None:
1124+
for k, v in update_payloads.items():
1125+
self._untyped_converted_memo[k] = self._payload_converter.from_payload(
1126+
v
1127+
)
1128+
for k in removals:
1129+
self._untyped_converted_memo.pop(k, None)
10871130

10881131
def workflow_metric_meter(self) -> temporalio.common.MetricMeter:
10891132
# Create if not present, which means using an extern function

temporalio/workflow.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,9 @@ def workflow_memo_value(
739739
self, key: str, default: Any, *, type_hint: Optional[Type]
740740
) -> Any: ...
741741

742+
@abstractmethod
743+
def workflow_upsert_memo(self, updates: Mapping[str, Any]) -> None: ...
744+
742745
@abstractmethod
743746
def workflow_metric_meter(self) -> temporalio.common.MetricMeter: ...
744747

@@ -986,6 +989,17 @@ def memo_value(
986989
return _Runtime.current().workflow_memo_value(key, default, type_hint=type_hint)
987990

988991

992+
def upsert_memo(updates: Mapping[str, Any]) -> None:
993+
"""Adds, modifies, and/or removes memos, with upsert semantics.
994+
995+
Every memo that has a matching key has its value replaced with the one specified in ``updates``.
996+
If the value is set to ``None``, the memo is removed instead.
997+
For every key with no existing memo, a new memo is added with specified value (unless the value is ``None``).
998+
Memos with keys not included in ``updates`` remain unchanged.
999+
"""
1000+
return _Runtime.current().workflow_upsert_memo(updates)
1001+
1002+
9891003
def get_current_details() -> str:
9901004
"""Get the current details of the workflow which may appear in the UI/CLI.
9911005
Unlike static details set at start, this value can be updated throughout

tests/worker/test_workflow.py

Lines changed: 92 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3127,24 +3127,83 @@ class MemoValue:
31273127
class MemoWorkflow:
31283128
@workflow.run
31293129
async def run(self, run_child: bool) -> None:
3130-
# Check untyped memo
3131-
assert workflow.memo()["my_memo"] == {"field1": "foo"}
3132-
# Check typed memo
3133-
assert workflow.memo_value("my_memo", type_hint=MemoValue) == MemoValue(
3134-
field1="foo"
3130+
expected_memo = {
3131+
"dict_memo": {"field1": "dict"},
3132+
"dataclass_memo": {"field1": "data"},
3133+
"changed_memo": {"field1": "old value"},
3134+
"removed_memo": {"field1": "removed"},
3135+
}
3136+
3137+
# Test getting all memos (child)
3138+
# Alternating order of operations between parent and child workflow for more coverage
3139+
if run_child:
3140+
assert workflow.memo() == expected_memo
3141+
3142+
# Test getting single memo with and without type hint
3143+
assert workflow.memo_value("dict_memo", type_hint=MemoValue) == MemoValue(
3144+
field1="dict"
31353145
)
3136-
# Check default
3137-
assert workflow.memo_value("absent_memo", "blah") == "blah"
3138-
# Check key error
3139-
try:
3146+
assert workflow.memo_value("dict_memo") == {"field1": "dict"}
3147+
assert workflow.memo_value("dataclass_memo", type_hint=MemoValue) == MemoValue(
3148+
field1="data"
3149+
)
3150+
assert workflow.memo_value("dataclass_memo") == {"field1": "data"}
3151+
3152+
# Test getting all memos (parent)
3153+
if not run_child:
3154+
assert workflow.memo() == expected_memo
3155+
3156+
# Test missing value handling
3157+
with pytest.raises(KeyError):
3158+
workflow.memo_value("absent_memo", type_hint=MemoValue)
3159+
with pytest.raises(KeyError):
31403160
workflow.memo_value("absent_memo")
3141-
assert False
3142-
except KeyError:
3143-
pass
3144-
# Run child if requested
3161+
3162+
# Test default value handling
3163+
assert (
3164+
workflow.memo_value("absent_memo", "default value", type_hint=MemoValue)
3165+
== "default value"
3166+
)
3167+
assert workflow.memo_value("absent_memo", "default value") == "default value"
3168+
assert workflow.memo_value(
3169+
"dict_memo", "default value", type_hint=MemoValue
3170+
) == MemoValue(field1="dict")
3171+
assert workflow.memo_value("dict_memo", "default value") == {"field1": "dict"}
3172+
3173+
# Saving original memo to pass to child workflow
3174+
old_memo = dict(workflow.memo())
3175+
3176+
# Test upsert
3177+
assert workflow.memo_value("changed_memo", type_hint=MemoValue) == MemoValue(
3178+
field1="old value"
3179+
)
3180+
assert workflow.memo_value("removed_memo", type_hint=MemoValue) == MemoValue(
3181+
field1="removed"
3182+
)
3183+
with pytest.raises(KeyError):
3184+
workflow.memo_value("added_memo", type_hint=MemoValue)
3185+
3186+
workflow.upsert_memo(
3187+
{
3188+
"changed_memo": MemoValue(field1="new value"),
3189+
"added_memo": MemoValue(field1="added"),
3190+
"removed_memo": None,
3191+
}
3192+
)
3193+
3194+
assert workflow.memo_value("changed_memo", type_hint=MemoValue) == MemoValue(
3195+
field1="new value"
3196+
)
3197+
assert workflow.memo_value("added_memo", type_hint=MemoValue) == MemoValue(
3198+
field1="added"
3199+
)
3200+
with pytest.raises(KeyError):
3201+
workflow.memo_value("removed_memo", type_hint=MemoValue)
3202+
3203+
# Run second time as child workflow
31453204
if run_child:
31463205
await workflow.execute_child_workflow(
3147-
MemoWorkflow.run, False, memo=workflow.memo()
3206+
MemoWorkflow.run, False, memo=old_memo
31483207
)
31493208

31503209

@@ -3156,24 +3215,33 @@ async def test_workflow_memo(client: Client):
31563215
True,
31573216
id=f"workflow-{uuid.uuid4()}",
31583217
task_queue=worker.task_queue,
3159-
memo={"my_memo": MemoValue(field1="foo")},
3218+
memo={
3219+
"dict_memo": {"field1": "dict"},
3220+
"dataclass_memo": MemoValue(field1="data"),
3221+
"changed_memo": MemoValue(field1="old value"),
3222+
"removed_memo": MemoValue(field1="removed"),
3223+
},
31603224
)
31613225
await handle.result()
31623226
desc = await handle.describe()
31633227
# Check untyped memo
3164-
assert (await desc.memo())["my_memo"] == {"field1": "foo"}
3228+
assert (await desc.memo()) == {
3229+
"dict_memo": {"field1": "dict"},
3230+
"dataclass_memo": {"field1": "data"},
3231+
"changed_memo": {"field1": "new value"},
3232+
"added_memo": {"field1": "added"},
3233+
}
31653234
# Check typed memo
3166-
assert (await desc.memo_value("my_memo", type_hint=MemoValue)) == MemoValue(
3167-
field1="foo"
3168-
)
3235+
assert (
3236+
await desc.memo_value("dataclass_memo", type_hint=MemoValue)
3237+
) == MemoValue(field1="data")
31693238
# Check default
3170-
assert (await desc.memo_value("absent_memo", "blah")) == "blah"
3239+
assert (
3240+
await desc.memo_value("absent_memo", "default value")
3241+
) == "default value"
31713242
# Check key error
3172-
try:
3243+
with pytest.raises(KeyError):
31733244
await desc.memo_value("absent_memo")
3174-
assert False
3175-
except KeyError:
3176-
pass
31773245

31783246

31793247
@workflow.defn

0 commit comments

Comments
 (0)