Skip to content

Commit 7747781

Browse files
davidhewittnix010
andauthored
simplify ser-as-any mechanism (#1478)
Co-authored-by: nix <khiem3t@gmail.com>
1 parent e20b74e commit 7747781

File tree

13 files changed

+449
-258
lines changed

13 files changed

+449
-258
lines changed

src/errors/validation_exception.rs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use crate::build_tools::py_schema_error_type;
1818
use crate::errors::LocItem;
1919
use crate::get_pydantic_version;
2020
use crate::input::InputType;
21-
use crate::serializers::{DuckTypingSerMode, Extra, SerMode, SerializationState};
21+
use crate::serializers::{Extra, SerMode, SerializationState};
2222
use crate::tools::{safe_repr, write_truncated_to_limited_bytes, SchemaDict};
2323

2424
use super::line_error::ValLineError;
@@ -341,17 +341,7 @@ impl ValidationError {
341341
include_input: bool,
342342
) -> PyResult<Bound<'py, PyString>> {
343343
let state = SerializationState::new("iso8601", "utf8", "constants")?;
344-
let extra = state.extra(
345-
py,
346-
&SerMode::Json,
347-
None,
348-
false,
349-
false,
350-
true,
351-
None,
352-
DuckTypingSerMode::SchemaBased,
353-
None,
354-
);
344+
let extra = state.extra(py, &SerMode::Json, None, false, false, true, None, false, None);
355345
let serializer = ValidationErrorSerializer {
356346
py,
357347
line_errors: &self.line_errors,

src/serializers/computed_fields.rs

Lines changed: 101 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@ use pyo3::prelude::*;
22
use pyo3::types::{PyDict, PyList, PyString};
33
use pyo3::{intern, PyTraverseError, PyVisit};
44
use serde::ser::SerializeMap;
5-
use serde::Serialize;
65

76
use crate::build_tools::py_schema_error_type;
87
use crate::definitions::DefinitionsBuilder;
98
use crate::py_gc::PyGcTraverse;
109
use crate::serializers::filter::SchemaFilter;
11-
use crate::serializers::shared::{BuildSerializer, CombinedSerializer, PydanticSerializer, TypeSerializer};
10+
use crate::serializers::shared::{BuildSerializer, CombinedSerializer, PydanticSerializer};
1211
use crate::tools::SchemaDict;
1312

1413
use super::errors::py_err_se_err;
@@ -48,18 +47,31 @@ impl ComputedFields {
4847
exclude: Option<&Bound<'_, PyAny>>,
4948
extra: &Extra,
5049
) -> PyResult<()> {
51-
if extra.round_trip {
52-
// Do not serialize computed fields
53-
return Ok(());
54-
}
55-
for computed_field in &self.0 {
56-
let field_extra = Extra {
57-
field_name: Some(computed_field.property_name.as_str()),
58-
..*extra
59-
};
60-
computed_field.to_python(model, output_dict, filter, include, exclude, &field_extra)?;
61-
}
62-
Ok(())
50+
self.serialize_fields(
51+
model,
52+
filter,
53+
include,
54+
exclude,
55+
extra,
56+
|e| e,
57+
|ComputedFieldToSerialize {
58+
computed_field,
59+
value,
60+
include,
61+
exclude,
62+
field_extra,
63+
}| {
64+
let key = match field_extra.serialize_by_alias_or(computed_field.serialize_by_alias) {
65+
true => computed_field.alias_py.bind(model.py()),
66+
false => computed_field.property_name_py.bind(model.py()),
67+
};
68+
let value =
69+
computed_field
70+
.serializer
71+
.to_python(&value, include.as_ref(), exclude.as_ref(), &field_extra)?;
72+
output_dict.set_item(key, value)
73+
},
74+
)
6375
}
6476

6577
pub fn serde_serialize<S: serde::ser::Serializer>(
@@ -71,44 +83,96 @@ impl ComputedFields {
7183
exclude: Option<&Bound<'_, PyAny>>,
7284
extra: &Extra,
7385
) -> Result<(), S::Error> {
86+
self.serialize_fields(
87+
model,
88+
filter,
89+
include,
90+
exclude,
91+
extra,
92+
py_err_se_err,
93+
|ComputedFieldToSerialize {
94+
computed_field,
95+
value,
96+
include,
97+
exclude,
98+
field_extra,
99+
}| {
100+
let key = match field_extra.serialize_by_alias_or(computed_field.serialize_by_alias) {
101+
true => &computed_field.alias,
102+
false => &computed_field.property_name,
103+
};
104+
let s = PydanticSerializer::new(
105+
&value,
106+
&computed_field.serializer,
107+
include.as_ref(),
108+
exclude.as_ref(),
109+
&field_extra,
110+
);
111+
map.serialize_entry(key, &s)
112+
},
113+
)
114+
}
115+
116+
/// Iterate each field for serialization, filtering on
117+
/// `include` and `exclude` if provided.
118+
#[allow(clippy::too_many_arguments)]
119+
fn serialize_fields<'a, 'py, E>(
120+
&'a self,
121+
model: &'a Bound<'py, PyAny>,
122+
filter: &'a SchemaFilter<isize>,
123+
include: Option<&'a Bound<'py, PyAny>>,
124+
exclude: Option<&'a Bound<'py, PyAny>>,
125+
extra: &'a Extra,
126+
convert_error: impl FnOnce(PyErr) -> E,
127+
mut serialize: impl FnMut(ComputedFieldToSerialize<'a, 'py>) -> Result<(), E>,
128+
) -> Result<(), E> {
74129
if extra.round_trip {
75130
// Do not serialize computed fields
76131
return Ok(());
77132
}
78133

79134
for computed_field in &self.0 {
80135
let property_name_py = computed_field.property_name_py.bind(model.py());
136+
let (next_include, next_exclude) = match filter.key_filter(property_name_py, include, exclude) {
137+
Ok(Some((next_include, next_exclude))) => (next_include, next_exclude),
138+
Ok(None) => continue,
139+
Err(e) => return Err(convert_error(e)),
140+
};
81141

82-
if let Some((next_include, next_exclude)) = filter
83-
.key_filter(property_name_py, include, exclude)
84-
.map_err(py_err_se_err)?
85-
{
86-
let value = model.getattr(property_name_py).map_err(py_err_se_err)?;
87-
if extra.exclude_none && value.is_none() {
88-
continue;
142+
let value = match model.getattr(property_name_py) {
143+
Ok(field_value) => field_value,
144+
Err(e) => {
145+
return Err(convert_error(e));
89146
}
90-
let field_extra = Extra {
91-
field_name: Some(computed_field.property_name.as_str()),
92-
..*extra
93-
};
94-
let cfs = ComputedFieldSerializer {
95-
model,
96-
computed_field,
97-
include: next_include.as_ref(),
98-
exclude: next_exclude.as_ref(),
99-
extra: &field_extra,
100-
};
101-
let key = match extra.serialize_by_alias_or(computed_field.serialize_by_alias) {
102-
true => computed_field.alias.as_str(),
103-
false => computed_field.property_name.as_str(),
104-
};
105-
map.serialize_entry(key, &cfs)?;
147+
};
148+
if extra.exclude_none && value.is_none() {
149+
continue;
106150
}
151+
152+
let field_extra = Extra {
153+
field_name: Some(&computed_field.property_name),
154+
..*extra
155+
};
156+
serialize(ComputedFieldToSerialize {
157+
computed_field,
158+
value,
159+
include: next_include,
160+
exclude: next_exclude,
161+
field_extra,
162+
})?;
107163
}
108164
Ok(())
109165
}
110166
}
111167

168+
struct ComputedFieldToSerialize<'a, 'py> {
169+
computed_field: &'a ComputedField,
170+
value: Bound<'py, PyAny>,
171+
include: Option<Bound<'py, PyAny>>,
172+
exclude: Option<Bound<'py, PyAny>>,
173+
field_extra: Extra<'a>,
174+
}
175+
112176
#[derive(Debug)]
113177
struct ComputedField {
114178
property_name: String,
@@ -143,44 +207,6 @@ impl ComputedField {
143207
serialize_by_alias: config.get_as(intern!(py, "serialize_by_alias"))?,
144208
})
145209
}
146-
147-
fn to_python(
148-
&self,
149-
model: &Bound<'_, PyAny>,
150-
output_dict: &Bound<'_, PyDict>,
151-
filter: &SchemaFilter<isize>,
152-
include: Option<&Bound<'_, PyAny>>,
153-
exclude: Option<&Bound<'_, PyAny>>,
154-
extra: &Extra,
155-
) -> PyResult<()> {
156-
let py = model.py();
157-
let property_name_py = self.property_name_py.bind(py);
158-
159-
if let Some((next_include, next_exclude)) = filter.key_filter(property_name_py, include, exclude)? {
160-
let next_value = model.getattr(property_name_py)?;
161-
162-
let value = self
163-
.serializer
164-
.to_python(&next_value, next_include.as_ref(), next_exclude.as_ref(), extra)?;
165-
if extra.exclude_none && value.is_none(py) {
166-
return Ok(());
167-
}
168-
let key = match extra.serialize_by_alias_or(self.serialize_by_alias) {
169-
true => self.alias_py.bind(py),
170-
false => property_name_py,
171-
};
172-
output_dict.set_item(key, value)?;
173-
}
174-
Ok(())
175-
}
176-
}
177-
178-
pub(crate) struct ComputedFieldSerializer<'py> {
179-
model: &'py Bound<'py, PyAny>,
180-
computed_field: &'py ComputedField,
181-
include: Option<&'py Bound<'py, PyAny>>,
182-
exclude: Option<&'py Bound<'py, PyAny>>,
183-
extra: &'py Extra<'py>,
184210
}
185211

186212
impl_py_gc_traverse!(ComputedField { serializer });
@@ -190,21 +216,3 @@ impl PyGcTraverse for ComputedFields {
190216
self.0.py_gc_traverse(visit)
191217
}
192218
}
193-
194-
impl_py_gc_traverse!(ComputedFieldSerializer<'_> { computed_field });
195-
196-
impl Serialize for ComputedFieldSerializer<'_> {
197-
fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
198-
let py = self.model.py();
199-
let property_name_py = self.computed_field.property_name_py.bind(py);
200-
let next_value = self.model.getattr(property_name_py).map_err(py_err_se_err)?;
201-
let s = PydanticSerializer::new(
202-
&next_value,
203-
&self.computed_field.serializer,
204-
self.include,
205-
self.exclude,
206-
self.extra,
207-
);
208-
s.serialize(serializer)
209-
}
210-
}

