From 6aac326ac1540c1ae2546ab06e115368d8427297 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Fri, 25 Oct 2024 08:09:36 +0100 Subject: [PATCH 1/7] fix enum --- src/validators/enum_.rs | 14 ++++++++++++-- tests/validators/test_enums.py | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/validators/enum_.rs b/src/validators/enum_.rs index 73589d806..78fdd0924 100644 --- a/src/validators/enum_.rs +++ b/src/validators/enum_.rs @@ -1,10 +1,11 @@ +use std::collections::HashSet; // Validator for Enums, so named because "enum" is a reserved keyword in Rust. use std::marker::PhantomData; use pyo3::exceptions::PyTypeError; use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyType}; +use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyTuple, PyType}; use crate::build_tools::{is_strict, py_schema_err}; use crate::errors::{ErrorType, ValError, ValResult}; @@ -33,7 +34,7 @@ impl BuildValidator for BuildEnumValidator { let py = schema.py(); let value_str = intern!(py, "value"); - let expected: Vec<(Bound<'_, PyAny>, PyObject)> = members + let mut expected: Vec<(Bound<'_, PyAny>, PyObject)> = members .iter() .map(|v| Ok((v.getattr(value_str)?, v.into()))) .collect::>()?; @@ -43,6 +44,15 @@ impl BuildValidator for BuildEnumValidator { .map(|(k, _)| k.repr()?.extract()) .collect::>()?; + let mut addition = vec![]; + for (k, v) in &expected { + if let Ok(ss) = k.downcast::() { + let list = ss.to_list(); + addition.push((list.into_any(), v.clone())); + } + } + expected.append(&mut addition); + let class: Bound = schema.get_as_req(intern!(py, "cls"))?; let class_repr = class_repr(schema, &class)?; diff --git a/tests/validators/test_enums.py b/tests/validators/test_enums.py index 83e286417..faba0b1dd 100644 --- a/tests/validators/test_enums.py +++ b/tests/validators/test_enums.py @@ -262,6 +262,22 @@ class MyEnum(Enum): assert v.validate_python([2]) is MyEnum.b +def test_plain_enum_tuple(): + from pydantic import RootModel + + class MyEnum(Enum): + a = 1, 2 + b = 2, 3 + + assert MyEnum((1, 2)) is MyEnum.a + v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + assert v.validate_python((1, 2)) is MyEnum.a + assert v.validate_python((2, 3)) is MyEnum.b + serialised = RootModel[MyEnum](MyEnum.a).model_dump_json() + parsed = RootModel[MyEnum].model_validate_json(serialised) + assert parsed.root is MyEnum.a + + def test_plain_enum_empty(): class MyEnum(Enum): pass From 578c89367e74e0f7c08875699213c677d1bb7581 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Fri, 25 Oct 2024 08:21:11 +0100 Subject: [PATCH 2/7] fix test --- src/validators/enum_.rs | 1 - tests/validators/test_enums.py | 6 +----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/validators/enum_.rs b/src/validators/enum_.rs index 78fdd0924..9bbe9aee8 100644 --- a/src/validators/enum_.rs +++ b/src/validators/enum_.rs @@ -1,4 +1,3 @@ -use std::collections::HashSet; // Validator for Enums, so named because "enum" is a reserved keyword in Rust. use std::marker::PhantomData; diff --git a/tests/validators/test_enums.py b/tests/validators/test_enums.py index faba0b1dd..41cd28ea2 100644 --- a/tests/validators/test_enums.py +++ b/tests/validators/test_enums.py @@ -263,8 +263,6 @@ class MyEnum(Enum): def test_plain_enum_tuple(): - from pydantic import RootModel - class MyEnum(Enum): a = 1, 2 b = 2, 3 @@ -273,9 +271,7 @@ class MyEnum(Enum): v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) assert v.validate_python((1, 2)) is MyEnum.a assert v.validate_python((2, 3)) is MyEnum.b - serialised = RootModel[MyEnum](MyEnum.a).model_dump_json() - parsed = RootModel[MyEnum].model_validate_json(serialised) - assert parsed.root is MyEnum.a + assert v.validate_json('[1, 2]') is MyEnum.a def test_plain_enum_empty(): From dd7f7ca9e36b3c32f768024e0c76a587da6556b7 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Fri, 25 Oct 2024 08:09:36 +0100 Subject: [PATCH 3/7] fix enum --- src/validators/enum_.rs | 14 ++++++++++++-- tests/validators/test_enums.py | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/validators/enum_.rs b/src/validators/enum_.rs index 73589d806..78fdd0924 100644 --- a/src/validators/enum_.rs +++ b/src/validators/enum_.rs @@ -1,10 +1,11 @@ +use std::collections::HashSet; // Validator for Enums, so named because "enum" is a reserved keyword in Rust. use std::marker::PhantomData; use pyo3::exceptions::PyTypeError; use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyType}; +use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyTuple, PyType}; use crate::build_tools::{is_strict, py_schema_err}; use crate::errors::{ErrorType, ValError, ValResult}; @@ -33,7 +34,7 @@ impl BuildValidator for BuildEnumValidator { let py = schema.py(); let value_str = intern!(py, "value"); - let expected: Vec<(Bound<'_, PyAny>, PyObject)> = members + let mut expected: Vec<(Bound<'_, PyAny>, PyObject)> = members .iter() .map(|v| Ok((v.getattr(value_str)?, v.into()))) .collect::>()?; @@ -43,6 +44,15 @@ impl BuildValidator for BuildEnumValidator { .map(|(k, _)| k.repr()?.extract()) .collect::>()?; + let mut addition = vec![]; + for (k, v) in &expected { + if let Ok(ss) = k.downcast::() { + let list = ss.to_list(); + addition.push((list.into_any(), v.clone())); + } + } + expected.append(&mut addition); + let class: Bound = schema.get_as_req(intern!(py, "cls"))?; let class_repr = class_repr(schema, &class)?; diff --git a/tests/validators/test_enums.py b/tests/validators/test_enums.py index 83e286417..faba0b1dd 100644 --- a/tests/validators/test_enums.py +++ b/tests/validators/test_enums.py @@ -262,6 +262,22 @@ class MyEnum(Enum): assert v.validate_python([2]) is MyEnum.b +def test_plain_enum_tuple(): + from pydantic import RootModel + + class MyEnum(Enum): + a = 1, 2 + b = 2, 3 + + assert MyEnum((1, 2)) is MyEnum.a + v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + assert v.validate_python((1, 2)) is MyEnum.a + assert v.validate_python((2, 3)) is MyEnum.b + serialised = RootModel[MyEnum](MyEnum.a).model_dump_json() + parsed = RootModel[MyEnum].model_validate_json(serialised) + assert parsed.root is MyEnum.a + + def test_plain_enum_empty(): class MyEnum(Enum): pass From 85e8e09053f33e7136fb69b88451482c12e11959 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Fri, 25 Oct 2024 08:21:11 +0100 Subject: [PATCH 4/7] fix test --- src/validators/enum_.rs | 1 - tests/validators/test_enums.py | 6 +----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/validators/enum_.rs b/src/validators/enum_.rs index 78fdd0924..9bbe9aee8 100644 --- a/src/validators/enum_.rs +++ b/src/validators/enum_.rs @@ -1,4 +1,3 @@ -use std::collections::HashSet; // Validator for Enums, so named because "enum" is a reserved keyword in Rust. use std::marker::PhantomData; diff --git a/tests/validators/test_enums.py b/tests/validators/test_enums.py index faba0b1dd..41cd28ea2 100644 --- a/tests/validators/test_enums.py +++ b/tests/validators/test_enums.py @@ -263,8 +263,6 @@ class MyEnum(Enum): def test_plain_enum_tuple(): - from pydantic import RootModel - class MyEnum(Enum): a = 1, 2 b = 2, 3 @@ -273,9 +271,7 @@ class MyEnum(Enum): v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) assert v.validate_python((1, 2)) is MyEnum.a assert v.validate_python((2, 3)) is MyEnum.b - serialised = RootModel[MyEnum](MyEnum.a).model_dump_json() - parsed = RootModel[MyEnum].model_validate_json(serialised) - assert parsed.root is MyEnum.a + assert v.validate_json('[1, 2]') is MyEnum.a def test_plain_enum_empty(): From 48de7c6b4f8678672be150b20fe9252d8b93eda1 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Fri, 8 Nov 2024 08:22:55 +0000 Subject: [PATCH 5/7] validate values against their json form --- src/serializers/config.rs | 11 +++++++- src/serializers/mod.rs | 2 +- src/validators/enum_.rs | 58 ++++++++++++++++++++++++++++----------- 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/src/serializers/config.rs b/src/serializers/config.rs index 13a833176..ac9d364d1 100644 --- a/src/serializers/config.rs +++ b/src/serializers/config.rs @@ -14,7 +14,7 @@ use crate::tools::SchemaDict; use super::errors::py_err_se_err; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] #[allow(clippy::struct_field_names)] pub(crate) struct SerializationConfig { pub timedelta_mode: TimedeltaMode, @@ -57,6 +57,15 @@ macro_rules! serialization_mode { $($variant,)* } + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + $(Self::$variant => write!(f, $value),)* + } + + } + } + impl FromStr for $name { type Err = PyErr; diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 1a0405e2c..b209414c1 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -9,7 +9,7 @@ use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::py_gc::PyGcTraverse; pub(crate) use config::BytesMode; -use config::SerializationConfig; +pub(crate) use config::SerializationConfig; pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue}; use extra::{CollectWarnings, SerRecursionState, WarningsMode}; pub(crate) use extra::{DuckTypingSerMode, Extra, SerMode, SerializationState}; diff --git a/src/validators/enum_.rs b/src/validators/enum_.rs index 9bbe9aee8..21efc3b6b 100644 --- a/src/validators/enum_.rs +++ b/src/validators/enum_.rs @@ -4,15 +4,17 @@ use std::marker::PhantomData; use pyo3::exceptions::PyTypeError; use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyTuple, PyType}; +use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyType}; use crate::build_tools::{is_strict, py_schema_err}; use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::Input; +use crate::serializers::{to_jsonable_python, SerializationConfig}; use crate::tools::{safe_repr, SchemaDict}; use super::is_instance::class_repr; use super::literal::{expected_repr_name, LiteralLookup}; +use super::InputType; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, ValidationState, Validator}; #[derive(Debug, Clone)] @@ -33,36 +35,55 @@ impl BuildValidator for BuildEnumValidator { let py = schema.py(); let value_str = intern!(py, "value"); - let mut expected: Vec<(Bound<'_, PyAny>, PyObject)> = members + let expected_py: Vec<(Bound<'_, PyAny>, PyObject)> = members .iter() .map(|v| Ok((v.getattr(value_str)?, v.into()))) .collect::>()?; + let ser_config = SerializationConfig::from_config(config).unwrap_or_default(); + let expected_json: Vec<(Bound<'_, PyAny>, PyObject)> = members + .iter() + .map(|v| { + Ok(( + to_jsonable_python( + py, + &v.getattr(value_str)?, + None, + None, + false, + false, + false, + &ser_config.timedelta_mode.to_string(), + &ser_config.bytes_mode.to_string(), + &ser_config.inf_nan_mode.to_string(), + false, + None, + true, + None, + )? + .into_bound(py), + v.into(), + )) + }) + .collect::>()?; - let repr_args: Vec = expected + let repr_args: Vec = expected_py .iter() .map(|(k, _)| k.repr()?.extract()) .collect::>()?; - let mut addition = vec![]; - for (k, v) in &expected { - if let Ok(ss) = k.downcast::() { - let list = ss.to_list(); - addition.push((list.into_any(), v.clone())); - } - } - expected.append(&mut addition); - let class: Bound = schema.get_as_req(intern!(py, "cls"))?; let class_repr = class_repr(schema, &class)?; - let lookup = LiteralLookup::new(py, expected.into_iter())?; + let py_lookup = LiteralLookup::new(py, expected_py.into_iter())?; + let json_lookup = LiteralLookup::new(py, expected_json.into_iter())?; macro_rules! build { ($vv:ty, $name_prefix:literal) => { EnumValidator { phantom: PhantomData::<$vv>, class: class.clone().into(), - lookup, + py_lookup, + json_lookup, missing: schema.get_as(intern!(py, "missing"))?, expected_repr: expected_repr_name(repr_args, "").0, strict: is_strict(schema, config)?, @@ -96,7 +117,8 @@ pub trait EnumValidateValue: std::fmt::Debug + Clone + Send + Sync { pub struct EnumValidator { phantom: PhantomData, class: Py, - lookup: LiteralLookup, + py_lookup: LiteralLookup, + json_lookup: LiteralLookup, missing: Option, expected_repr: String, strict: bool, @@ -129,7 +151,11 @@ impl Validator for EnumValidator { state.floor_exactness(Exactness::Lax); - if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? { + let lookup = match state.extra().input_type { + InputType::Json => &self.json_lookup, + _ => &self.py_lookup, + }; + if let Some(v) = T::validate_value(py, input, lookup, strict)? { return Ok(v); } else if let Ok(res) = class.as_unbound().call1(py, (input.as_python(),)) { return Ok(res); From a1abde79ac865e44e48d752ffc83ab6bb7878da4 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Tue, 12 Nov 2024 07:43:08 +0000 Subject: [PATCH 6/7] add tests --- tests/validators/test_enums.py | 40 ++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/validators/test_enums.py b/tests/validators/test_enums.py index 41cd28ea2..9855e4b25 100644 --- a/tests/validators/test_enums.py +++ b/tests/validators/test_enums.py @@ -1,3 +1,4 @@ +import datetime import re import sys from decimal import Decimal @@ -274,6 +275,45 @@ class MyEnum(Enum): assert v.validate_json('[1, 2]') is MyEnum.a +def test_plain_enum_datetime(): + class MyEnum(Enum): + a = datetime.datetime.fromisoformat('2024-01-01T00:00:00Z') + + v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + assert v.validate_python(datetime.datetime.fromisoformat('2024-01-01T00:00:00Z')) is MyEnum.a + assert v.validate_json('"2024-01-01T00:00:00Z"') is MyEnum.a + + +def test_plain_enum_complex(): + class MyEnum(Enum): + a = complex(1, 2) + + v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + assert v.validate_python(complex(1, 2)) is MyEnum.a + assert v.validate_json('"1+2j"') is MyEnum.a + + +def test_plain_enum_identical_serialized_form(): + class MyEnum(Enum): + tuple_ = 1, 2 + list_ = [1, 2] + + v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + assert v.validate_python((1, 2)) is MyEnum.tuple_ + assert v.validate_python([1, 2]) is MyEnum.list_ + assert v.validate_json('[1,2]') is MyEnum.tuple_ + + # Change the order of `a` and `b` in MyEnum2; validate_json should pick [1, 2] this time + class MyEnum2(Enum): + list_ = [1, 2] + tuple_ = 1, 2 + + v = SchemaValidator(core_schema.enum_schema(MyEnum2, list(MyEnum2.__members__.values()))) + assert v.validate_python((1, 2)) is MyEnum2.tuple_ + assert v.validate_python([1, 2]) is MyEnum2.list_ + assert v.validate_json('[1,2]') is MyEnum2.list_ + + def test_plain_enum_empty(): class MyEnum(Enum): pass From 06ac56d241675245a9135e5e55dbfa2bc6085429 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Tue, 12 Nov 2024 07:50:44 +0000 Subject: [PATCH 7/7] fix tests --- tests/validators/test_enums.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/validators/test_enums.py b/tests/validators/test_enums.py index 9855e4b25..fc3a06403 100644 --- a/tests/validators/test_enums.py +++ b/tests/validators/test_enums.py @@ -277,11 +277,11 @@ class MyEnum(Enum): def test_plain_enum_datetime(): class MyEnum(Enum): - a = datetime.datetime.fromisoformat('2024-01-01T00:00:00Z') + a = datetime.datetime.fromisoformat('2024-01-01T00:00:00') v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) - assert v.validate_python(datetime.datetime.fromisoformat('2024-01-01T00:00:00Z')) is MyEnum.a - assert v.validate_json('"2024-01-01T00:00:00Z"') is MyEnum.a + assert v.validate_python(datetime.datetime.fromisoformat('2024-01-01T00:00:00')) is MyEnum.a + assert v.validate_json('"2024-01-01T00:00:00"') is MyEnum.a def test_plain_enum_complex():