Skip to content

Commit a6f3972

Browse files
authored
Coerce validator constraints to their valid type (#1661)
1 parent 0761adb commit a6f3972

File tree

8 files changed

+119
-37
lines changed

8 files changed

+119
-37
lines changed

src/validators/date.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
use pyo3::exceptions::PyValueError;
12
use pyo3::intern;
23
use pyo3::prelude::*;
3-
use pyo3::types::{PyDate, PyDict, PyString};
4+
use pyo3::types::{PyDict, PyString};
45
use speedate::{Date, Time};
56
use strum::EnumMessage;
67

@@ -175,9 +176,14 @@ impl DateConstraints {
175176
}
176177
}
177178

178-
fn convert_pydate(schema: &Bound<'_, PyDict>, field: &Bound<'_, PyString>) -> PyResult<Option<Date>> {
179-
match schema.get_as::<Bound<'_, PyDate>>(field)? {
180-
Some(date) => Ok(Some(EitherDate::Py(date).as_raw()?)),
179+
fn convert_pydate(schema: &Bound<'_, PyDict>, key: &Bound<'_, PyString>) -> PyResult<Option<Date>> {
180+
match schema.get_as::<Bound<'_, PyAny>>(key)? {
181+
Some(value) => match value.validate_date(false) {
182+
Ok(v) => Ok(Some(v.into_inner().as_raw()?)),
183+
Err(_) => Err(PyValueError::new_err(format!(
184+
"'{key}' must be coercible to a date instance",
185+
))),
186+
},
181187
None => Ok(None),
182188
}
183189
}

src/validators/datetime.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
use pyo3::exceptions::PyValueError;
12
use pyo3::intern;
23
use pyo3::prelude::*;
34
use pyo3::sync::GILOnceCell;
45
use pyo3::types::{PyDict, PyString};
5-
use speedate::{DateTime, Time};
6+
use speedate::{DateTime, MicrosecondsPrecisionOverflowBehavior, Time};
67
use std::cmp::Ordering;
78
use strum::EnumMessage;
89

@@ -208,9 +209,14 @@ impl DateTimeConstraints {
208209
}
209210
}
210211

211-
fn py_datetime_as_datetime(schema: &Bound<'_, PyDict>, field: &Bound<'_, PyString>) -> PyResult<Option<DateTime>> {
212-
match schema.get_as(field)? {
213-
Some(dt) => Ok(Some(EitherDateTime::Py(dt).as_raw()?)),
212+
fn py_datetime_as_datetime(schema: &Bound<'_, PyDict>, key: &Bound<'_, PyString>) -> PyResult<Option<DateTime>> {
213+
match schema.get_as::<Bound<'_, PyAny>>(key)? {
214+
Some(value) => match value.validate_datetime(false, MicrosecondsPrecisionOverflowBehavior::Truncate) {
215+
Ok(v) => Ok(Some(v.into_inner().as_raw()?)),
216+
Err(_) => Err(PyValueError::new_err(format!(
217+
"'{key}' must be coercible to a datetime instance",
218+
))),
219+
},
214220
None => Ok(None),
215221
}
216222
}

src/validators/decimal.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use pyo3::exceptions::{PyTypeError, PyValueError};
22
use pyo3::intern;
33
use pyo3::sync::GILOnceCell;
4-
use pyo3::types::{IntoPyDict, PyDict, PyTuple, PyType};
4+
use pyo3::types::{IntoPyDict, PyDict, PyString, PyTuple, PyType};
55
use pyo3::{prelude::*, PyTypeInfo};
66

77
use crate::build_tools::{is_strict, schema_or_config_same};
@@ -28,6 +28,18 @@ pub fn get_decimal_type(py: Python) -> &Bound<'_, PyType> {
2828
.bind(py)
2929
}
3030

