@@ -117,44 +117,6 @@ fn union_serialize<S>(
117
117
Ok ( None )
118
118
}
119
119
120
- fn tagged_union_serialize < S > (
121
- discriminator_value : Option < Py < PyAny > > ,
122
- lookup : & HashMap < String , usize > ,
123
- // if this returns `Ok(v)`, we picked a union variant to serialize, where
124
- // `S` is intermediate state which can be passed on to the finalizer
125
- mut selector : impl FnMut ( & CombinedSerializer , & Extra ) -> PyResult < S > ,
126
- extra : & Extra ,
127
- choices : & [ CombinedSerializer ] ,
128
- retry_with_lax_check : bool ,
129
- ) -> PyResult < Option < S > > {
130
- let mut new_extra = extra. clone ( ) ;
131
- new_extra. check = SerCheck :: Strict ;
132
-
133
- if let Some ( tag) = discriminator_value {
134
- let tag_str = tag. to_string ( ) ;
135
- if let Some ( & serializer_index) = lookup. get ( & tag_str) {
136
- let selected_serializer = & choices[ serializer_index] ;
137
-
138
- match selector ( selected_serializer, & new_extra) {
139
- Ok ( v) => return Ok ( Some ( v) ) ,
140
- Err ( _) => {
141
- if retry_with_lax_check {
142
- new_extra. check = SerCheck :: Lax ;
143
- if let Ok ( v) = selector ( selected_serializer, & new_extra) {
144
- return Ok ( Some ( v) ) ;
145
- }
146
- }
147
- }
148
- }
149
- }
150
- }
151
-
152
- // if we haven't returned at this point, we should fallback to the union serializer
153
- // which preserves the historical expectation that we do our best with serialization
154
- // even if that means we resort to inference
155
- union_serialize ( selector, extra, choices, retry_with_lax_check)
156
- }
157
-
158
120
impl TypeSerializer for UnionSerializer {
159
121
fn to_python (
160
122
& self ,
@@ -267,27 +229,21 @@ impl TypeSerializer for TaggedUnionSerializer {
267
229
exclude : Option < & Bound < ' _ , PyAny > > ,
268
230
extra : & Extra ,
269
231
) -> PyResult < PyObject > {
270
- tagged_union_serialize (
271
- self . get_discriminator_value ( value, extra) ,
272
- & self . lookup ,
232
+ self . tagged_union_serialize (
233
+ value,
273
234
|comb_serializer : & CombinedSerializer , new_extra : & Extra | {
274
235
comb_serializer. to_python ( value, include, exclude, new_extra)
275
236
} ,
276
237
extra,
277
- & self . choices ,
278
- self . retry_with_lax_check ( ) ,
279
238
) ?
280
239
. map_or_else ( || infer_to_python ( value, include, exclude, extra) , Ok )
281
240
}
282
241
283
242
fn json_key < ' a > ( & self , key : & ' a Bound < ' _ , PyAny > , extra : & Extra ) -> PyResult < Cow < ' a , str > > {
284
- tagged_union_serialize (
285
- self . get_discriminator_value ( key, extra) ,
286
- & self . lookup ,
243
+ self . tagged_union_serialize (
244
+ key,
287
245
|comb_serializer : & CombinedSerializer , new_extra : & Extra | comb_serializer. json_key ( key, new_extra) ,
288
246
extra,
289
- & self . choices ,
290
- self . retry_with_lax_check ( ) ,
291
247
) ?
292
248
. map_or_else ( || infer_json_key ( key, extra) , Ok )
293
249
}
@@ -300,15 +256,12 @@ impl TypeSerializer for TaggedUnionSerializer {
300
256
exclude : Option < & Bound < ' _ , PyAny > > ,
301
257
extra : & Extra ,
302
258
) -> Result < S :: Ok , S :: Error > {
303
- match tagged_union_serialize (
304
- None ,
305
- & self . lookup ,
259
+ match self . tagged_union_serialize (
260
+ value,
306
261
|comb_serializer : & CombinedSerializer , new_extra : & Extra | {
307
262
comb_serializer. to_python ( value, include, exclude, new_extra)
308
263
} ,
309
264
extra,
310
- & self . choices ,
311
- self . retry_with_lax_check ( ) ,
312
265
) {
313
266
Ok ( Some ( v) ) => return infer_serialize ( v. bind ( value. py ( ) ) , serializer, None , None , extra) ,
314
267
Ok ( None ) => infer_serialize ( value, serializer, include, exclude, extra) ,
@@ -326,36 +279,66 @@ impl TypeSerializer for TaggedUnionSerializer {
326
279
}
327
280
328
281
impl TaggedUnionSerializer {
329
- fn get_discriminator_value ( & self , value : & Bound < ' _ , PyAny > , extra : & Extra ) -> Option < Py < PyAny > > {
282
+ fn get_discriminator_value < ' py > ( & self , value : & Bound < ' py , PyAny > ) -> Option < Bound < ' py , PyAny > > {
330
283
let py = value. py ( ) ;
331
- let discriminator_value = match & self . discriminator {
284
+ match & self . discriminator {
332
285
Discriminator :: LookupKey ( lookup_key) => {
333
286
// we're pretty lax here, we allow either dict[key] or object.key, as we very well could
334
287
// be doing a discriminator lookup on a typed dict, and there's no good way to check that
335
288
// at this point. we could be more strict and only do this in lax mode...
336
- let getattr_result = match value. is_instance_of :: < PyDict > ( ) {
337
- true => {
338
- let value_dict = value. downcast :: < PyDict > ( ) . unwrap ( ) ;
339
- lookup_key. py_get_dict_item ( value_dict) . ok ( )
340
- }
341
- false => lookup_key. simple_py_get_attr ( value) . ok ( ) ,
342
- } ;
343
- getattr_result. and_then ( |opt| opt. map ( |( _, bound) | bound. to_object ( py) ) )
289
+ if let Ok ( value_dict) = value. downcast :: < PyDict > ( ) {
290
+ lookup_key. py_get_dict_item ( value_dict) . ok ( ) . flatten ( )
291
+ } else {
292
+ lookup_key. simple_py_get_attr ( value) . ok ( ) . flatten ( )
293
+ }
294
+ . map ( |( _, tag) | tag)
344
295
}
345
- Discriminator :: Function ( func) => func. call1 ( py, ( value, ) ) . ok ( ) ,
346
- } ;
347
- if discriminator_value. is_none ( ) {
348
- let value_str = truncate_safe_repr ( value, None ) ;
296
+ Discriminator :: Function ( func) => func. bind ( py) . call1 ( ( value, ) ) . ok ( ) ,
297
+ }
298
+ }
349
299
350
- // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise this warning
351
- if extra. check == SerCheck :: None {
352
- extra. warnings . custom_warning (
353
- format ! (
354
- "Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
355
- )
356
- ) ;
300
+ fn tagged_union_serialize < S > (
301
+ & self ,
302
+ value : & Bound < ' _ , PyAny > ,
303
+ // if this returns `Ok(v)`, we picked a union variant to serialize, where
304
+ // `S` is intermediate state which can be passed on to the finalizer
305
+ mut selector : impl FnMut ( & CombinedSerializer , & Extra ) -> PyResult < S > ,
306
+ extra : & Extra ,
307
+ ) -> PyResult < Option < S > > {
308
+ if let Some ( tag) = self . get_discriminator_value ( value) {
309
+ let mut new_extra = extra. clone ( ) ;
310
+ new_extra. check = SerCheck :: Strict ;
311
+
312
+ let tag_str = tag. to_string ( ) ;
313
+ if let Some ( & serializer_index) = self . lookup . get ( & tag_str) {
314
+ let selected_serializer = & self . choices [ serializer_index] ;
315
+
316
+ match selector ( selected_serializer, & new_extra) {
317
+ Ok ( v) => return Ok ( Some ( v) ) ,
318
+ Err ( _) => {
319
+ if self . retry_with_lax_check ( ) {
320
+ new_extra. check = SerCheck :: Lax ;
321
+ if let Ok ( v) = selector ( selected_serializer, & new_extra) {
322
+ return Ok ( Some ( v) ) ;
323
+ }
324
+ }
325
+ }
326
+ }
357
327
}
328
+ } else if extra. check == SerCheck :: None {
329
+ // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise
330
+ // this warning
331
+ let value_str = truncate_safe_repr ( value, None ) ;
332
+ extra. warnings . custom_warning (
333
+ format ! (
334
+ "Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
335
+ )
336
+ ) ;
358
337
}
359
- discriminator_value
338
+
339
+ // if we haven't returned at this point, we should fallback to the union serializer
340
+ // which preserves the historical expectation that we do our best with serialization
341
+ // even if that means we resort to inference
342
+ union_serialize ( selector, extra, & self . choices , self . retry_with_lax_check ( ) )
360
343
}
361
344
}
0 commit comments