diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 3d2c5ab60..10977a158 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -219,7 +219,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: self._current_history_size = 0 self._continue_as_new_suggested = False # Lazily loaded - self._memo: Optional[Mapping[str, Any]] = None + self._untyped_converted_memo: Optional[MutableMapping[str, Any]] = None # Handles which are ready to run on the next event loop iteration self._ready: Deque[asyncio.Handle] = collections.deque() self._conditions: List[Tuple[Callable[[], bool], asyncio.Future]] = [] @@ -1066,12 +1066,12 @@ def workflow_is_replaying(self) -> bool: return self._is_replaying def workflow_memo(self) -> Mapping[str, Any]: - if self._memo is None: - self._memo = { - k: self._payload_converter.from_payloads([v])[0] + if self._untyped_converted_memo is None: + self._untyped_converted_memo = { + k: self._payload_converter.from_payload(v) for k, v in self._info.raw_memo.items() } - return self._memo + return self._untyped_converted_memo def workflow_memo_value( self, key: str, default: Any, *, type_hint: Optional[Type] @@ -1081,9 +1081,52 @@ def workflow_memo_value( if default is temporalio.common._arg_unset: raise KeyError(f"Memo does not have a value for key {key}") return default - return self._payload_converter.from_payloads( - [payload], [type_hint] if type_hint else None - )[0] + return self._payload_converter.from_payload( + payload, + type_hint, # type: ignore[arg-type] + ) + + def workflow_upsert_memo(self, updates: Mapping[str, Any]) -> None: + # Converting before creating a command so that we don't leave a partial command in case of conversion failure. + update_payloads = {} + removals = [] + for k, v in updates.items(): + if v is None: + # Intentionally not checking if memo exists, so that no-op removals show up in history too. + removals.append(k) + else: + update_payloads[k] = self._payload_converter.to_payload(v) + + if not update_payloads and not removals: + return + + command = self._add_command() + fields = command.modify_workflow_properties.upserted_memo.fields + + # Updating memo inside info by downcasting to mutable mapping. + mut_raw_memo = cast( + MutableMapping[str, temporalio.api.common.v1.Payload], + self._info.raw_memo, + ) + + for k, v in update_payloads.items(): + fields[k].CopyFrom(v) + mut_raw_memo[k] = v + + if removals: + null_payload = self._payload_converter.to_payload(None) + for k in removals: + fields[k].CopyFrom(null_payload) + mut_raw_memo.pop(k, None) + + # Keeping deserialized memo dict in sync, if exists + if self._untyped_converted_memo is not None: + for k, v in update_payloads.items(): + self._untyped_converted_memo[k] = self._payload_converter.from_payload( + v + ) + for k in removals: + self._untyped_converted_memo.pop(k, None) def workflow_metric_meter(self) -> temporalio.common.MetricMeter: # Create if not present, which means using an extern function diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 1a56be00d..608275a7c 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -739,6 +739,9 @@ def workflow_memo_value( self, key: str, default: Any, *, type_hint: Optional[Type] ) -> Any: ... + @abstractmethod + def workflow_upsert_memo(self, updates: Mapping[str, Any]) -> None: ... + @abstractmethod def workflow_metric_meter(self) -> temporalio.common.MetricMeter: ... @@ -986,6 +989,17 @@ def memo_value( return _Runtime.current().workflow_memo_value(key, default, type_hint=type_hint) +def upsert_memo(updates: Mapping[str, Any]) -> None: + """Adds, modifies, and/or removes memos, with upsert semantics. + + Every memo that has a matching key has its value replaced with the one specified in ``updates``. + If the value is set to ``None``, the memo is removed instead. + For every key with no existing memo, a new memo is added with specified value (unless the value is ``None``). + Memos with keys not included in ``updates`` remain unchanged. + """ + return _Runtime.current().workflow_upsert_memo(updates) + + def get_current_details() -> str: """Get the current details of the workflow which may appear in the UI/CLI. Unlike static details set at start, this value can be updated throughout diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 7373f33b5..542576ab7 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -3127,24 +3127,83 @@ class MemoValue: class MemoWorkflow: @workflow.run async def run(self, run_child: bool) -> None: - # Check untyped memo - assert workflow.memo()["my_memo"] == {"field1": "foo"} - # Check typed memo - assert workflow.memo_value("my_memo", type_hint=MemoValue) == MemoValue( - field1="foo" + expected_memo = { + "dict_memo": {"field1": "dict"}, + "dataclass_memo": {"field1": "data"}, + "changed_memo": {"field1": "old value"}, + "removed_memo": {"field1": "removed"}, + } + + # Test getting all memos (child) + # Alternating order of operations between parent and child workflow for more coverage + if run_child: + assert workflow.memo() == expected_memo + + # Test getting single memo with and without type hint + assert workflow.memo_value("dict_memo", type_hint=MemoValue) == MemoValue( + field1="dict" ) - # Check default - assert workflow.memo_value("absent_memo", "blah") == "blah" - # Check key error - try: + assert workflow.memo_value("dict_memo") == {"field1": "dict"} + assert workflow.memo_value("dataclass_memo", type_hint=MemoValue) == MemoValue( + field1="data" + ) + assert workflow.memo_value("dataclass_memo") == {"field1": "data"} + + # Test getting all memos (parent) + if not run_child: + assert workflow.memo() == expected_memo + + # Test missing value handling + with pytest.raises(KeyError): + workflow.memo_value("absent_memo", type_hint=MemoValue) + with pytest.raises(KeyError): workflow.memo_value("absent_memo") - assert False - except KeyError: - pass - # Run child if requested + + # Test default value handling + assert ( + workflow.memo_value("absent_memo", "default value", type_hint=MemoValue) + == "default value" + ) + assert workflow.memo_value("absent_memo", "default value") == "default value" + assert workflow.memo_value( + "dict_memo", "default value", type_hint=MemoValue + ) == MemoValue(field1="dict") + assert workflow.memo_value("dict_memo", "default value") == {"field1": "dict"} + + # Saving original memo to pass to child workflow + old_memo = dict(workflow.memo()) + + # Test upsert + assert workflow.memo_value("changed_memo", type_hint=MemoValue) == MemoValue( + field1="old value" + ) + assert workflow.memo_value("removed_memo", type_hint=MemoValue) == MemoValue( + field1="removed" + ) + with pytest.raises(KeyError): + workflow.memo_value("added_memo", type_hint=MemoValue) + + workflow.upsert_memo( + { + "changed_memo": MemoValue(field1="new value"), + "added_memo": MemoValue(field1="added"), + "removed_memo": None, + } + ) + + assert workflow.memo_value("changed_memo", type_hint=MemoValue) == MemoValue( + field1="new value" + ) + assert workflow.memo_value("added_memo", type_hint=MemoValue) == MemoValue( + field1="added" + ) + with pytest.raises(KeyError): + workflow.memo_value("removed_memo", type_hint=MemoValue) + + # Run second time as child workflow if run_child: await workflow.execute_child_workflow( - MemoWorkflow.run, False, memo=workflow.memo() + MemoWorkflow.run, False, memo=old_memo ) @@ -3156,24 +3215,33 @@ async def test_workflow_memo(client: Client): True, id=f"workflow-{uuid.uuid4()}", task_queue=worker.task_queue, - memo={"my_memo": MemoValue(field1="foo")}, + memo={ + "dict_memo": {"field1": "dict"}, + "dataclass_memo": MemoValue(field1="data"), + "changed_memo": MemoValue(field1="old value"), + "removed_memo": MemoValue(field1="removed"), + }, ) await handle.result() desc = await handle.describe() # Check untyped memo - assert (await desc.memo())["my_memo"] == {"field1": "foo"} + assert (await desc.memo()) == { + "dict_memo": {"field1": "dict"}, + "dataclass_memo": {"field1": "data"}, + "changed_memo": {"field1": "new value"}, + "added_memo": {"field1": "added"}, + } # Check typed memo - assert (await desc.memo_value("my_memo", type_hint=MemoValue)) == MemoValue( - field1="foo" - ) + assert ( + await desc.memo_value("dataclass_memo", type_hint=MemoValue) + ) == MemoValue(field1="data") # Check default - assert (await desc.memo_value("absent_memo", "blah")) == "blah" + assert ( + await desc.memo_value("absent_memo", "default value") + ) == "default value" # Check key error - try: + with pytest.raises(KeyError): await desc.memo_value("absent_memo") - assert False - except KeyError: - pass @workflow.defn