Skip to content

Commit 6769140

Browse files
Fix parsing int from large decimals (#948)
Co-authored-by: David Hewitt <1939362+davidhewitt@users.noreply.github.com>
1 parent 2d9df49 commit 6769140

File tree

6 files changed

+97
-46
lines changed

6 files changed

+97
-46
lines changed

src/input/input_abstract.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,17 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
152152
self.strict_float()
153153
}
154154

155-
fn validate_decimal(&'a self, strict: bool, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
155+
fn validate_decimal(&'a self, strict: bool, py: Python<'a>) -> ValResult<&'a PyAny> {
156156
if strict {
157-
self.strict_decimal(decimal_type)
157+
self.strict_decimal(py)
158158
} else {
159-
self.lax_decimal(decimal_type)
159+
self.lax_decimal(py)
160160
}
161161
}
162-
fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny>;
162+
fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny>;
163163
#[cfg_attr(has_no_coverage, no_coverage)]
164-
fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
165-
self.strict_decimal(decimal_type)
164+
fn lax_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> {
165+
self.strict_decimal(py)
166166
}
167167

168168
fn validate_dict(&'a self, strict: bool) -> ValResult<GenericMapping<'a>> {

src/input/input_json.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::borrow::Cow;
22

33
use pyo3::prelude::*;
4-
use pyo3::types::{PyDict, PyString, PyType};
4+
use pyo3::types::{PyDict, PyString};
55
use speedate::MicrosecondsPrecisionOverflowBehavior;
66
use strum::EnumMessage;
77

@@ -178,13 +178,12 @@ impl<'a> Input<'a> for JsonInput {
178178
}
179179
}
180180

181-
fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
182-
let py = decimal_type.py();
181+
fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> {
183182
match self {
184-
JsonInput::Float(f) => create_decimal(PyString::new(py, &f.to_string()), self, decimal_type),
183+
JsonInput::Float(f) => create_decimal(PyString::new(py, &f.to_string()), self, py),
185184

186185
JsonInput::String(..) | JsonInput::Int(..) | JsonInput::Uint(..) | JsonInput::BigInt(..) => {
187-
create_decimal(self.to_object(py).into_ref(py), self, decimal_type)
186+
create_decimal(self.to_object(py).into_ref(py), self, py)
188187
}
189188
_ => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
190189
}
@@ -439,9 +438,8 @@ impl<'a> Input<'a> for String {
439438
str_as_float(self, self)
440439
}
441440

442-
fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
443-
let py = decimal_type.py();
444-
create_decimal(self.to_object(py).into_ref(py), self, decimal_type)
441+
fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> {
442+
create_decimal(self.to_object(py).into_ref(py), self, py)
445443
}
446444

447445
#[cfg_attr(has_no_coverage, no_coverage)]

src/input/input_python.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ use speedate::MicrosecondsPrecisionOverflowBehavior;
1313

1414
use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
1515
use crate::tools::{extract_i64, safe_repr};
16-
use crate::validators::decimal::create_decimal;
16+
use crate::validators::decimal::{create_decimal, get_decimal_type};
1717
use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl};
1818

1919
use super::datetime::{
2020
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime,
2121
float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime,
2222
EitherTime,
2323
};
24-
use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int};
24+
use super::shared::{decimal_as_int, float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int};
2525
use super::{
2626
py_string_str, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments,
2727
GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs,
@@ -324,6 +324,10 @@ impl<'a> Input<'a> for PyAny {
324324
} else if PyInt::is_type_of(self) {
325325
// force to an int to upcast to a pure python int to maintain current behaviour
326326
EitherInt::upcast(self)
327+
} else if PyFloat::is_exact_type_of(self) {
328+
float_as_int(self, self.extract::<f64>()?)
329+
} else if let Ok(decimal) = self.strict_decimal(self.py()) {
330+
decimal_as_int(self.py(), self, decimal)
327331
} else if let Ok(float) = self.extract::<f64>() {
328332
float_as_int(self, float)
329333
} else {
@@ -367,15 +371,17 @@ impl<'a> Input<'a> for PyAny {
367371
}
368372
}
369373

370-
fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
374+
fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> {
375+
let decimal_type_obj: Py<PyType> = get_decimal_type(py);
376+
let decimal_type = decimal_type_obj.as_ref(py);
371377
// Fast path for existing decimal objects
372378
if self.is_exact_instance(decimal_type) {
373379
return Ok(self);
374380
}
375381

376382
// Try subclasses of decimals, they will be upcast to Decimal
377383
if self.is_instance(decimal_type)? {
378-
return create_decimal(self, self, decimal_type);
384+
return create_decimal(self, self, py);
379385
}
380386

381387
Err(ValError::new(
@@ -387,20 +393,22 @@ impl<'a> Input<'a> for PyAny {
387393
))
388394
}
389395

390-
fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
396+
fn lax_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> {
397+
let decimal_type_obj: Py<PyType> = get_decimal_type(py);
398+
let decimal_type = decimal_type_obj.as_ref(py);
391399
// Fast path for existing decimal objects
392400
if self.is_exact_instance(decimal_type) {
393401
return Ok(self);
394402
}
395403

396404
if self.is_instance_of::<PyString>() || (self.is_instance_of::<PyInt>() && !self.is_instance_of::<PyBool>()) {
397405
// checking isinstance for str / int / bool is fast compared to decimal / float
398-
create_decimal(self, self, decimal_type)
406+
create_decimal(self, self, py)
399407
} else if self.is_instance(decimal_type)? {
400408
// upcast subclasses to decimal
401-
return create_decimal(self, self, decimal_type);
409+
return create_decimal(self, self, py);
402410
} else if self.is_instance_of::<PyFloat>() {
403-
create_decimal(self.str()?, self, decimal_type)
411+
create_decimal(self.str()?, self, py)
404412
} else {
405413
Err(ValError::new(ErrorTypeDefaults::DecimalType, self))
406414
}

src/input/shared.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use num_bigint::BigInt;
2+
use pyo3::{intern, PyAny, Python};
23

34
use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult};
45
use crate::input::EitherInt;
@@ -136,3 +137,16 @@ pub fn float_as_int<'a>(input: &'a impl Input<'a>, float: f64) -> ValResult<'a,
136137
Err(ValError::new(ErrorTypeDefaults::IntParsingSize, input))
137138
}
138139
}
140+
141+
pub fn decimal_as_int<'a>(py: Python, input: &'a impl Input<'a>, decimal: &'a PyAny) -> ValResult<'a, EitherInt<'a>> {
142+
if !decimal.call_method0(intern!(py, "is_finite"))?.extract::<bool>()? {
143+
return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input));
144+
}
145+
let (numerator, denominator) = decimal
146+
.call_method0(intern!(py, "as_integer_ratio"))?
147+
.extract::<(&PyAny, &PyAny)>()?;
148+
if denominator.extract::<i64>().map_or(true, |d| d != 1) {
149+
return Err(ValError::new(ErrorTypeDefaults::IntFromFloat, input));
150+
}
151+
Ok(EitherInt::Py(numerator))
152+
}

