Skip to content

Commit e4de8a6

Browse files
authored
Fix performance regression for JSON tagged union (#1552)
1 parent 4477692 commit e4de8a6

File tree

1 file changed

+58
-75
lines changed
  • src/serializers/type_serializers

1 file changed

+58
-75
lines changed

src/serializers/type_serializers/union.rs

Lines changed: 58 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -117,44 +117,6 @@ fn union_serialize<S>(
117117
Ok(None)
118118
}
119119

120-
fn tagged_union_serialize<S>(
121-
discriminator_value: Option<Py<PyAny>>,
122-
lookup: &HashMap<String, usize>,
123-
// if this returns `Ok(v)`, we picked a union variant to serialize, where
124-
// `S` is intermediate state which can be passed on to the finalizer
125-
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
126-
extra: &Extra,
127-
choices: &[CombinedSerializer],
128-
retry_with_lax_check: bool,
129-
) -> PyResult<Option<S>> {
130-
let mut new_extra = extra.clone();
131-
new_extra.check = SerCheck::Strict;
132-
133-
if let Some(tag) = discriminator_value {
134-
let tag_str = tag.to_string();
135-
if let Some(&serializer_index) = lookup.get(&tag_str) {
136-
let selected_serializer = &choices[serializer_index];
137-
138-
match selector(selected_serializer, &new_extra) {
139-
Ok(v) => return Ok(Some(v)),
140-
Err(_) => {
141-
if retry_with_lax_check {
142-
new_extra.check = SerCheck::Lax;
143-
if let Ok(v) = selector(selected_serializer, &new_extra) {
144-
return Ok(Some(v));
145-
}
146-
}
147-
}
148-
}
149-
}
150-
}
151-
152-
// if we haven't returned at this point, we should fallback to the union serializer
153-
// which preserves the historical expectation that we do our best with serialization
154-
// even if that means we resort to inference
155-
union_serialize(selector, extra, choices, retry_with_lax_check)
156-
}
157-
158120
impl TypeSerializer for UnionSerializer {
159121
fn to_python(
160122
&self,
@@ -267,27 +229,21 @@ impl TypeSerializer for TaggedUnionSerializer {
267229
exclude: Option<&Bound<'_, PyAny>>,
268230
extra: &Extra,
269231
) -> PyResult<PyObject> {
270-
tagged_union_serialize(
271-
self.get_discriminator_value(value, extra),
272-
&self.lookup,
232+
self.tagged_union_serialize(
233+
value,
273234
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
274235
comb_serializer.to_python(value, include, exclude, new_extra)
275236
},
276237
extra,
277-
&self.choices,
278-
self.retry_with_lax_check(),
279238
)?
280239
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
281240
}
282241

283242
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
284-
tagged_union_serialize(
285-
self.get_discriminator_value(key, extra),
286-
&self.lookup,
243+
self.tagged_union_serialize(
244+
key,
287245
|comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra),
288246
extra,
289-
&self.choices,
290-
self.retry_with_lax_check(),
291247
)?
292248
.map_or_else(|| infer_json_key(key, extra), Ok)
293249
}
@@ -300,15 +256,12 @@ impl TypeSerializer for TaggedUnionSerializer {
300256
exclude: Option<&Bound<'_, PyAny>>,
301257
extra: &Extra,
302258
) -> Result<S::Ok, S::Error> {
303-
match tagged_union_serialize(
304-
None,
305-
&self.lookup,
259+
match self.tagged_union_serialize(
260+
value,
306261
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
307262
comb_serializer.to_python(value, include, exclude, new_extra)
308263
},
309264
extra,
310-
&self.choices,
311-
self.retry_with_lax_check(),
312265
) {
313266
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
314267
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
@@ -326,36 +279,66 @@ impl TypeSerializer for TaggedUnionSerializer {
326279
}
327280

328281
impl TaggedUnionSerializer {
329-
fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option<Py<PyAny>> {
282+
fn get_discriminator_value<'py>(&self, value: &Bound<'py, PyAny>) -> Option<Bound<'py, PyAny>> {
330283
let py = value.py();
331-
let discriminator_value = match &self.discriminator {
284+
match &self.discriminator {
332285
Discriminator::LookupKey(lookup_key) => {
333286
// we're pretty lax here, we allow either dict[key] or object.key, as we very well could
334287
// be doing a discriminator lookup on a typed dict, and there's no good way to check that
335288
// at this point. we could be more strict and only do this in lax mode...
336-
let getattr_result = match value.is_instance_of::<PyDict>() {
337-
true => {
338-
let value_dict = value.downcast::<PyDict>().unwrap();
339-
lookup_key.py_get_dict_item(value_dict).ok()
340-
}
341-
false => lookup_key.simple_py_get_attr(value).ok(),
342-
};
343-
getattr_result.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py)))
289+
if let Ok(value_dict) = value.downcast::<PyDict>() {
290+
lookup_key.py_get_dict_item(value_dict).ok().flatten()
291+
} else {
292+
lookup_key.simple_py_get_attr(value).ok().flatten()
293+
}
294+
.map(|(_, tag)| tag)
344295
}
345-
Discriminator::Function(func) => func.call1(py, (value,)).ok(),
346-
};
347-
if discriminator_value.is_none() {
348-
let value_str = truncate_safe_repr(value, None);
296+
Discriminator::Function(func) => func.bind(py).call1((value,)).ok(),
297+
}
298+
}
349299

350-
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise this warning
351-
if extra.check == SerCheck::None {
352-
extra.warnings.custom_warning(
353-
format!(
354-
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
355-
)
356-
);
300+
fn tagged_union_serialize<S>(
301+
&self,
302+
value: &Bound<'_, PyAny>,
303+
// if this returns `Ok(v)`, we picked a union variant to serialize, where
304+
// `S` is intermediate state which can be passed on to the finalizer
305+
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
306+
extra: &Extra,
307+
) -> PyResult<Option<S>> {
308+
if let Some(tag) = self.get_discriminator_value(value) {
309+
let mut new_extra = extra.clone();
310+
new_extra.check = SerCheck::Strict;
311+
312+
let tag_str = tag.to_string();
313+
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
314+
let selected_serializer = &self.choices[serializer_index];
315+
316+
match selector(selected_serializer, &new_extra) {
317+
Ok(v) => return Ok(Some(v)),
318+
Err(_) => {
319+
if self.retry_with_lax_check() {
320+
new_extra.check = SerCheck::Lax;
321+
if let Ok(v) = selector(selected_serializer, &new_extra) {
322+
return Ok(Some(v));
323+
}
324+
}
325+
}
326+
}
357327
}
328+
} else if extra.check == SerCheck::None {
329+
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise
330+
// this warning
331+
let value_str = truncate_safe_repr(value, None);
332+
extra.warnings.custom_warning(
333+
format!(
334+
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
335+
)
336+
);
358337
}
359-
discriminator_value
338+
339+
// if we haven't returned at this point, we should fallback to the union serializer
340+
// which preserves the historical expectation that we do our best with serialization
341+
// even if that means we resort to inference
342+
union_serialize(selector, extra, &self.choices, self.retry_with_lax_check())
360343
}
361344
}

0 commit comments

Comments
 (0)