@@ -3127,24 +3127,83 @@ class MemoValue:
3127
3127
class MemoWorkflow :
3128
3128
@workflow .run
3129
3129
async def run (self , run_child : bool ) -> None :
3130
- # Check untyped memo
3131
- assert workflow .memo ()["my_memo" ] == {"field1" : "foo" }
3132
- # Check typed memo
3133
- assert workflow .memo_value ("my_memo" , type_hint = MemoValue ) == MemoValue (
3134
- field1 = "foo"
3130
+ expected_memo = {
3131
+ "dict_memo" : {"field1" : "dict" },
3132
+ "dataclass_memo" : {"field1" : "data" },
3133
+ "changed_memo" : {"field1" : "old value" },
3134
+ "removed_memo" : {"field1" : "removed" },
3135
+ }
3136
+
3137
+ # Test getting all memos (child)
3138
+ # Alternating order of operations between parent and child workflow for more coverage
3139
+ if run_child :
3140
+ assert workflow .memo () == expected_memo
3141
+
3142
+ # Test getting single memo with and without type hint
3143
+ assert workflow .memo_value ("dict_memo" , type_hint = MemoValue ) == MemoValue (
3144
+ field1 = "dict"
3135
3145
)
3136
- # Check default
3137
- assert workflow .memo_value ("absent_memo" , "blah" ) == "blah"
3138
- # Check key error
3139
- try :
3146
+ assert workflow .memo_value ("dict_memo" ) == {"field1" : "dict" }
3147
+ assert workflow .memo_value ("dataclass_memo" , type_hint = MemoValue ) == MemoValue (
3148
+ field1 = "data"
3149
+ )
3150
+ assert workflow .memo_value ("dataclass_memo" ) == {"field1" : "data" }
3151
+
3152
+ # Test getting all memos (parent)
3153
+ if not run_child :
3154
+ assert workflow .memo () == expected_memo
3155
+
3156
+ # Test missing value handling
3157
+ with pytest .raises (KeyError ):
3158
+ workflow .memo_value ("absent_memo" , type_hint = MemoValue )
3159
+ with pytest .raises (KeyError ):
3140
3160
workflow .memo_value ("absent_memo" )
3141
- assert False
3142
- except KeyError :
3143
- pass
3144
- # Run child if requested
3161
+
3162
+ # Test default value handling
3163
+ assert (
3164
+ workflow .memo_value ("absent_memo" , "default value" , type_hint = MemoValue )
3165
+ == "default value"
3166
+ )
3167
+ assert workflow .memo_value ("absent_memo" , "default value" ) == "default value"
3168
+ assert workflow .memo_value (
3169
+ "dict_memo" , "default value" , type_hint = MemoValue
3170
+ ) == MemoValue (field1 = "dict" )
3171
+ assert workflow .memo_value ("dict_memo" , "default value" ) == {"field1" : "dict" }
3172
+
3173
+ # Saving original memo to pass to child workflow
3174
+ old_memo = dict (workflow .memo ())
3175
+
3176
+ # Test upsert
3177
+ assert workflow .memo_value ("changed_memo" , type_hint = MemoValue ) == MemoValue (
3178
+ field1 = "old value"
3179
+ )
3180
+ assert workflow .memo_value ("removed_memo" , type_hint = MemoValue ) == MemoValue (
3181
+ field1 = "removed"
3182
+ )
3183
+ with pytest .raises (KeyError ):
3184
+ workflow .memo_value ("added_memo" , type_hint = MemoValue )
3185
+
3186
+ workflow .upsert_memo (
3187
+ {
3188
+ "changed_memo" : MemoValue (field1 = "new value" ),
3189
+ "added_memo" : MemoValue (field1 = "added" ),
3190
+ "removed_memo" : None ,
3191
+ }
3192
+ )
3193
+
3194
+ assert workflow .memo_value ("changed_memo" , type_hint = MemoValue ) == MemoValue (
3195
+ field1 = "new value"
3196
+ )
3197
+ assert workflow .memo_value ("added_memo" , type_hint = MemoValue ) == MemoValue (
3198
+ field1 = "added"
3199
+ )
3200
+ with pytest .raises (KeyError ):
3201
+ workflow .memo_value ("removed_memo" , type_hint = MemoValue )
3202
+
3203
+ # Run second time as child workflow
3145
3204
if run_child :
3146
3205
await workflow .execute_child_workflow (
3147
- MemoWorkflow .run , False , memo = workflow . memo ()
3206
+ MemoWorkflow .run , False , memo = old_memo
3148
3207
)
3149
3208
3150
3209
@@ -3156,24 +3215,33 @@ async def test_workflow_memo(client: Client):
3156
3215
True ,
3157
3216
id = f"workflow-{ uuid .uuid4 ()} " ,
3158
3217
task_queue = worker .task_queue ,
3159
- memo = {"my_memo" : MemoValue (field1 = "foo" )},
3218
+ memo = {
3219
+ "dict_memo" : {"field1" : "dict" },
3220
+ "dataclass_memo" : MemoValue (field1 = "data" ),
3221
+ "changed_memo" : MemoValue (field1 = "old value" ),
3222
+ "removed_memo" : MemoValue (field1 = "removed" ),
3223
+ },
3160
3224
)
3161
3225
await handle .result ()
3162
3226
desc = await handle .describe ()
3163
3227
# Check untyped memo
3164
- assert (await desc .memo ())["my_memo" ] == {"field1" : "foo" }
3228
+ assert (await desc .memo ()) == {
3229
+ "dict_memo" : {"field1" : "dict" },
3230
+ "dataclass_memo" : {"field1" : "data" },
3231
+ "changed_memo" : {"field1" : "new value" },
3232
+ "added_memo" : {"field1" : "added" },
3233
+ }
3165
3234
# Check typed memo
3166
- assert (await desc . memo_value ( "my_memo" , type_hint = MemoValue )) == MemoValue (
3167
- field1 = "foo"
3168
- )
3235
+ assert (
3236
+ await desc . memo_value ( "dataclass_memo" , type_hint = MemoValue )
3237
+ ) == MemoValue ( field1 = "data" )
3169
3238
# Check default
3170
- assert (await desc .memo_value ("absent_memo" , "blah" )) == "blah"
3239
+ assert (
3240
+ await desc .memo_value ("absent_memo" , "default value" )
3241
+ ) == "default value"
3171
3242
# Check key error
3172
- try :
3243
+ with pytest . raises ( KeyError ) :
3173
3244
await desc .memo_value ("absent_memo" )
3174
- assert False
3175
- except KeyError :
3176
- pass
3177
3245
3178
3246
3179
3247
@workflow .defn
0 commit comments