Skip to content

Commit d30d268

Browse files
authored
Consistent use of PyString and to_str (#210)
* Consistent use of PyString and to_str * tests for invalid unicode containing unpaired surrogates
1 parent b33dc3e commit d30d268

File tree

13 files changed

+100
-45
lines changed

13 files changed

+100
-45
lines changed

src/errors/line_error.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ impl<'a> ValError<'a> {
4343
Self::LineErrors(vec![ValLineError::new_with_loc(kind, input, loc)])
4444
}
4545

46+
pub fn new_custom_input(kind: ErrorKind, input_value: InputValue<'a>) -> ValError<'a> {
47+
Self::LineErrors(vec![ValLineError::new_custom_input(kind, input_value)])
48+
}
49+
4650
/// helper function to call with_outer on line items if applicable
4751
pub fn with_outer_location(self, loc_item: LocItem) -> Self {
4852
match self {
@@ -90,6 +94,14 @@ impl<'a> ValLineError<'a> {
9094
}
9195
}
9296

97+
pub fn new_custom_input(kind: ErrorKind, input_value: InputValue<'a>) -> ValLineError<'a> {
98+
Self {
99+
kind,
100+
input_value,
101+
location: Location::default(),
102+
}
103+
}
104+
93105
/// location is stored reversed so it's quicker to add "outer" items as that's what we always do
94106
/// hence `push` here instead of `insert`
95107
pub fn with_outer_location(mut self, loc_item: LocItem) -> Self {

src/errors/value_exception.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ impl PydanticValueError {
4646
if let Some(ref context) = self.context {
4747
for item in context.as_ref(py).items().iter() {
4848
let (key, value): (&PyString, &PyAny) = item.extract()?;
49-
if let Ok(value_str) = value.extract::<&PyString>() {
50-
message = message.replace(&format!("{{{}}}", key.to_str()?), value_str.to_str()?);
49+
if let Ok(py_str) = value.cast_as::<PyString>() {
50+
message = message.replace(&format!("{{{}}}", key.to_str()?), py_str.to_str()?);
5151
} else if let Ok(value_int) = value.extract::<i64>() {
5252
message = message.replace(&format!("{{{}}}", key.to_str()?), &value_int.to_string());
5353
} else {

src/input/input_python.rs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ use super::datetime::{
1818
};
1919
use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_int};
2020
use super::{
21-
repr_string, EitherBytes, EitherString, EitherTimedelta, GenericArguments, GenericListLike, GenericMapping, Input,
22-
PyArgs,
21+
py_string_str, repr_string, EitherBytes, EitherString, EitherTimedelta, GenericArguments, GenericListLike,
22+
GenericMapping, Input, PyArgs,
2323
};
2424

2525
impl<'a> Input<'a> for PyAny {
2626
fn as_loc_item(&self) -> LocItem {
27-
if let Ok(key_str) = self.extract::<String>() {
28-
key_str.into()
27+
if let Ok(py_str) = self.cast_as::<PyString>() {
28+
py_str.to_string_lossy().as_ref().into()
2929
} else if let Ok(key_int) = self.extract::<usize>() {
3030
key_int.into()
3131
} else {
@@ -140,8 +140,8 @@ impl<'a> Input<'a> for PyAny {
140140
if let Ok(py_bytes) = self.cast_as::<PyBytes>() {
141141
Ok(py_bytes.into())
142142
} else if let Ok(py_str) = self.cast_as::<PyString>() {
143-
let string = py_str.to_string_lossy().to_string();
144-
Ok(string.into_bytes().into())
143+
let str = py_string_str(py_str)?;
144+
Ok(str.as_bytes().into())
145145
} else if let Ok(py_byte_array) = self.cast_as::<PyByteArray>() {
146146
Ok(py_byte_array.to_vec().into())
147147
} else {
@@ -369,7 +369,8 @@ impl<'a> Input<'a> for PyAny {
369369
Err(ValError::new(ErrorKind::DateType, self))
370370
} else if let Ok(date) = self.cast_as::<PyDate>() {
371371
Ok(date.into())
372-
} else if let Ok(str) = self.extract::<String>() {
372+
} else if let Ok(py_str) = self.cast_as::<PyString>() {
373+
let str = py_string_str(py_str)?;
373374
bytes_as_date(self, str.as_bytes())
374375
} else if let Ok(py_bytes) = self.cast_as::<PyBytes>() {
375376
bytes_as_date(self, py_bytes.as_bytes())
@@ -389,7 +390,8 @@ impl<'a> Input<'a> for PyAny {
389390
fn lax_time(&self) -> ValResult<EitherTime> {
390391
if let Ok(time) = self.cast_as::<PyTime>() {
391392
Ok(time.into())
392-
} else if let Ok(str) = self.extract::<String>() {
393+
} else if let Ok(py_str) = self.cast_as::<PyString>() {
394+
let str = py_string_str(py_str)?;
393395
bytes_as_time(self, str.as_bytes())
394396
} else if let Ok(py_bytes) = self.cast_as::<PyBytes>() {
395397
bytes_as_time(self, py_bytes.as_bytes())
@@ -415,7 +417,8 @@ impl<'a> Input<'a> for PyAny {
415417
fn lax_datetime(&self) -> ValResult<EitherDateTime> {
416418
if let Ok(dt) = self.cast_as::<PyDateTime>() {
417419
Ok(dt.into())
418-
} else if let Ok(str) = self.extract::<String>() {
420+
} else if let Ok(py_str) = self.cast_as::<PyString>() {
421+
let str = py_string_str(py_str)?;
419422
bytes_as_datetime(self, str.as_bytes())
420423
} else if let Ok(py_bytes) = self.cast_as::<PyBytes>() {
421424
bytes_as_datetime(self, py_bytes.as_bytes())
@@ -444,7 +447,8 @@ impl<'a> Input<'a> for PyAny {
444447
if let Ok(dt) = self.cast_as::<PyDelta>() {
445448
Ok(dt.into())
446449
} else if let Ok(py_str) = self.cast_as::<PyString>() {
447-
bytes_as_timedelta(self, py_str.to_string_lossy().as_bytes())
450+
let str = py_string_str(py_str)?;
451+
bytes_as_timedelta(self, str.as_bytes())
448452
} else if let Ok(py_bytes) = self.cast_as::<PyBytes>() {
449453
bytes_as_timedelta(self, py_bytes.as_bytes())
450454
} else if let Ok(int) = self.extract::<i64>() {
@@ -518,7 +522,8 @@ fn from_attributes_applicable(obj: &PyAny) -> bool {
518522
/// Utility for extracting a string from a PyAny, if possible.
519523
fn maybe_as_string(v: &PyAny, unicode_error: ErrorKind) -> ValResult<Option<Cow<str>>> {
520524
if let Ok(py_string) = v.cast_as::<PyString>() {
521-
Ok(Some(py_string.to_string_lossy()))
525+
let str = py_string_str(py_string)?;
526+
Ok(Some(Cow::Borrowed(str)))
522527
} else if let Ok(bytes) = v.cast_as::<PyBytes>() {
523528
match from_utf8(bytes.as_bytes()) {
524529
Ok(s) => Ok(Some(Cow::Owned(s.to_string()))),

src/input/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pub use datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
1212
pub use input_abstract::Input;
1313
pub use parse_json::{JsonInput, JsonObject};
1414
pub use return_enums::{
15-
EitherBytes, EitherString, GenericArguments, GenericListLike, GenericMapping, JsonArgs, PyArgs,
15+
py_string_str, EitherBytes, EitherString, GenericArguments, GenericListLike, GenericMapping, JsonArgs, PyArgs,
1616
};
1717

1818
pub fn repr_string(v: &PyAny) -> PyResult<String> {

src/input/return_enums.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::borrow::Cow;
33
use pyo3::prelude::*;
44
use pyo3::types::{PyBytes, PyDict, PyFrozenSet, PyList, PySet, PyString, PyTuple};
55

6-
use crate::errors::{ErrorKind, ValError, ValLineError, ValResult};
6+
use crate::errors::{ErrorKind, InputValue, ValError, ValLineError, ValResult};
77
use crate::recursion_guard::RecursionGuard;
88
use crate::validators::{CombinedValidator, Extra, Validator};
99

@@ -212,10 +212,10 @@ pub enum EitherString<'a> {
212212
}
213213

214214
impl<'a> EitherString<'a> {
215-
pub fn as_cow(&self) -> Cow<str> {
215+
pub fn as_cow(&self) -> ValResult<'a, Cow<str>> {
216216
match self {
217-
Self::Cow(data) => data.clone(),
218-
Self::Py(py_str) => py_str.to_string_lossy(),
217+
Self::Cow(data) => Ok(data.clone()),
218+
Self::Py(py_str) => Ok(Cow::Borrowed(py_string_str(py_str)?)),
219219
}
220220
}
221221

@@ -251,6 +251,12 @@ impl<'a> IntoPy<PyObject> for EitherString<'a> {
251251
}
252252
}
253253

254+
pub fn py_string_str(py_str: &PyString) -> ValResult<&str> {
255+
py_str.to_str().map_err(|_| {
256+
ValError::new_custom_input(ErrorKind::StrUnicode, InputValue::PyObject(py_str.into_py(py_str.py())))
257+
})
258+
}
259+
254260
#[cfg_attr(debug_assertions, derive(Debug))]
255261
pub enum EitherBytes<'a> {
256262
Cow(Cow<'a, [u8]>),

src/lookup_key.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,9 @@ impl LookupKey {
6464
};
6565
let mut locs: Vec<Path> = if first.cast_as::<PyString>().is_ok() {
6666
// list of strings rather than list of lists
67-
vec![Self::path_choice(py, list)?]
67+
vec![Self::path_choice(list)?]
6868
} else {
69-
list.iter()
70-
.map(|obj| Self::path_choice(py, obj))
71-
.collect::<PyResult<_>>()?
69+
list.iter().map(Self::path_choice).collect::<PyResult<_>>()?
7270
};
7371

7472
if let Some(alt_alias) = alt_alias {
@@ -82,12 +80,12 @@ impl LookupKey {
8280
LookupKey::Simple(key.to_string(), py_string!(py, key))
8381
}
8482

85-
fn path_choice(py: Python, obj: &PyAny) -> PyResult<Path> {
83+
fn path_choice(obj: &PyAny) -> PyResult<Path> {
8684
let path = obj
8785
.extract::<&PyList>()?
8886
.iter()
8987
.enumerate()
90-
.map(|(index, obj)| PathItem::from_py(py, index, obj))
88+
.map(|(index, obj)| PathItem::from_py(index, obj))
9189
.collect::<PyResult<Path>>()?;
9290

9391
if path.is_empty() {
@@ -237,10 +235,10 @@ fn path_to_string(path: &Path) -> String {
237235
}
238236

239237
impl PathItem {
240-
pub fn from_py(py: Python, index: usize, obj: &PyAny) -> PyResult<Self> {
241-
if let Ok(str_key) = obj.extract::<String>() {
242-
let py_str_key = py_string!(py, &str_key);
243-
Ok(Self::S(str_key, py_str_key))
238+
pub fn from_py(index: usize, obj: &PyAny) -> PyResult<Self> {
239+
if let Ok(py_str_key) = obj.cast_as::<PyString>() {
240+
let str_key = py_str_key.to_str()?.to_string();
241+
Ok(Self::S(str_key, py_str_key.into()))
244242
} else {
245243
let int_key = obj.extract::<usize>()?;
246244
if index == 0 {

src/validators/arguments.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ impl Validator for ArgumentsValidator {
281281
}
282282
Err(err) => return Err(err),
283283
};
284-
if !used_kwargs.contains(either_str.as_cow().as_ref()) {
284+
if !used_kwargs.contains(either_str.as_cow()?.as_ref()) {
285285
match self.var_kwargs_validator {
286286
Some(ref validator) => match validator.validate(py, value, extra, slots, recursion_guard) {
287287
Ok(value) => output_kwargs.set_item(either_str.as_py_string(py), value)?,

src/validators/literal.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::hash::BuildHasherDefault;
22

33
use pyo3::intern;
44
use pyo3::prelude::*;
5-
use pyo3::types::{PyDict, PyList};
5+
use pyo3::types::{PyDict, PyList, PyString};
66

77
use ahash::AHashSet;
88

@@ -29,8 +29,8 @@ impl BuildValidator for LiteralBuilder {
2929
return py_error!(r#""expected" should have length > 0"#);
3030
} else if expected.len() == 1 {
3131
let first = expected.get_item(0)?;
32-
if let Ok(str) = first.extract::<String>() {
33-
return Ok(LiteralSingleStringValidator::new(str).into());
32+
if let Ok(py_str) = first.cast_as::<PyString>() {
33+
return Ok(LiteralSingleStringValidator::new(py_str.to_str()?.to_string()).into());
3434
}
3535
if let Ok(int) = first.extract::<i64>() {
3636
return Ok(LiteralSingleIntValidator::new(int).into());
@@ -72,7 +72,7 @@ impl Validator for LiteralSingleStringValidator {
7272
_recursion_guard: &'s mut RecursionGuard,
7373
) -> ValResult<'data, PyObject> {
7474
let either_str = input.strict_str()?;
75-
if either_str.as_cow().as_ref() == self.expected.as_str() {
75+
if either_str.as_cow()?.as_ref() == self.expected.as_str() {
7676
Ok(input.to_object(py))
7777
} else {
7878
Err(ValError::new(
@@ -166,7 +166,7 @@ impl Validator for LiteralMultipleStringsValidator {
166166
_recursion_guard: &'s mut RecursionGuard,
167167
) -> ValResult<'data, PyObject> {
168168
let either_str = input.strict_str()?;
169-
if self.expected.contains(either_str.as_cow().as_ref()) {
169+
if self.expected.contains(either_str.as_cow()?.as_ref()) {
170170
Ok(input.to_object(py))
171171
} else {
172172
Err(ValError::new(
@@ -255,8 +255,8 @@ impl LiteralGeneralValidator {
255255
repr_args.push(item.repr()?.extract()?);
256256
if let Ok(int) = item.extract::<i64>() {
257257
expected_int.insert(int);
258-
} else if let Ok(str) = item.extract::<String>() {
259-
expected_str.insert(str);
258+
} else if let Ok(py_str) = item.cast_as::<PyString>() {
259+
expected_str.insert(py_str.to_str()?.to_string());
260260
} else {
261261
expected_py.append(item)?;
262262
}
@@ -291,7 +291,7 @@ impl Validator for LiteralGeneralValidator {
291291
}
292292
if !self.expected_str.is_empty() {
293293
if let Ok(either_str) = input.strict_str() {
294-
if self.expected_str.contains(either_str.as_cow().as_ref()) {
294+
if self.expected_str.contains(either_str.as_cow()?.as_ref()) {
295295
return Ok(input.to_object(py));
296296
}
297297
}

src/validators/string.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl Validator for StrConstrainedValidator {
9090
_recursion_guard: &'s mut RecursionGuard,
9191
) -> ValResult<'data, PyObject> {
9292
let either_str = input.validate_str(extra.strict.unwrap_or(self.strict))?;
93-
let cow = either_str.as_cow();
93+
let cow = either_str.as_cow()?;
9494
let mut str = cow.as_ref();
9595
if let Some(min_length) = self.min_length {
9696
if str.len() < min_length {

src/validators/typed_dict.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ impl Validator for TypedDictValidator {
313313
}
314314
Err(err) => return Err(err),
315315
};
316-
if used_keys.contains(either_str.as_cow().as_ref()) {
316+
if used_keys.contains(either_str.as_cow()?.as_ref()) {
317317
continue;
318318
}
319319

0 commit comments

Comments
 (0)