31+
fn validate_as_decimal(py: Python, schema: &Bound<'_, PyDict>, key: &str) -> PyResult<Option<Py<PyAny>>> {
32+
match schema.get_as::<Bound<'_, PyAny>>(&PyString::new(py, key))? {
33+
Some(value) => match value.validate_decimal(false, py) {
34+
Ok(v) => Ok(Some(v.into_inner().unbind())),
35+
Err(_) => Err(PyValueError::new_err(format!(
36+
"'{key}' must be coercible to a Decimal instance",
37+
))),
38+
},
39+
None => Ok(None),
40+
}
41+
}
42+
3143
#[derive(Debug, Clone)]
3244
pub struct DecimalValidator {
3345
strict: bool,
@@ -50,6 +62,7 @@ impl BuildValidator for DecimalValidator {
5062
_definitions: &mut DefinitionsBuilder<CombinedValidator>,
5163
) -> PyResult<CombinedValidator> {
5264
let py = schema.py();
65+
5366
let allow_inf_nan = schema_or_config_same(schema, config, intern!(py, "allow_inf_nan"))?.unwrap_or(false);
5467
let decimal_places = schema.get_as(intern!(py, "decimal_places"))?;
5568
let max_digits = schema.get_as(intern!(py, "max_digits"))?;
@@ -58,16 +71,17 @@ impl BuildValidator for DecimalValidator {
5871
"allow_inf_nan=True cannot be used with max_digits or decimal_places",
5972
));
6073
}
74+
6175
Ok(Self {
6276
strict: is_strict(schema, config)?,
6377
allow_inf_nan,
6478
check_digits: decimal_places.is_some() || max_digits.is_some(),
6579
decimal_places,
66-
multiple_of: schema.get_as(intern!(py, "multiple_of"))?,
67-
le: schema.get_as(intern!(py, "le"))?,
68-
lt: schema.get_as(intern!(py, "lt"))?,
69-
ge: schema.get_as(intern!(py, "ge"))?,
70-
gt: schema.get_as(intern!(py, "gt"))?,
80+
multiple_of: validate_as_decimal(py, schema, "multiple_of")?,
81+
le: validate_as_decimal(py, schema, "le")?,
82+
lt: validate_as_decimal(py, schema, "lt")?,
83+
ge: validate_as_decimal(py, schema, "ge")?,
84+
gt: validate_as_decimal(py, schema, "gt")?,
7185
max_digits,
7286
}
7387
.into())

src/validators/int.rs

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use num_bigint::BigInt;
2+
use pyo3::exceptions::PyValueError;
23
use pyo3::intern;
34
use pyo3::prelude::*;
4-
use pyo3::types::PyDict;
5+
use pyo3::types::{PyDict, PyString};
56
use pyo3::IntoPyObjectExt;
67

78
use crate::build_tools::is_strict;
@@ -11,6 +12,23 @@ use crate::tools::SchemaDict;
1112

1213
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
1314