src/serializers/extra.rs

Lines changed: 8 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,45 +27,6 @@ pub(crate) struct SerializationState {
2727
config: SerializationConfig,
2828
}
2929

30-
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31-
pub enum DuckTypingSerMode {
32-
// Don't check the type of the value, use the type of the schema
33-
SchemaBased,
34-
// Check the type of the value, use the type of the value
35-
NeedsInference,
36-
// We already checked the type of the value
37-
// we don't want to infer again, but if we recurse down
38-
// we do want to flip this back to NeedsInference for the
39-
// fields / keys / items of any inner serializers
40-
Inferred,
41-
}
42-
43-
impl DuckTypingSerMode {
44-
pub fn from_bool(serialize_as_any: bool) -> Self {
45-
if serialize_as_any {
46-
DuckTypingSerMode::NeedsInference
47-
} else {
48-
DuckTypingSerMode::SchemaBased
49-
}
50-
}
51-
52-
pub fn to_bool(self) -> bool {
53-
match self {
54-
DuckTypingSerMode::SchemaBased => false,
55-
DuckTypingSerMode::NeedsInference => true,
56-
DuckTypingSerMode::Inferred => true,
57-
}
58-
}
59-
60-
pub fn next_mode(self) -> Self {
61-
match self {
62-
DuckTypingSerMode::SchemaBased => DuckTypingSerMode::SchemaBased,
63-
DuckTypingSerMode::NeedsInference => DuckTypingSerMode::Inferred,
64-
DuckTypingSerMode::Inferred => DuckTypingSerMode::NeedsInference,
65-
}
66-
}
67-
}
68-
6930
impl SerializationState {
7031
pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
7132
let warnings = CollectWarnings::new(WarningsMode::None);
@@ -88,7 +49,7 @@ impl SerializationState {
8849
round_trip: bool,
8950
serialize_unknown: bool,
9051
fallback: Option<&'py Bound<'_, PyAny>>,
91-
duck_typing_ser_mode: DuckTypingSerMode,
52+
serialize_as_any: bool,
9253
context: Option<&'py Bound<'_, PyAny>>,
9354
) -> Extra<'py> {
9455
Extra::new(
@@ -104,7 +65,7 @@ impl SerializationState {
10465
&self.rec_guard,
10566
serialize_unknown,
10667
fallback,
107-
duck_typing_ser_mode,
68+
serialize_as_any,
10869
context,
10970
)
11071
}
@@ -137,7 +98,7 @@ pub(crate) struct Extra<'a> {
13798
pub field_name: Option<&'a str>,
13899
pub serialize_unknown: bool,
139100
pub fallback: Option<&'a Bound<'a, PyAny>>,
140-
pub duck_typing_ser_mode: DuckTypingSerMode,
101+
pub serialize_as_any: bool,
141102
pub context: Option<&'a Bound<'a, PyAny>>,
142103
}
143104

