Skip to content

Commit 334a9b0

Browse files
authored
Do not reuse validator and serializer when unpickling (#1693)
1 parent 3414703 commit 334a9b0

File tree

5 files changed

+97
-15
lines changed

5 files changed

+97
-15
lines changed

src/serializers/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValu
1414
use extra::{CollectWarnings, SerRecursionState, WarningsMode};
1515
pub(crate) use extra::{DuckTypingSerMode, Extra, SerMode, SerializationState};
1616
pub use shared::CombinedSerializer;
17-
use shared::{to_json_bytes, BuildSerializer, TypeSerializer};
17+
use shared::{to_json_bytes, TypeSerializer};
1818

1919
mod computed_fields;
2020
mod config;
@@ -91,7 +91,7 @@ impl SchemaSerializer {
9191
#[pyo3(signature = (schema, config=None))]
9292
pub fn py_new(schema: Bound<'_, PyDict>, config: Option<&Bound<'_, PyDict>>) -> PyResult<Self> {
9393
let mut definitions_builder = DefinitionsBuilder::new();
94-
let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?;
94+
let serializer = CombinedSerializer::build_base(schema.downcast()?, config, &mut definitions_builder)?;
9595
Ok(Self {
9696
serializer,
9797
definitions: definitions_builder.finish()?,

src/serializers/shared.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,21 @@ combined_serializer! {
149149
}
150150

151151
impl CombinedSerializer {
152+
// Used when creating the base serializer instance, to avoid reusing the instance
153+
// when unpickling:
154+
pub fn build_base(
155+
schema: &Bound<'_, PyDict>,
156+
config: Option<&Bound<'_, PyDict>>,
157+
definitions: &mut DefinitionsBuilder<CombinedSerializer>,
158+
) -> PyResult<CombinedSerializer> {
159+
Self::_build(schema, config, definitions, false)
160+
}
161+
152162
fn _build(
153163
schema: &Bound<'_, PyDict>,
154164
config: Option<&Bound<'_, PyDict>>,
155165
definitions: &mut DefinitionsBuilder<CombinedSerializer>,
166+
use_prebuilt: bool,
156167
) -> PyResult<CombinedSerializer> {
157168
let py = schema.py();
158169
let type_key = intern!(py, "type");
@@ -199,9 +210,13 @@ impl CombinedSerializer {
199210
let type_: Bound<'_, PyString> = schema.get_as_req(type_key)?;
200211
let type_ = type_.to_str()?;
201212

202-
// if we have a SchemaValidator on the type already, use it
203-
if let Ok(Some(prebuilt_serializer)) = super::prebuilt::PrebuiltSerializer::try_get_from_schema(type_, schema) {
204-
return Ok(prebuilt_serializer);
213+
if use_prebuilt {
214+
// if we have a SchemaValidator on the type already, use it
215+
if let Ok(Some(prebuilt_serializer)) =
216+
super::prebuilt::PrebuiltSerializer::try_get_from_schema(type_, schema)
217+
{
218+
return Ok(prebuilt_serializer);
219+
}
205220
}
206221

207222
Self::find_serializer(type_, schema, config, definitions)
@@ -217,7 +232,7 @@ impl BuildSerializer for CombinedSerializer {
217232
config: Option<&Bound<'_, PyDict>>,
218233
definitions: &mut DefinitionsBuilder<CombinedSerializer>,
219234
) -> PyResult<CombinedSerializer> {
220-
Self::_build(schema, config, definitions)
235+
Self::_build(schema, config, definitions, true)
221236
}
222237
}
223238

src/validators/mod.rs

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ impl SchemaValidator {
127127
pub fn py_new(py: Python, schema: &Bound<'_, PyAny>, config: Option<&Bound<'_, PyDict>>) -> PyResult<Self> {
128128
let mut definitions_builder = DefinitionsBuilder::new();
129129

130-
let validator = build_validator(schema, config, &mut definitions_builder)?;
130+
let validator = build_validator_base(schema, config, &mut definitions_builder)?;
131131
let definitions = definitions_builder.finish()?;
132132
let py_schema = schema.clone().unbind();
133133
let py_config = match config {
@@ -159,11 +159,6 @@ impl SchemaValidator {
159159
})
160160
}
161161

162-
pub fn __reduce__<'py>(slf: &Bound<'py, Self>) -> PyResult<(Bound<'py, PyType>, Bound<'py, PyTuple>)> {
163-
let init_args = (&slf.get().py_schema, &slf.get().py_config).into_pyobject(slf.py())?;
164-
Ok((slf.get_type(), init_args))
165-
}
166-
167162
#[allow(clippy::too_many_arguments)]
168163
#[pyo3(signature = (input, *, strict=None, from_attributes=None, context=None, self_instance=None, allow_partial=PartialMode::Off, by_alias=None, by_name=None))]
169164
pub fn validate_python(
@@ -357,6 +352,11 @@ impl SchemaValidator {
357352
}
358353
}
359354

355+
pub fn __reduce__<'py>(slf: &Bound<'py, Self>) -> PyResult<(Bound<'py, PyType>, Bound<'py, PyTuple>)> {
356+
let init_args = (&slf.get().py_schema, &slf.get().py_config).into_pyobject(slf.py())?;
357+
Ok((slf.get_type(), init_args))
358+
}
359+
360360
pub fn __repr__(&self, py: Python) -> String {
361361
format!(
362362
"SchemaValidator(title={:?}, validator={:#?}, definitions={:#?}, cache_strings={})",
@@ -555,19 +555,40 @@ macro_rules! validator_match {
555555
};
556556
}
557557

558+
// Used when creating the base validator instance, to avoid reusing the instance
559+
// when unpickling:
560+
pub fn build_validator_base(
561+
schema: &Bound<'_, PyAny>,
562+
config: Option<&Bound<'_, PyDict>>,
563+
definitions: &mut DefinitionsBuilder<CombinedValidator>,
564+
) -> PyResult<CombinedValidator> {
565+
build_validator_inner(schema, config, definitions, false)
566+
}
567+
558568
pub fn build_validator(
559569
schema: &Bound<'_, PyAny>,
560570
config: Option<&Bound<'_, PyDict>>,
561571
definitions: &mut DefinitionsBuilder<CombinedValidator>,
572+
) -> PyResult<CombinedValidator> {
573+
build_validator_inner(schema, config, definitions, true)
574+
}
575+
576+
fn build_validator_inner(
577+
schema: &Bound<'_, PyAny>,
578+
config: Option<&Bound<'_, PyDict>>,
579+
definitions: &mut DefinitionsBuilder<CombinedValidator>,
580+
use_prebuilt: bool,
562581
) -> PyResult<CombinedValidator> {
563582
let dict = schema.downcast::<PyDict>()?;
564583
let py = schema.py();
565584
let type_: Bound<'_, PyString> = dict.get_as_req(intern!(py, "type"))?;
566585
let type_ = type_.to_str()?;
567586

568-
// if we have a SchemaValidator on the type already, use it
569-
if let Ok(Some(prebuilt_validator)) = prebuilt::PrebuiltValidator::try_get_from_schema(type_, dict) {
570-
return Ok(prebuilt_validator);
587+
if use_prebuilt {
588+
// if we have a SchemaValidator on the type already, use it
589+
if let Ok(Some(prebuilt_validator)) = prebuilt::PrebuiltValidator::try_get_from_schema(type_, dict) {
590+
return Ok(prebuilt_validator);
591+
}
571592
}
572593

573594
validator_match!(

tests/serializers/test_pickling.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,26 @@ def test_schema_serializer_containing_config():
4848
assert s.to_python(timedelta(seconds=4, microseconds=500_000)) == timedelta(seconds=4, microseconds=500_000)
4949
assert s.to_python(timedelta(seconds=4, microseconds=500_000), mode='json') == 4.5
5050
assert s.to_json(timedelta(seconds=4, microseconds=500_000)) == b'4.5'
51+
52+
53+
# Should be defined at the module level for pickling to work:
54+
class Model:
55+
__pydantic_serializer__: SchemaSerializer
56+
__pydantic_complete__ = True
57+
58+
59+
def test_schema_serializer_not_reused_when_unpickling() -> None:
60+
s = SchemaSerializer(
61+
core_schema.model_schema(
62+
cls=Model,
63+
schema=core_schema.model_fields_schema(fields={}, model_name='Model'),
64+
config={'title': 'Model'},
65+
ref='Model:123',
66+
)
67+
)
68+
69+
Model.__pydantic_serializer__ = s
70+
assert 'Prebuilt' not in str(Model.__pydantic_serializer__)
71+
72+
reconstructed = pickle.loads(pickle.dumps(Model.__pydantic_serializer__))
73+
assert 'Prebuilt' not in str(reconstructed)

tests/validators/test_pickling.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,26 @@ def test_schema_validator_tz_pickle() -> None:
5151
validated = v.validate_python('2022-06-08T12:13:14-12:15')
5252
assert validated == original
5353
assert pickle.loads(pickle.dumps(validated)) == validated == original
54+
55+
56+
# Should be defined at the module level for pickling to work:
57+
class Model:
58+
__pydantic_validator__: SchemaValidator
59+
__pydantic_complete__ = True
60+
61+
62+
def test_schema_validator_not_reused_when_unpickling() -> None:
63+
s = SchemaValidator(
64+
core_schema.model_schema(
65+
cls=Model,
66+
schema=core_schema.model_fields_schema(fields={}, model_name='Model'),
67+
config={'title': 'Model'},
68+
ref='Model:123',
69+
)
70+
)
71+
72+
Model.__pydantic_validator__ = s
73+
assert 'Prebuilt' not in str(Model.__pydantic_validator__)
74+
75+
reconstructed = pickle.loads(pickle.dumps(Model.__pydantic_validator__))
76+
assert 'Prebuilt' not in str(reconstructed)

0 commit comments

Comments
 (0)