Skip to content

Commit 4ef1da0

Browse files
authored
feat: add on_error (#200)
* feat: add `on_error` * use single quote * no need for owned string * add error message check * add default_value
1 parent c0d1631 commit 4ef1da0

File tree

3 files changed

+195
-20
lines changed

3 files changed

+195
-20
lines changed

pydantic_core/_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ class TypedDictField(TypedDict, total=False):
123123
required: bool
124124
default: Any
125125
default_factory: Callable[[], Any]
126+
on_error: Literal['raise', 'omit', 'fallback_on_default'] # default: 'raise'
126127
alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
127128

128129

src/validators/typed_dict.rs

Lines changed: 80 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::borrow::Cow;
2+
13
use pyo3::prelude::*;
24
use pyo3::types::{PyDict, PyFunction, PyList, PySet, PyString};
35
use pyo3::{intern, PyTypeInfo};
@@ -13,17 +15,37 @@ use crate::SchemaError;
1315

1416
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
1517

18+
#[derive(Debug, Clone)]
19+
enum OnError {
20+
Raise,
21+
Omit,
22+
FallbackOnDefault,
23+
}
24+
1625
#[derive(Debug, Clone)]
1726
struct TypedDictField {
1827
name: String,
1928
lookup_key: LookupKey,
2029
name_pystring: Py<PyString>,
2130
required: bool,
31+
on_error: OnError,
2232
default: Option<PyObject>,
2333
default_factory: Option<PyObject>,
2434
validator: CombinedValidator,
2535
}
2636

37+
impl TypedDictField {
38+
fn default_value(&self, py: Python) -> PyResult<Option<Cow<PyObject>>> {
39+
if let Some(ref default) = self.default {
40+
Ok(Some(Cow::Borrowed(default)))
41+
} else if let Some(ref default_factory) = self.default_factory {
42+
Ok(Some(Cow::Owned(default_factory.call0(py)?)))
43+
} else {
44+
Ok(None)
45+
}
46+
}
47+
}
48+
2749
#[derive(Debug, Clone)]
2850
pub struct TypedDictValidator {
2951
fields: Vec<TypedDictField>,
@@ -88,7 +110,7 @@ impl BuildValidator for TypedDictValidator {
88110
let field_name: &str = key.extract()?;
89111
let schema: &PyAny = field_info
90112
.get_as_req(intern!(py, "schema"))
91-
.map_err(|err| SchemaError::new_err(format!("Field \"{}\":\n {}", field_name, err)))?;
113+
.map_err(|err| SchemaError::new_err(format!("Field '{}':\n {}", field_name, err)))?;
92114

93115
let (default, default_factory) = match (
94116
field_info.get_as(intern!(py, "default"))?,
@@ -107,6 +129,47 @@ impl BuildValidator for TypedDictValidator {
107129
}
108130
None => LookupKey::from_string(py, field_name),
109131
};
132+
133+
let required = match field_info.get_as::<bool>(intern!(py, "required"))? {
134+
Some(required) => {
135+
if required && (default.is_some() || default_factory.is_some()) {
136+
return py_error!("Field '{}': a required field cannot have a default value", field_name);
137+
}
138+
required
139+
}
140+
None => full,
141+
};
142+
143+
let on_error = match field_info.get_as::<&str>(intern!(py, "on_error"))? {
144+
Some(on_error) => match on_error {
145+
"raise" => OnError::Raise,
146+
"omit" => {
147+
if required {
148+
return py_error!(
149+
"Field '{}': 'on_error = {}' cannot be set for required fields",
150+
field_name,
151+
on_error
152+
);
153+
}
154+
155+
OnError::Omit
156+
}
157+
"fallback_on_default" => {
158+
if default.is_none() && default_factory.is_none() {
159+
return py_error!(
160+
"Field '{}': 'on_error = {}' requires a `default` or `default_factory`",
161+
field_name,
162+
on_error
163+
);
164+
}
165+
166+
OnError::FallbackOnDefault
167+
}
168+
_ => unreachable!(),
169+
},
170+
None => OnError::Raise,
171+
};
172+
110173
fields.push(TypedDictField {
111174
name: field_name.to_string(),
112175
lookup_key,
@@ -115,17 +178,10 @@ impl BuildValidator for TypedDictValidator {
115178
Ok((v, _)) => v,
116179
Err(err) => return py_error!("Field \"{}\":\n {}", field_name, err),
117180
},
118-
required: match field_info.get_as::<bool>(intern!(py, "required"))? {
119-
Some(required) => {
120-
if required && (default.is_some() || default_factory.is_some()) {
121-
return py_error!("Field \"{}\": a required field cannot have a default value", field_name);
122-
}
123-
required
124-
}
125-
None => full,
126-
},
181+
required,
127182
default,
128183
default_factory,
184+
on_error,
129185
});
130186
}
131187
Ok(Self {
@@ -210,17 +266,23 @@ impl Validator for TypedDictValidator {
210266
fs.push(field.name_pystring.clone_ref(py));
211267
}
212268
}
213-
Err(ValError::LineErrors(line_errors)) => {
214-
for err in line_errors {
215-
errors.push(err.with_outer_location(field.name.clone().into()));
269+
Err(ValError::LineErrors(line_errors)) => match field.on_error {
270+
OnError::Raise => {
271+
for err in line_errors {
272+
errors.push(err.with_outer_location(field.name.clone().into()));
273+
}
216274
}
217-
}
275+
OnError::Omit => continue,
276+
OnError::FallbackOnDefault => {
277+
if let Some(default_value) = field.default_value(py)? {
278+
output_dict.set_item(&field.name_pystring, default_value.as_ref())?;
279+
}
280+
}
281+
},
218282
Err(err) => return Err(err),
219283
}
220-
} else if let Some(ref default) = field.default {
221-
output_dict.set_item(&field.name_pystring, default)?;
222-
} else if let Some(ref default_factory) = field.default_factory {
223-
output_dict.set_item(&field.name_pystring, default_factory.call0(py)?)?;
284+
} else if let Some(default_value) = field.default_value(py)? {
285+
output_dict.set_item(&field.name_pystring, default_value.as_ref())?
224286
} else if !field.required {
225287
continue;
226288
} else {

tests/validators/test_typed_dict.py

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def test_all_optional_fields_with_required_fields():
449449

450450
def test_field_required_and_default():
451451
"""A field cannot be required and have a default value"""
452-
with pytest.raises(SchemaError, match='Field "x": a required field cannot have a default value'):
452+
with pytest.raises(SchemaError, match="Field 'x': a required field cannot have a default value"):
453453
SchemaValidator(
454454
{'type': 'typed-dict', 'fields': {'x': {'schema': {'type': 'str'}, 'required': True, 'default': 'pika'}}}
455455
)
@@ -1077,7 +1077,7 @@ def test_default_and_default_factory():
10771077

10781078
def test_field_required_and_default_factory():
10791079
"""A field cannot be required and have a default factory"""
1080-
with pytest.raises(SchemaError, match='Field "x": a required field cannot have a default value'):
1080+
with pytest.raises(SchemaError, match="Field 'x': a required field cannot have a default value"):
10811081
SchemaValidator(
10821082
{
10831083
'type': 'typed-dict',
@@ -1099,3 +1099,115 @@ def test_bad_default_factory(default_factory, error_message):
10991099
)
11001100
with pytest.raises(TypeError, match=re.escape(error_message)):
11011101
v.validate_python({})
1102+
1103+
1104+
class TestOnError:
1105+
def test_on_error_bad_name(self):
1106+
with pytest.raises(SchemaError, match="Input should be one of: 'raise', 'omit', 'fallback_on_default'"):
1107+
SchemaValidator({'type': 'typed-dict', 'fields': {'x': {'schema': {'type': 'str'}, 'on_error': 'rais'}}})
1108+
1109+
def test_on_error_bad_omit(self):
1110+
with pytest.raises(SchemaError, match="Field 'x': 'on_error = omit' cannot be set for required fields"):
1111+
SchemaValidator({'type': 'typed-dict', 'fields': {'x': {'schema': {'type': 'str'}, 'on_error': 'omit'}}})
1112+
1113+
def test_on_error_bad_fallback_on_default(self):
1114+
with pytest.raises(
1115+
SchemaError, match="Field 'x': 'on_error = fallback_on_default' requires a `default` or `default_factory`"
1116+
):
1117+
SchemaValidator(
1118+
{'type': 'typed-dict', 'fields': {'x': {'schema': {'type': 'str'}, 'on_error': 'fallback_on_default'}}}
1119+
)
1120+
1121+
def test_on_error_raise_by_default(self, py_and_json: PyAndJson):
1122+
v = py_and_json({'type': 'typed-dict', 'fields': {'x': {'schema': {'type': 'str'}}}})
1123+
assert v.validate_test({'x': 'foo'}) == {'x': 'foo'}
1124+
with pytest.raises(ValidationError) as exc_info:
1125+
v.validate_test({'x': ['foo']})
1126+
assert exc_info.value.errors() == [
1127+
{'input_value': ['foo'], 'kind': 'str_type', 'loc': ['x'], 'message': 'Input should be a valid string'}
1128+
]
1129+
1130+
def test_on_error_raise_explicit(self, py_and_json: PyAndJson):
1131+
v = py_and_json({'type': 'typed-dict', 'fields': {'x': {'schema': {'type': 'str'}, 'on_error': 'raise'}}})
1132+
assert v.validate_test({'x': 'foo'}) == {'x': 'foo'}
1133+
with pytest.raises(ValidationError) as exc_info:
1134+
v.validate_test({'x': ['foo']})
1135+
assert exc_info.value.errors() == [
1136+
{'input_value': ['foo'], 'kind': 'str_type', 'loc': ['x'], 'message': 'Input should be a valid string'}
1137+
]
1138+
1139+
def test_on_error_omit(self, py_and_json: PyAndJson):
1140+
v = py_and_json(
1141+
{'type': 'typed-dict', 'fields': {'x': {'schema': {'type': 'str'}, 'on_error': 'omit', 'required': False}}}
1142+
)
1143+
assert v.validate_test({'x': 'foo'}) == {'x': 'foo'}
1144+
assert v.validate_test({}) == {}
1145+
assert v.validate_test({'x': ['foo']}) == {}
1146+
1147+
def test_on_error_omit_with_default(self, py_and_json: PyAndJson):
1148+
v = py_and_json(
1149+
{
1150+
'type': 'typed-dict',
1151+
'fields': {'x': {'schema': {'type': 'str'}, 'on_error': 'omit', 'default': 'pika', 'required': False}},
1152+
}
1153+
)
1154+
assert v.validate_test({'x': 'foo'}) == {'x': 'foo'}
1155+
assert v.validate_test({}) == {'x': 'pika'}
1156+
assert v.validate_test({'x': ['foo']}) == {}
1157+
1158+
def test_on_error_fallback_on_default(self, py_and_json: PyAndJson):
1159+
v = py_and_json(
1160+
{
1161+
'type': 'typed-dict',
1162+
'fields': {'x': {'schema': {'type': 'str'}, 'on_error': 'fallback_on_default', 'default': 'pika'}},
1163+
}
1164+
)
1165+
assert v.validate_test({'x': 'foo'}) == {'x': 'foo'}
1166+
assert v.validate_test({'x': ['foo']}) == {'x': 'pika'}
1167+
1168+
def test_on_error_fallback_on_default_factory(self, py_and_json: PyAndJson):
1169+
v = py_and_json(
1170+
{
1171+
'type': 'typed-dict',
1172+
'fields': {
1173+
'x': {
1174+
'schema': {'type': 'str'},
1175+
'on_error': 'fallback_on_default',
1176+
'default_factory': lambda: 'pika',
1177+
}
1178+
},
1179+
}
1180+
)
1181+
assert v.validate_test({'x': 'foo'}) == {'x': 'foo'}
1182+
assert v.validate_test({'x': ['foo']}) == {'x': 'pika'}
1183+
1184+
def test_wrap_on_error(self, py_and_json: PyAndJson):
1185+
def wrap_function(input_value, *, validator, **kwargs):
1186+
try:
1187+
return validator(input_value)
1188+
except ValidationError:
1189+
if isinstance(input_value, list):
1190+
return str(len(input_value))
1191+
else:
1192+
return repr(input_value)
1193+
1194+
v = py_and_json(
1195+
{
1196+
'type': 'typed-dict',
1197+
'fields': {
1198+
'x': {
1199+
'schema': {
1200+
'type': 'function',
1201+
'mode': 'wrap',
1202+
'function': wrap_function,
1203+
'schema': {'type': 'str'},
1204+
},
1205+
'on_error': 'raise',
1206+
}
1207+
},
1208+
}
1209+
)
1210+
assert v.validate_test({'x': 'foo'}) == {'x': 'foo'}
1211+
assert v.validate_test({'x': ['foo']}) == {'x': '1'}
1212+
assert v.validate_test({'x': ['foo', 'bar']}) == {'x': '2'}
1213+
assert v.validate_test({'x': {'a': 'b'}}) == {'x': "{'a': 'b'}"}

0 commit comments

Comments
 (0)