@@ -156,7 +117,7 @@ impl<'a> Extra<'a> {
156117
rec_guard: &'a SerRecursionState,
157118
serialize_unknown: bool,
158119
fallback: Option<&'a Bound<'a, PyAny>>,
159-
duck_typing_ser_mode: DuckTypingSerMode,
120+
serialize_as_any: bool,
160121
context: Option<&'a Bound<'a, PyAny>>,
161122
) -> Self {
162123
Self {
@@ -175,7 +136,7 @@ impl<'a> Extra<'a> {
175136
field_name: None,
176137
serialize_unknown,
177138
fallback,
178-
duck_typing_ser_mode,
139+
serialize_as_any,
179140
context,
180141
}
181142
}
@@ -243,7 +204,7 @@ pub(crate) struct ExtraOwned {
243204
field_name: Option<String>,
244205
serialize_unknown: bool,
245206
pub fallback: Option<PyObject>,
246-
duck_typing_ser_mode: DuckTypingSerMode,
207+
serialize_as_any: bool,
247208
pub context: Option<PyObject>,
248209
}
249210

@@ -264,7 +225,7 @@ impl ExtraOwned {
264225
field_name: extra.field_name.map(ToString::to_string),
265226
serialize_unknown: extra.serialize_unknown,
266227
fallback: extra.fallback.map(|model| model.clone().into()),
267-
duck_typing_ser_mode: extra.duck_typing_ser_mode,
228+
serialize_as_any: extra.serialize_as_any,
268229
context: extra.context.map(|model| model.clone().into()),
269230
}
270231
}
@@ -286,7 +247,7 @@ impl ExtraOwned {
286247
field_name: self.field_name.as_deref(),
287248
serialize_unknown: self.serialize_unknown,
288249
fallback: self.fallback.as_ref().map(|m| m.bind(py)),
289-
duck_typing_ser_mode: self.duck_typing_ser_mode,
250+
serialize_as_any: self.serialize_as_any,
290251
context: self.context.as_ref().map(|m| m.bind(py)),
291252
}
292253
}

0 commit comments

Comments
 (0)