Skip to content

Commit 46379ac

Browse files
authored
tidy up tagged_union_schema (#1333)
1 parent 9507a28 commit 46379ac

File tree

3 files changed

+16
-19
lines changed

3 files changed

+16
-19
lines changed

python/pydantic_core/core_schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2466,8 +2466,8 @@ class TaggedUnionSchema(TypedDict, total=False):
24662466

24672467

24682468
def tagged_union_schema(
2469-
choices: Dict[Hashable, CoreSchema],
2470-
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], Hashable],
2469+
choices: Dict[Any, CoreSchema],
2470+
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], Any],
24712471
*,
24722472
custom_error_type: str | None = None,
24732473
custom_error_message: str | None = None,

src/validators/literal.rs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// Validator for things inside of a typing.Literal[]
22
// which can be an int, a string, bytes or an Enum value (including `class Foo(str, Enum)` type enums)
33
use core::fmt::Debug;
4-
use std::cmp::Ordering;
54

65
use pyo3::prelude::*;
76
use pyo3::types::{PyDict, PyInt, PyList};
@@ -35,7 +34,7 @@ pub struct LiteralLookup<T: Debug> {
3534
// Catch all for hashable types like Enum and bytes (the latter only because it is seldom used)
3635
expected_py_dict: Option<Py<PyDict>>,
3736
// Catch all for unhashable types like list
38-
expected_py_list: Option<Py<PyList>>,
37+
expected_py_values: Option<Vec<(Py<PyAny>, usize)>>,
3938

4039
pub values: Vec<T>,
4140
}
@@ -46,7 +45,7 @@ impl<T: Debug> LiteralLookup<T> {
4645
let mut expected_int = AHashMap::new();
4746
let mut expected_str: AHashMap<String, usize> = AHashMap::new();
4847
let expected_py_dict = PyDict::new_bound(py);
49-
let expected_py_list = PyList::empty_bound(py);
48+
let mut expected_py_values = Vec::new();
5049
let mut values = Vec::new();
5150
for (k, v) in expected {
5251
let id = values.len();
@@ -71,7 +70,7 @@ impl<T: Debug> LiteralLookup<T> {
7170
.map_err(|_| py_schema_error_type!("error extracting str {:?}", k))?;
7271
expected_str.insert(str.to_string(), id);
7372
} else if expected_py_dict.set_item(&k, id).is_err() {
74-
expected_py_list.append((&k, id))?;
73+
expected_py_values.push((k.as_unbound().clone_ref(py), id));
7574
}
7675
}
7776

@@ -92,9 +91,9 @@ impl<T: Debug> LiteralLookup<T> {
9291
true => None,
9392
false => Some(expected_py_dict.into()),
9493
},
95-
expected_py_list: match expected_py_list.is_empty() {
94+
expected_py_values: match expected_py_values.is_empty() {
9695
true => None,
97-
false => Some(expected_py_list.into()),
96+
false => Some(expected_py_values),
9897
},
9998
values,
10099
})
@@ -143,23 +142,23 @@ impl<T: Debug> LiteralLookup<T> {
143142
}
144143
}
145144
}
145+
// cache py_input if needed, since we might need it for multiple lookups
146+
let mut py_input = None;
146147
if let Some(expected_py_dict) = &self.expected_py_dict {
148+
let py_input = py_input.get_or_insert_with(|| input.to_object(py));
147149
// We don't use ? to unpack the result of `get_item` in the next line because unhashable
148150
// inputs will produce a TypeError, which in this case we just want to treat equivalently
149151
// to a failed lookup
150-
if let Ok(Some(v)) = expected_py_dict.bind(py).get_item(input) {
152+
if let Ok(Some(v)) = expected_py_dict.bind(py).get_item(&*py_input) {
151153
let id: usize = v.extract().unwrap();
152154
return Ok(Some((input, &self.values[id])));
153155
}
154156
};
155-
if let Some(expected_py_list) = &self.expected_py_list {
156-
for item in expected_py_list.bind(py) {
157-
let (k, id): (Bound<PyAny>, usize) = item.extract()?;
158-
if k.compare(input.to_object(py).bind(py))
159-
.unwrap_or(Ordering::Less)
160-
.is_eq()
161-
{
162-
return Ok(Some((input, &self.values[id])));
157+
if let Some(expected_py_values) = &self.expected_py_values {
158+
let py_input = py_input.get_or_insert_with(|| input.to_object(py));
159+
for (k, id) in expected_py_values {
160+
if k.bind(py).eq(&*py_input).unwrap_or(false) {
161+
return Ok(Some((input, &self.values[*id])));
163162
}
164163
}
165164
};

src/validators/union.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,11 +344,9 @@ impl BuildValidator for TaggedUnionValidator {
344344
let mut tags_repr = String::with_capacity(50);
345345
let mut descr = String::with_capacity(50);
346346
let mut first = true;
347-
let mut discriminators = Vec::with_capacity(choices.len());
348347
let schema_choices: Bound<PyDict> = schema.get_as_req(intern!(py, "choices"))?;
349348
let mut lookup_map = Vec::with_capacity(choices.len());
350349
for (choice_key, choice_schema) in schema_choices {
351-
discriminators.push(choice_key.clone());
352350
let validator = build_validator(&choice_schema, config, definitions)?;
353351
let tag_repr = choice_key.repr()?.to_string();
354352
if first {

0 commit comments

Comments
 (0)