src/validators/decimal.rs

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use pyo3::exceptions::{PyTypeError, PyValueError};
2+
use pyo3::sync::GILOnceCell;
23
use pyo3::types::{IntoPyDict, PyDict, PyTuple, PyType};
34
use pyo3::{intern, AsPyPointer};
45
use pyo3::{prelude::*, PyTypeInfo};
@@ -13,6 +14,21 @@ use crate::tools::SchemaDict;
1314

1415
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
1516

17+
static DECIMAL_TYPE: GILOnceCell<Py<PyType>> = GILOnceCell::new();
18+
19+
pub fn get_decimal_type(py: Python) -> Py<PyType> {
20+
DECIMAL_TYPE
21+
.get_or_init(py, || {
22+
py.import("decimal")
23+
.and_then(|decimal_module| decimal_module.getattr("Decimal"))
24+
.unwrap()
25+
.extract::<&PyType>()
26+
.unwrap()
27+
.into()
28+
})
29+
.clone()
30+
}
31+
1632
#[derive(Debug, Clone)]
1733
pub struct DecimalValidator {
1834
strict: bool,
@@ -25,7 +41,6 @@ pub struct DecimalValidator {
2541
gt: Option<Py<PyAny>>,
2642
max_digits: Option<u64>,
2743
decimal_places: Option<u64>,
28-
decimal_type: Py<PyType>,
2944
}
3045

3146
impl BuildValidator for DecimalValidator {
@@ -55,10 +70,6 @@ impl BuildValidator for DecimalValidator {
5570
ge: schema.get_as(intern!(py, "ge"))?,
5671
gt: schema.get_as(intern!(py, "gt"))?,
5772
max_digits,
58-
decimal_type: py
59-
.import(intern!(py, "decimal"))?
60-
.getattr(intern!(py, "Decimal"))?
61-
.extract()?,
6273
}
6374
.into())
6475
}
@@ -69,8 +80,7 @@ impl_py_gc_traverse!(DecimalValidator {
6980
le,
7081
lt,
7182
ge,
72-
gt,
73-
decimal_type
83+
gt
7484
});
7585

