Skip to content

Commit cfd5da7

Browse files
changes to some coercions (#208)
* stop coercing `set / frozenset` to `list / tuple` * add `dict_key` and `dict_value` * add iterator support * more tests * use PyList directly for lax_list * catch errors in generator evaluation * fix mypy and more tests * remove unused dict_items code * make format Co-authored-by: Samuel Colvin <s@muelcolvin.com>
1 parent 42a465a commit cfd5da7

File tree

8 files changed

+332
-43
lines changed

8 files changed

+332
-43
lines changed

src/errors/kinds.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ pub enum ErrorKind {
119119
error: String,
120120
},
121121
// ---------------------
122+
// generic list-list errors
123+
#[strum(message = "Error iterating over object")]
124+
IterationError,
125+
// ---------------------
122126
// list errors
123127
#[strum(message = "Input should be a valid list/array")]
124128
ListType,

src/input/_pyo3_dict.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// TODO: remove this file once a new pyo3 version is released
2+
// with https://github.com/PyO3/pyo3/pull/2358
3+
4+
use pyo3::{ffi, pyobject_native_type_core, PyAny};
5+
6+
/// Represents a Python `dict_keys`.
7+
#[cfg(not(PyPy))]
8+
#[repr(transparent)]
9+
pub struct PyDictKeys(PyAny);
10+
11+
#[cfg(not(PyPy))]
12+
pyobject_native_type_core!(
13+
PyDictKeys,
14+
ffi::PyDictKeys_Type,
15+
#checkfunction=ffi::PyDictKeys_Check
16+
);
17+
18+
/// Represents a Python `dict_values`.
19+
#[cfg(not(PyPy))]
20+
#[repr(transparent)]
21+
pub struct PyDictValues(PyAny);
22+
23+
#[cfg(not(PyPy))]
24+
pyobject_native_type_core!(
25+
PyDictValues,
26+
ffi::PyDictValues_Type,
27+
#checkfunction=ffi::PyDictValues_Check
28+
);

src/input/input_python.rs

Lines changed: 107 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ use std::str::from_utf8;
44
use pyo3::exceptions::PyAttributeError;
55
use pyo3::prelude::*;
66
use pyo3::types::{
7-
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyList, PyMapping, PySequence,
8-
PySet, PyString, PyTime, PyTuple, PyType,
7+
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyIterator, PyList, PyMapping,
8+
PySequence, PySet, PyString, PyTime, PyTuple, PyType,
99
};
1010
use pyo3::{intern, AsPyPointer};
1111

1212
use crate::errors::{py_err_string, ErrorKind, InputValue, LocItem, ValError, ValResult};
1313

14+
#[cfg(not(PyPy))]
15+
use super::_pyo3_dict::{PyDictKeys, PyDictValues};
1416
use super::datetime::{
1517
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime,
1618
float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime,
@@ -22,6 +24,25 @@ use super::{
2224
GenericMapping, Input, PyArgs,
2325
};
2426

27+
#[cfg(not(PyPy))]
28+
macro_rules! extract_gen_dict {
29+
($type:ty, $obj:ident) => {{
30+
let map_err = |_| ValError::new(ErrorKind::IterationError, $obj);
31+
if let Ok(iterator) = $obj.cast_as::<PyIterator>() {
32+
let vec = iterator.collect::<PyResult<Vec<_>>>().map_err(map_err)?;
33+
Some(<$type>::new($obj.py(), vec))
34+
} else if let Ok(dict_keys) = $obj.cast_as::<PyDictKeys>() {
35+
let vec = dict_keys.iter()?.collect::<PyResult<Vec<_>>>().map_err(map_err)?;
36+
Some(<$type>::new($obj.py(), vec))
37+
} else if let Ok(dict_values) = $obj.cast_as::<PyDictValues>() {
38+
let vec = dict_values.iter()?.collect::<PyResult<Vec<_>>>().map_err(map_err)?;
39+
Some(<$type>::new($obj.py(), vec))
40+
} else {
41+
None
42+
}
43+
}};
44+
}
45+
2546
impl<'a> Input<'a> for PyAny {
2647
fn as_loc_item(&self) -> LocItem {
2748
if let Ok(py_str) = self.cast_as::<PyString>() {
@@ -261,15 +282,30 @@ impl<'a> Input<'a> for PyAny {
261282
}
262283
}
263284

285+
#[cfg(not(PyPy))]
264286
fn lax_list(&'a self) -> ValResult<GenericListLike<'a>> {
265287
if let Ok(list) = self.cast_as::<PyList>() {
266288
Ok(list.into())
267289
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
268290
Ok(tuple.into())
269-
} else if let Ok(set) = self.cast_as::<PySet>() {
270-
Ok(set.into())
271-
} else if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
272-
Ok(frozen_set.into())
291+
} else if let Some(list) = extract_gen_dict!(PyList, self) {
292+
Ok(list.into())
293+
} else {
294+
Err(ValError::new(ErrorKind::ListType, self))
295+
}
296+
}
297+
298+
#[cfg(PyPy)]
299+
fn lax_list(&'a self) -> ValResult<GenericListLike<'a>> {
300+
if let Ok(list) = self.cast_as::<PyList>() {
301+
Ok(list.into())
302+
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
303+
Ok(tuple.into())
304+
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
305+
let vec = iterator
306+
.collect::<PyResult<Vec<_>>>()
307+
.map_err(|_| ValError::new(ErrorKind::IterationError, self))?;
308+
Ok(PyList::new(self.py(), vec).into())
273309
} else {
274310
Err(ValError::new(ErrorKind::ListType, self))
275311
}
@@ -283,15 +319,30 @@ impl<'a> Input<'a> for PyAny {
283319
}
284320
}
285321

322+
#[cfg(not(PyPy))]
286323
fn lax_tuple(&'a self) -> ValResult<GenericListLike<'a>> {
287324
if let Ok(tuple) = self.cast_as::<PyTuple>() {
288325
Ok(tuple.into())
289326
} else if let Ok(list) = self.cast_as::<PyList>() {
290327
Ok(list.into())
291-
} else if let Ok(set) = self.cast_as::<PySet>() {
292-
Ok(set.into())
293-
} else if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
294-
Ok(frozen_set.into())
328+
} else if let Some(tuple) = extract_gen_dict!(PyTuple, self) {
329+
Ok(tuple.into())
330+
} else {
331+
Err(ValError::new(ErrorKind::TupleType, self))
332+
}
333+
}
334+
335+
#[cfg(PyPy)]
336+
fn lax_tuple(&'a self) -> ValResult<GenericListLike<'a>> {
337+
if let Ok(tuple) = self.cast_as::<PyTuple>() {
338+
Ok(tuple.into())
339+
} else if let Ok(list) = self.cast_as::<PyList>() {
340+
Ok(list.into())
341+
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
342+
let vec = iterator
343+
.collect::<PyResult<Vec<_>>>()
344+
.map_err(|_| ValError::new(ErrorKind::IterationError, self))?;
345+
Ok(PyTuple::new(self.py(), vec).into())
295346
} else {
296347
Err(ValError::new(ErrorKind::TupleType, self))
297348
}
@@ -305,6 +356,24 @@ impl<'a> Input<'a> for PyAny {
305356
}
306357
}
307358

359+
#[cfg(not(PyPy))]
360+
fn lax_set(&'a self) -> ValResult<GenericListLike<'a>> {
361+
if let Ok(set) = self.cast_as::<PySet>() {
362+
Ok(set.into())
363+
} else if let Ok(list) = self.cast_as::<PyList>() {
364+
Ok(list.into())
365+
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
366+
Ok(tuple.into())
367+
} else if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
368+
Ok(frozen_set.into())
369+
} else if let Some(tuple) = extract_gen_dict!(PyTuple, self) {
370+
Ok(tuple.into())
371+
} else {
372+
Err(ValError::new(ErrorKind::SetType, self))
373+
}
374+
}
375+
376+
#[cfg(PyPy)]
308377
fn lax_set(&'a self) -> ValResult<GenericListLike<'a>> {
309378
if let Ok(set) = self.cast_as::<PySet>() {
310379
Ok(set.into())
@@ -314,6 +383,11 @@ impl<'a> Input<'a> for PyAny {
314383
Ok(tuple.into())
315384
} else if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
316385
Ok(frozen_set.into())
386+
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
387+
let vec = iterator
388+
.collect::<PyResult<Vec<_>>>()
389+
.map_err(|_| ValError::new(ErrorKind::IterationError, self))?;
390+
Ok(PyTuple::new(self.py(), vec).into())
317391
} else {
318392
Err(ValError::new(ErrorKind::SetType, self))
319393
}
@@ -327,6 +401,24 @@ impl<'a> Input<'a> for PyAny {
327401
}
328402
}
329403

404+
#[cfg(not(PyPy))]
405+
fn lax_frozenset(&'a self) -> ValResult<GenericListLike<'a>> {
406+
if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
407+
Ok(frozen_set.into())
408+
} else if let Ok(set) = self.cast_as::<PySet>() {
409+
Ok(set.into())
410+
} else if let Ok(list) = self.cast_as::<PyList>() {
411+
Ok(list.into())
412+
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
413+
Ok(tuple.into())
414+
} else if let Some(tuple) = extract_gen_dict!(PyTuple, self) {
415+
Ok(tuple.into())
416+
} else {
417+
Err(ValError::new(ErrorKind::FrozenSetType, self))
418+
}
419+
}
420+
421+
#[cfg(PyPy)]
330422
fn lax_frozenset(&'a self) -> ValResult<GenericListLike<'a>> {
331423
if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
332424
Ok(frozen_set.into())
@@ -336,6 +428,11 @@ impl<'a> Input<'a> for PyAny {
336428
Ok(list.into())
337429
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
338430
Ok(tuple.into())
431+
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
432+
let vec = iterator
433+
.collect::<PyResult<Vec<_>>>()
434+
.map_err(|_| ValError::new(ErrorKind::IterationError, self))?;
435+
Ok(PyTuple::new(self.py(), vec).into())
339436
} else {
340437
Err(ValError::new(ErrorKind::FrozenSetType, self))
341438
}

src/input/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use pyo3::prelude::*;
22

3+
#[cfg(not(PyPy))]
4+
mod _pyo3_dict;
35
mod datetime;
46
mod input_abstract;
57
mod input_json;

tests/validators/test_frozenset.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import platform
12
import re
23
from typing import Any, Dict
34

@@ -63,19 +64,35 @@ def test_frozenset_no_validators_both(py_and_json: PyAndJson, input_value, expec
6364
@pytest.mark.parametrize(
6465
'input_value,expected',
6566
[
66-
({1, 2, 3}, {1, 2, 3}),
67+
({1, 2, 3}, frozenset({1, 2, 3})),
6768
(frozenset(), frozenset()),
68-
([1, 2, 3, 2, 3], {1, 2, 3}),
69+
([1, 2, 3, 2, 3], frozenset({1, 2, 3})),
6970
([], frozenset()),
70-
((1, 2, 3, 2, 3), {1, 2, 3}),
71+
((1, 2, 3, 2, 3), frozenset({1, 2, 3})),
7172
((), frozenset()),
72-
(frozenset([1, 2, 3, 2, 3]), {1, 2, 3}),
73+
(frozenset([1, 2, 3, 2, 3]), frozenset({1, 2, 3})),
74+
pytest.param(
75+
{1: 10, 2: 20, '3': '30'}.keys(),
76+
frozenset({1, 2, 3}),
77+
marks=pytest.mark.skipif(
78+
platform.python_implementation() == 'PyPy', reason='dict views not implemented in pyo3 for pypy'
79+
),
80+
),
81+
pytest.param(
82+
{1: 10, 2: 20, '3': '30'}.values(),
83+
frozenset({10, 20, 30}),
84+
marks=pytest.mark.skipif(
85+
platform.python_implementation() == 'PyPy', reason='dict views not implemented in pyo3 for pypy'
86+
),
87+
),
88+
({1: 10, 2: 20, '3': '30'}, Err('Input should be a valid frozenset [kind=frozen_set_type,')),
89+
# https://github.com/samuelcolvin/pydantic-core/issues/211
90+
({1: 10, 2: 20, '3': '30'}.items(), Err('Input should be a valid frozenset [kind=frozen_set_type,')),
91+
((x for x in [1, 2, '3']), frozenset({1, 2, 3})),
7392
({'abc'}, Err('0\n Input should be a valid integer')),
7493
({1, 2, 'wrong'}, Err('Input should be a valid integer')),
7594
({1: 2}, Err('1 validation error for frozenset[int]\n Input should be a valid frozenset')),
7695
('abc', Err('Input should be a valid frozenset')),
77-
# Technically correct, but does anyone actually need this? I think needs a new type in pyo3
78-
pytest.param({1: 10, 2: 20, 3: 30}.keys(), {1, 2, 3}, marks=pytest.mark.xfail(raises=ValidationError)),
7996
],
8097
)
8198
def test_frozenset_ints_python(input_value, expected):
@@ -89,7 +106,10 @@ def test_frozenset_ints_python(input_value, expected):
89106
assert isinstance(output, frozenset)
90107

91108

92-
@pytest.mark.parametrize('input_value,expected', [([1, 2.5, '3'], {1, 2.5, '3'}), ([(1, 2), (3, 4)], {(1, 2), (3, 4)})])
109+
@pytest.mark.parametrize(
110+
'input_value,expected',
111+
[(frozenset([1, 2.5, '3']), {1, 2.5, '3'}), ([1, 2.5, '3'], {1, 2.5, '3'}), ([(1, 2), (3, 4)], {(1, 2), (3, 4)})],
112+
)
93113
def test_frozenset_no_validators_python(input_value, expected):
94114
v = SchemaValidator({'type': 'frozenset'})
95115
output = v.validate_python(input_value)
@@ -216,3 +236,20 @@ def test_repr():
216236
'strict:true,item_validator:None,size_range:Some((Some(42),None)),name:"frozenset[any]"'
217237
'}))'
218238
)
239+
240+
241+
def test_generator_error():
242+
def gen(error: bool):
243+
yield 1
244+
yield 2
245+
if error:
246+
raise RuntimeError('error')
247+
yield 3
248+
249+
v = SchemaValidator({'type': 'frozenset', 'items_schema': 'int'})
250+
r = v.validate_python(gen(False))
251+
assert r == {1, 2, 3}
252+
assert isinstance(r, frozenset)
253+
254+
with pytest.raises(ValidationError, match=r'Error iterating over object \[kind=iteration_error,'):
255+
v.validate_python(gen(True))

0 commit comments

Comments
 (0)