Skip to content

Commit b33dc3e

Browse files
authored
Functions inside models (#209)
* Allow functions within model_class validators * make fields_set flexible * improve coverage * fix error message test to support pypy
1 parent 4ef1da0 commit b33dc3e

File tree

13 files changed

+305
-73
lines changed

13 files changed

+305
-73
lines changed

pydantic_core/_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class LiteralSchema(TypedDict):
112112
class ModelClassSchema(TypedDict):
113113
type: Literal['model-class']
114114
class_type: type
115-
schema: TypedDictSchema
115+
schema: Schema
116116
strict: NotRequired[bool]
117117
ref: NotRequired[str]
118118
config: NotRequired[Config]

src/validators/arguments.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ impl BuildValidator for ArgumentsValidator {
7676
.get_as_req(intern!(py, "schema"))
7777
.map_err(|err| SchemaError::new_err(format!("Argument \"{}\":\n {}", name, err)))?;
7878

79-
let (validator, _) = build_validator(schema, config, build_context)?;
79+
let validator = build_validator(schema, config, build_context)?;
8080

8181
let default = arg.get_as(intern!(py, "default"))?;
8282
let default_factory = arg.get_as(intern!(py, "default_factory"))?;
@@ -102,11 +102,11 @@ impl BuildValidator for ArgumentsValidator {
102102
arguments,
103103
positional_args_count,
104104
var_args_validator: match schema.get_item(intern!(py, "var_args_schema")) {
105-
Some(v) => Some(Box::new(build_validator(v, config, build_context)?.0)),
105+
Some(v) => Some(Box::new(build_validator(v, config, build_context)?)),
106106
None => None,
107107
},
108108
var_kwargs_validator: match schema.get_item(intern!(py, "var_kwargs_schema")) {
109-
Some(v) => Some(Box::new(build_validator(v, config, build_context)?.0)),
109+
Some(v) => Some(Box::new(build_validator(v, config, build_context)?)),
110110
None => None,
111111
},
112112
}

src/validators/dict.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ impl BuildValidator for DictValidator {
3030
) -> PyResult<CombinedValidator> {
3131
let py = schema.py();
3232
let key_validator = match schema.get_item(intern!(py, "keys_schema")) {
33-
Some(schema) => Box::new(build_validator(schema, config, build_context)?.0),
33+
Some(schema) => Box::new(build_validator(schema, config, build_context)?),
3434
None => Box::new(AnyValidator::build(schema, config, build_context)?),
3535
};
3636
let value_validator = match schema.get_item(intern!(py, "values_schema")) {
37-
Some(d) => Box::new(build_validator(d, config, build_context)?.0),
37+
Some(d) => Box::new(build_validator(d, config, build_context)?),
3838
None => Box::new(AnyValidator::build(schema, config, build_context)?),
3939
};
4040
let name = format!(

src/validators/function.rs

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@ use pyo3::intern;
33
use pyo3::prelude::*;
44
use pyo3::types::{PyAny, PyDict};
55

6-
use crate::build_tools::{py_error, SchemaDict};
6+
use crate::build_tools::SchemaDict;
77
use crate::errors::{ErrorKind, PydanticValueError, ValError, ValResult, ValidationError};
88
use crate::input::Input;
99
use crate::recursion_guard::RecursionGuard;
1010

1111
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
1212

13-
#[derive(Debug)]
1413
pub struct FunctionBuilder;
1514

1615
impl BuildValidator for FunctionBuilder {
@@ -25,9 +24,9 @@ impl BuildValidator for FunctionBuilder {
2524
match mode {
2625
"before" => FunctionBeforeValidator::build(schema, config, build_context),
2726
"after" => FunctionAfterValidator::build(schema, config, build_context),
28-
"plain" => FunctionPlainValidator::build(schema, config),
2927
"wrap" => FunctionWrapValidator::build(schema, config, build_context),
30-
_ => py_error!("Unexpected function mode {:?}", mode),
28+
// must be "plain"
29+
_ => FunctionPlainValidator::build(schema, config),
3130
}
3231
}
3332
}
@@ -47,7 +46,7 @@ macro_rules! impl_build {
4746
build_context: &mut BuildContext,
4847
) -> PyResult<CombinedValidator> {
4948
let py = schema.py();
50-
let validator = build_validator(schema.get_as_req(intern!(py, "schema"))?, config, build_context)?.0;
49+
let validator = build_validator(schema.get_as_req(intern!(py, "schema"))?, config, build_context)?;
5150
let name = format!("{}[{}]", $name, validator.get_name());
5251
Ok(Self {
5352
validator: Box::new(validator),
@@ -110,6 +109,10 @@ impl Validator for FunctionBeforeValidator {
110109
&self.name
111110
}
112111

112+
fn ask(&self, question: &str) -> bool {
113+
self.validator.ask(question)
114+
}
115+
113116
fn complete(&mut self, build_context: &BuildContext) -> PyResult<()> {
114117
self.validator.complete(build_context)
115118
}
@@ -143,6 +146,10 @@ impl Validator for FunctionAfterValidator {
143146
&self.name
144147
}
145148

149+
fn ask(&self, question: &str) -> bool {
150+
self.validator.ask(question)
151+
}
152+
146153
fn complete(&mut self, build_context: &BuildContext) -> PyResult<()> {
147154
self.validator.complete(build_context)
148155
}
@@ -157,18 +164,14 @@ pub struct FunctionPlainValidator {
157164
impl FunctionPlainValidator {
158165
pub fn build(schema: &PyDict, config: Option<&PyDict>) -> PyResult<CombinedValidator> {
159166
let py = schema.py();
160-
if schema.get_item(intern!(py, "schema")).is_some() {
161-
py_error!("Plain functions should not include a sub-schema")
162-
} else {
163-
Ok(Self {
164-
func: schema.get_as_req::<&PyAny>(intern!(py, "function"))?.into_py(py),
165-
config: match config {
166-
Some(c) => c.into(),
167-
None => py.None(),
168-
},
169-
}
170-
.into())
167+
Ok(Self {
168+
func: schema.get_as_req::<&PyAny>(intern!(py, "function"))?.into_py(py),
169+
config: match config {
170+
Some(c) => c.into(),
171+
None => py.None(),
172+
},
171173
}
174+
.into())
172175
}
173176
}
174177

@@ -236,6 +239,10 @@ impl Validator for FunctionWrapValidator {
236239
&self.name
237240
}
238241

242+
fn ask(&self, question: &str) -> bool {
243+
self.validator.ask(question)
244+
}
245+
239246
fn complete(&mut self, build_context: &BuildContext) -> PyResult<()> {
240247
self.validator.complete(build_context)
241248
}

src/validators/list.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ macro_rules! generic_list_like_build {
2828
) -> PyResult<CombinedValidator> {
2929
let py = schema.py();
3030
let item_validator = match schema.get_item(pyo3::intern!(py, "items_schema")) {
31-
Some(d) => Some(Box::new(build_validator(d, config, build_context)?.0)),
31+
Some(d) => Some(Box::new(build_validator(d, config, build_context)?)),
3232
None => None,
3333
};
3434
let inner_name = item_validator.as_ref().map(|v| v.get_name()).unwrap_or("any");

src/validators/mod.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ impl SchemaValidator {
6868
let schema = schema_obj.as_ref(py);
6969

7070
let mut build_context = BuildContext::default();
71-
let (mut validator, _) = build_validator(schema, config, &mut build_context)?;
71+
let mut validator = build_validator(schema, config, &mut build_context)?;
7272
validator.complete(&build_context)?;
7373
let slots = build_context.into_slots()?;
7474
let title = validator.get_name().into_py(py);
@@ -211,7 +211,7 @@ impl SchemaValidator {
211211

212212
let mut build_context = BuildContext::default();
213213
let validator = match build_validator(self_schema, None, &mut build_context) {
214-
Ok((v, _)) => v,
214+
Ok(v) => v,
215215
Err(err) => return Err(SchemaError::new_err(format!("Error building self-schema:\n {}", err))),
216216
};
217217
Ok(Self {
@@ -255,7 +255,7 @@ fn build_single_validator<'a, T: BuildValidator>(
255255
schema_dict: &'a PyDict,
256256
config: Option<&'a PyDict>,
257257
build_context: &mut BuildContext,
258-
) -> PyResult<(CombinedValidator, &'a PyDict)> {
258+
) -> PyResult<CombinedValidator> {
259259
let py = schema_dict.py();
260260
let val: CombinedValidator = if let Some(schema_ref) = schema_dict.get_as::<String>(intern!(py, "ref"))? {
261261
let slot_id = build_context.prepare_slot(schema_ref)?;
@@ -269,7 +269,7 @@ fn build_single_validator<'a, T: BuildValidator>(
269269
.map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))?
270270
};
271271

272-
Ok((val, schema_dict))
272+
Ok(val)
273273
}
274274

275275
// macro to build the match statement for validator selection
@@ -290,7 +290,7 @@ pub fn build_validator<'a>(
290290
schema: &'a PyAny,
291291
config: Option<&'a PyDict>,
292292
build_context: &mut BuildContext,
293-
) -> PyResult<(CombinedValidator, &'a PyDict)> {
293+
) -> PyResult<CombinedValidator> {
294294
let py = schema.py();
295295
let dict: &PyDict = match schema.cast_as() {
296296
Ok(s) => s,
@@ -485,6 +485,14 @@ pub trait Validator: Send + Sync + Clone + Debug {
485485
/// this is used in the error location in unions, and in the top level message in `ValidationError`
486486
fn get_name(&self) -> &str;
487487

488+
/// allows validators to ask specific questions of sub-validators in a general way, could be extended
489+
/// to do more, validators which don't know the question and have sub-validators
490+
/// should return the result them in an `...iter().all(|v| v.ask(question))` way, ONLY
491+
/// if they return the value of the sub-validator, e.g. functions, unions
492+
fn ask(&self, _question: &str) -> bool {
493+
false
494+
}
495+
488496
/// this method must be implemented for any validator which holds references to other validators,
489497
/// it is used by `RecursiveRefValidator` to set its name
490498
fn complete(&mut self, _build_context: &BuildContext) -> PyResult<()> {

src/validators/model_class.rs

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,21 @@ use pyo3::prelude::*;
88
use pyo3::types::{PyDict, PyTuple, PyType};
99
use pyo3::{ffi, intern};
1010

11-
use crate::build_tools::{py_error, SchemaDict};
11+
use crate::build_tools::SchemaDict;
1212
use crate::errors::{ErrorKind, ValError, ValResult};
1313
use crate::input::Input;
1414
use crate::recursion_guard::RecursionGuard;
1515

16-
use super::typed_dict::TypedDictValidator;
1716
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
1817

1918
#[derive(Debug, Clone)]
2019
pub struct ModelClassValidator {
2120
strict: bool,
2221
revalidate: bool,
23-
validator: TypedDictValidator,
22+
validator: Box<CombinedValidator>,
2423
class: Py<PyType>,
2524
name: String,
25+
expect_fields_set: bool,
2626
}
2727

2828
impl BuildValidator for ModelClassValidator {
@@ -39,27 +39,21 @@ impl BuildValidator for ModelClassValidator {
3939

4040
let class: &PyType = schema.get_as_req(intern!(py, "class_type"))?;
4141
let sub_schema: &PyAny = schema.get_as_req(intern!(py, "schema"))?;
42-
let (comb_validator, td_schema) = build_validator(sub_schema, config, build_context)?;
42+
let validator = build_validator(sub_schema, config, build_context)?;
4343

44-
if !td_schema.get_as(intern!(py, "return_fields_set"))?.unwrap_or(false) {
45-
return py_error!("model-class inner schema should have 'return_fields_set' set to True");
46-
}
47-
48-
let validator = match comb_validator {
49-
CombinedValidator::TypedDict(tdv) => tdv,
50-
_ => return py_error!("Wrong validator type, expected 'typed-dict' validator"),
51-
};
44+
let expect_fields_set = validator.ask("return_fields_set");
5245

5346
Ok(Self {
5447
// we don't use is_strict here since we don't want validation to be strict in this case if
5548
// `config.strict` is set, only if this specific field is strict
5649
strict: schema.get_as(intern!(py, "strict"))?.unwrap_or(false),
5750
revalidate: config.get_as(intern!(py, "revalidate_models"))?.unwrap_or(false),
58-
validator,
51+
validator: Box::new(validator),
5952
class: class.into(),
6053
// Get the class's `__name__`, not using `class.name()` since it uses `__qualname__`
6154
// which is not what we want here
6255
name: class.getattr(intern!(py, "__name__"))?.extract()?,
56+
expect_fields_set,
6357
}
6458
.into())
6559
}
@@ -79,9 +73,13 @@ impl Validator for ModelClassValidator {
7973
if self.revalidate {
8074
let fields_set = input.get_attr(intern!(py, "__fields_set__"));
8175
let output = self.validator.validate(py, input, extra, slots, recursion_guard)?;
82-
let (model_dict, validation_fields_set): (&PyAny, &PyAny) = output.extract(py)?;
83-
let fields_set = fields_set.unwrap_or(validation_fields_set);
84-
Ok(self.create_class(py, model_dict, fields_set)?)
76+
if self.expect_fields_set {
77+
let (model_dict, validation_fields_set): (&PyAny, &PyAny) = output.extract(py)?;
78+
let fields_set = fields_set.unwrap_or(validation_fields_set);
79+
Ok(self.create_class(py, model_dict, Some(fields_set))?)
80+
} else {
81+
Ok(self.create_class(py, output.as_ref(py), fields_set)?)
82+
}
8583
} else {
8684
Ok(input.to_object(py))
8785
}
@@ -94,8 +92,12 @@ impl Validator for ModelClassValidator {
9492
))
9593
} else {
9694
let output = self.validator.validate(py, input, extra, slots, recursion_guard)?;
97-
let (model_dict, fields_set): (&PyAny, &PyAny) = output.extract(py)?;
98-
Ok(self.create_class(py, model_dict, fields_set)?)
95+
if self.expect_fields_set {
96+
let (model_dict, fields_set): (&PyAny, &PyAny) = output.extract(py)?;
97+
Ok(self.create_class(py, model_dict, Some(fields_set))?)
98+
} else {
99+
Ok(self.create_class(py, output.as_ref(py), None)?)
100+
}
99101
}
100102
}
101103

@@ -105,7 +107,7 @@ impl Validator for ModelClassValidator {
105107
}
106108

107109
impl ModelClassValidator {
108-
fn create_class(&self, py: Python, model_dict: &PyAny, fields_set: &PyAny) -> PyResult<PyObject> {
110+
fn create_class(&self, py: Python, model_dict: &PyAny, fields_set: Option<&PyAny>) -> PyResult<PyObject> {
109111
// based on the following but with the second argument of new_func set to an empty tuple as required
110112
// https://github.com/PyO3/pyo3/blob/d2caa056e9aacc46374139ef491d112cb8af1a25/src/pyclass_init.rs#L35-L77
111113
let args = PyTuple::empty(py);
@@ -126,7 +128,9 @@ impl ModelClassValidator {
126128

127129
let instance_ref = instance.as_ref(py);
128130
force_setattr(py, instance_ref, intern!(py, "__dict__"), model_dict)?;
129-
force_setattr(py, instance_ref, intern!(py, "__fields_set__"), fields_set)?;
131+
if let Some(fields_set) = fields_set {
132+
force_setattr(py, instance_ref, intern!(py, "__fields_set__"), fields_set)?;
133+
}
130134

131135
Ok(instance)
132136
}

src/validators/nullable.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ impl BuildValidator for NullableValidator {
2424
build_context: &mut BuildContext,
2525
) -> PyResult<CombinedValidator> {
2626
let schema: &PyAny = schema.get_as_req(intern!(schema.py(), "schema"))?;
27-
let validator = Box::new(build_validator(schema, config, build_context)?.0);
27+
let validator = Box::new(build_validator(schema, config, build_context)?);
2828
let name = format!("{}[{}]", Self::EXPECTED_TYPE, validator.get_name());
2929
Ok(Self { validator, name }.into())
3030
}

src/validators/tuple.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,15 @@ impl TuplePositionalValidator {
9393
let items: &PyList = schema.get_as_req(intern!(py, "items_schema"))?;
9494
let validators: Vec<CombinedValidator> = items
9595
.iter()
96-
.map(|item| build_validator(item, config, build_context).map(|result| result.0))
96+
.map(|item| build_validator(item, config, build_context))
9797
.collect::<PyResult<Vec<CombinedValidator>>>()?;
9898

9999
let descr = validators.iter().map(|v| v.get_name()).collect::<Vec<_>>().join(", ");
100100
Ok(Self {
101101
strict: is_strict(schema, config)?,
102102
items_validators: validators,
103103
extra_validator: match schema.get_item(intern!(py, "extra_schema")) {
104-
Some(v) => Some(Box::new(build_validator(v, config, build_context)?.0)),
104+
Some(v) => Some(Box::new(build_validator(v, config, build_context)?)),
105105
None => None,
106106
},
107107
name: format!("tuple[{}]", descr),

src/validators/typed_dict.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ impl BuildValidator for TypedDictValidator {
9494
let extra_validator = match schema.get_item(intern!(py, "extra_validator")) {
9595
Some(v) => {
9696
if check_extra && !forbid_extra {
97-
Some(Box::new(build_validator(v, config, build_context)?.0))
97+
Some(Box::new(build_validator(v, config, build_context)?))
9898
} else {
9999
return py_error!("extra_validator can only be used if extra_behavior=allow");
100100
}
@@ -175,7 +175,7 @@ impl BuildValidator for TypedDictValidator {
175175
lookup_key,
176176
name_pystring: PyString::intern(py, field_name).into(),
177177
validator: match build_validator(schema, config, build_context) {
178-
Ok((v, _)) => v,
178+
Ok(v) => v,
179179
Err(err) => return py_error!("Field \"{}\":\n {}", field_name, err),
180180
},
181181
required,
@@ -376,6 +376,14 @@ impl Validator for TypedDictValidator {
376376
Self::EXPECTED_TYPE
377377
}
378378

379+
fn ask(&self, question: &str) -> bool {
380+
if question == "return_fields_set" {
381+
self.return_fields_set
382+
} else {
383+
false
384+
}
385+
}
386+
379387
fn complete(&mut self, build_context: &BuildContext) -> PyResult<()> {
380388
self.fields
381389
.iter_mut()

0 commit comments

Comments
 (0)