7686
impl Validator for DecimalValidator {
@@ -80,11 +90,7 @@ impl Validator for DecimalValidator {
8090
input: &'data impl Input<'data>,
8191
state: &mut ValidationState,
8292
) -> ValResult<'data, PyObject> {
83-
let decimal = input.validate_decimal(
84-
state.strict_or(self.strict),
85-
// Safety: self and py both outlive this call
86-
unsafe { py.from_borrowed_ptr(self.decimal_type.as_ptr()) },
87-
)?;
93+
let decimal = input.validate_decimal(state.strict_or(self.strict), py)?;
8894

8995
if !self.allow_inf_nan || self.check_digits {
9096
if !decimal.call_method0(intern!(py, "is_finite"))?.extract()? {
@@ -244,19 +250,23 @@ impl Validator for DecimalValidator {
244250
pub(crate) fn create_decimal<'a>(
245251
arg: &'a PyAny,
246252
input: &'a impl Input<'a>,
247-
decimal_type: &'a PyType,
253+
py: Python<'a>,
248254
) -> ValResult<'a, &'a PyAny> {
249-
decimal_type.call1((arg,)).map_err(|e| {
250-
let decimal_exception = match arg
251-
.py()
252-
.import("decimal")
253-
.and_then(|decimal_module| decimal_module.getattr("DecimalException"))
254-
{
255-
Ok(decimal_exception) => decimal_exception,
256-
Err(e) => return ValError::InternalErr(e),
257-
};
258-
handle_decimal_new_error(arg.py(), input.as_error_value(), e, decimal_exception)
259-
})
255+
let decimal_type_obj: Py<PyType> = get_decimal_type(py);
256+
decimal_type_obj
257+
.call1(py, (arg,))
258+
.map_err(|e| {
259+
let decimal_exception = match arg
260+
.py()
261+
.import("decimal")
262+
.and_then(|decimal_module| decimal_module.getattr("DecimalException"))
263+
{
264+
Ok(decimal_exception) => decimal_exception,
265+
Err(e) => return ValError::InternalErr(e),
266+
};
267+
handle_decimal_new_error(arg.py(), input.as_error_value(), e, decimal_exception)
268+
})
269+
.map(|v| v.into_ref(py))
260270
}
261271

262272
fn handle_decimal_new_error<'a>(

tests/validators/test_int.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,24 @@ def test_int_py_and_json(py_and_json: PyAndJson, input_value, expected):
5858
'input_value,expected',
5959
[
6060
(Decimal('1'), 1),
61+
(Decimal('1' + '0' * 1_000), int('1' + '0' * 1_000)), # a large decimal
6162
(Decimal('1.0'), 1),
63+
(1.0, 1),
6264
(i64_max, i64_max),
6365
(str(i64_max), i64_max),
6466
(str(i64_max * 2), i64_max * 2),
6567
(i64_max + 1, i64_max + 1),
6668
(-i64_max + 1, -i64_max + 1),
6769
(i64_max * 2, i64_max * 2),
6870
(-i64_max * 2, -i64_max * 2),
71+
pytest.param(
72+
1.00000000001,
73+
Err(
74+
'Input should be a valid integer, got a number with a fractional part '
75+
'[type=int_from_float, input_value=1.00000000001, input_type=float]'
76+
),
77+
id='decimal-remainder',
78+
),
6979
pytest.param(
7080
Decimal('1.001'),
7181
Err(
@@ -437,3 +447,14 @@ def test_int_subclass_constraint() -> None:
437447

438448
with pytest.raises(ValidationError, match='Input should be greater than 0'):
439449
v.validate_python(IntSubclass(0))
450+
451+
452+
class FloatSubclass(float):
453+
pass
454+
455+
456+
def test_float_subclass() -> None:
457+
v = SchemaValidator({'type': 'int'})
458+
v_lax = v.validate_python(FloatSubclass(1))
459+
assert v_lax == 1
460+
assert type(v_lax) == int

0 commit comments

Comments
 (0)