15+
fn validate_as_int(py: Python, schema: &Bound<'_, PyDict>, key: &str) -> PyResult<Option<Int>> {
16+
match schema.get_as::<Bound<'_, PyAny>>(&PyString::new(py, key))? {
17+
Some(value) => match value.validate_int(false) {
18+
Ok(v) => match v.into_inner().as_int() {
19+
Ok(v) => Ok(Some(v)),
20+
Err(_) => Err(PyValueError::new_err(format!(
21+
"'{key}' must be coercible to an integer"
22+
))),
23+
},
24+
Err(_) => Err(PyValueError::new_err(format!(
25+
"'{key}' must be coercible to an integer"
26+
))),
27+
},
28+
None => Ok(None),
29+
}
30+
}
31+
1432
#[derive(Debug, Clone)]
1533
pub struct IntValidator {
1634
strict: bool,
@@ -70,6 +88,21 @@ pub struct ConstrainedIntValidator {
7088
gt: Option<Int>,
7189
}
7290

91+
impl ConstrainedIntValidator {
92+
fn build(schema: &Bound<'_, PyDict>, config: Option<&Bound<'_, PyDict>>) -> PyResult<CombinedValidator> {
93+
let py = schema.py();
94+
Ok(Self {
95+
strict: is_strict(schema, config)?,
96+
multiple_of: validate_as_int(py, schema, "multiple_of")?,
97+
le: validate_as_int(py, schema, "le")?,
98+
lt: validate_as_int(py, schema, "lt")?,
99+
ge: validate_as_int(py, schema, "ge")?,
100+
gt: validate_as_int(py, schema, "gt")?,
101+
}
102+
.into())
103+
}
104+
}
105+
73106
impl_py_gc_traverse!(ConstrainedIntValidator {});
74107

75108
impl Validator for ConstrainedIntValidator {
@@ -144,18 +177,3 @@ impl Validator for ConstrainedIntValidator {
144177
"constrained-int"
145178
}
146179
}
147-
148-
impl ConstrainedIntValidator {
149-
fn build(schema: &Bound<'_, PyDict>, config: Option<&Bound<'_, PyDict>>) -> PyResult<CombinedValidator> {
150-
let py = schema.py();
151-
Ok(Self {
152-
strict: is_strict(schema, config)?,
153-
multiple_of: schema.get_as(intern!(py, "multiple_of"))?,
154-
le: schema.get_as(intern!(py, "le"))?,
155-
lt: schema.get_as(intern!(py, "lt"))?,
156-
ge: schema.get_as(intern!(py, "ge"))?,
157-
gt: schema.get_as(intern!(py, "gt"))?,
158-
}
159-
.into())
160-
}
161-
}

tests/validators/test_date.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@
1313
from ..conftest import Err, PyAndJson
1414

1515

16+
@pytest.mark.parametrize(
17+
'constraint',
18+
['le', 'lt', 'ge', 'gt'],
19+
)
20+
def test_constraints_schema_validation(constraint: str) -> None:
21+
with pytest.raises(SchemaError, match=f"'{constraint}' must be coercible to a date instance"):
22+
SchemaValidator(cs.date_schema(**{constraint: 'bad_value'}))
23+
24+
1625
@pytest.mark.parametrize(
1726
'input_value,expected',
1827
[

tests/validators/test_datetime.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
from ..conftest import Err, PyAndJson
1515

1616

17+
@pytest.mark.parametrize(
18+
'constraint',
19+
['le', 'lt', 'ge', 'gt'],
20+
)
21+
def test_constraints_schema_validation(constraint: str) -> None:
22+
with pytest.raises(SchemaError, match=f"'{constraint}' must be coercible to a datetime instance"):
23+
SchemaValidator(cs.datetime_schema(**{constraint: 'bad_value'}))
24+
25+
1726
@pytest.mark.parametrize(
1827
'input_value,expected',
1928
[

tests/validators/test_decimal.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010
from dirty_equals import FunctionCheck, IsStr
1111

12-
from pydantic_core import SchemaValidator, ValidationError, core_schema
12+
from pydantic_core import SchemaError, SchemaValidator, ValidationError
1313
from pydantic_core import core_schema as cs
1414

1515
from ..conftest import Err, PyAndJson, plain_repr
@@ -19,6 +19,17 @@ class DecimalSubclass(Decimal):
1919
pass
2020

2121

22+
# Note: there's another constraint validation (allow_inf_nan=True cannot be used with max_digits or decimal_places).
23+
# but it is tested in Pydantic:
24+
@pytest.mark.parametrize(
25+
'constraint',
26+
['multiple_of', 'le', 'lt', 'ge', 'gt'],
27+
)
28+
def test_constraints_schema_validation(constraint: str) -> None:
29+
with pytest.raises(SchemaError, match=f"'{constraint}' must be coercible to a Decimal instance"):
30+
SchemaValidator(cs.decimal_schema(**{constraint: 'bad_value'}))
31+
32+
2233
@pytest.mark.parametrize(
2334
'input_value,expected',
2435
[
@@ -487,20 +498,20 @@ def test_validate_max_digits_and_decimal_places_edge_case() -> None:
487498

488499

489500
def test_str_validation_w_strict() -> None:
490-
s = SchemaValidator(core_schema.decimal_schema(strict=True))
501+
s = SchemaValidator(cs.decimal_schema(strict=True))
491502

492503
with pytest.raises(ValidationError):
493504
assert s.validate_python('1.23')
494505

495506

496507
def test_str_validation_w_lax() -> None:
497-
s = SchemaValidator(core_schema.decimal_schema(strict=False))
508+
s = SchemaValidator(cs.decimal_schema(strict=False))
498509

499510
assert s.validate_python('1.23') == Decimal('1.23')
500511

501512

502513
def test_union_with_str_prefers_str() -> None:
503-
s = SchemaValidator(core_schema.union_schema([core_schema.decimal_schema(), core_schema.str_schema()]))
514+
s = SchemaValidator(cs.union_schema([cs.decimal_schema(), cs.str_schema()]))
504515

505516
assert s.validate_python('1.23') == '1.23'
506517
assert s.validate_python(1.23) == Decimal('1.23')

tests/validators/test_int.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,23 @@
66
import pytest
77
from dirty_equals import IsStr
88

9-
from pydantic_core import SchemaValidator, ValidationError, core_schema
9+
from pydantic_core import SchemaError, SchemaValidator, ValidationError
1010
from pydantic_core import core_schema as cs
1111

1212
from ..conftest import Err, PyAndJson, plain_repr
1313

1414
i64_max = 9_223_372_036_854_775_807
1515

1616

17+
@pytest.mark.parametrize(
18+
'constraint',
19+
['multiple_of', 'le', 'lt', 'ge', 'gt'],
20+
)
21+
def test_constraints_schema_validation(constraint: str) -> None:
22+
with pytest.raises(SchemaError, match=f"'{constraint}' must be coercible to an integer"):
23+
SchemaValidator(cs.int_schema(**{constraint: 'bad_value'}))
24+
25+
1726
@pytest.mark.parametrize(
1827
'input_value,expected',
1928
[
@@ -532,7 +541,7 @@ class PlainEnum(Enum):
532541

533542

534543
def test_allow_inf_nan_true_json() -> None:
535-
v = SchemaValidator(core_schema.int_schema(), config=core_schema.CoreConfig(allow_inf_nan=True))
544+
v = SchemaValidator(cs.int_schema(), config=cs.CoreConfig(allow_inf_nan=True))
536545

537546
assert v.validate_json('123') == 123
538547
with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'):
@@ -544,7 +553,7 @@ def test_allow_inf_nan_true_json() -> None:
544553

545554

546555
def test_allow_inf_nan_false_json() -> None:
547-
v = SchemaValidator(core_schema.int_schema(), config=core_schema.CoreConfig(allow_inf_nan=False))
556+
v = SchemaValidator(cs.int_schema(), config=cs.CoreConfig(allow_inf_nan=False))
548557

549558
assert v.validate_json('123') == 123
550559
with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'):

0 commit comments

Comments
 (0)