Skip to content

Commit eeb51e8

Browse files
authored
Properly coerce fractions as int (#1757)
1 parent 37ec6e7 commit eeb51e8

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

src/input/input_python.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::str::from_utf8;
33
use pyo3::intern;
44
use pyo3::prelude::*;
55

6+
use pyo3::sync::GILOnceCell;
67
use pyo3::types::PyType;
78
use pyo3::types::{
89
PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator,
@@ -30,7 +31,8 @@ use super::input_abstract::ValMatch;
3031
use super::return_enums::EitherComplex;
3132
use super::return_enums::{iterate_attributes, iterate_mapping_items, ValidationMatch};
3233
use super::shared::{
33-
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int,
34+
decimal_as_int, float_as_int, fraction_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float,
35+
str_as_int,
3436
};
3537
use super::Arguments;
3638
use super::ConsumeIterator;
@@ -45,6 +47,20 @@ use super::{
4547
Input,
4648
};
4749

50+
static FRACTION_TYPE: GILOnceCell<Py<PyType>> = GILOnceCell::new();
51+
52+
pub fn get_fraction_type(py: Python) -> &Bound<'_, PyType> {
53+
FRACTION_TYPE
54+
.get_or_init(py, || {
55+
py.import("fractions")
56+
.and_then(|fractions_module| fractions_module.getattr("Fraction"))
57+
.unwrap()
58+
.extract()
59+
.unwrap()
60+
})
61+
.bind(py)
62+
}
63+
4864
pub(crate) fn downcast_python_input<'py, T: PyTypeCheck>(input: &(impl Input<'py> + ?Sized)) -> Option<&Bound<'py, T>> {
4965
input.as_python().and_then(|any| any.downcast::<T>().ok())
5066
}
@@ -269,6 +285,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
269285
float_as_int(self, self.extract::<f64>()?)
270286
} else if let Ok(decimal) = self.validate_decimal(true, self.py()) {
271287
decimal_as_int(self, &decimal.into_inner())
288+
} else if self.is_instance(get_fraction_type(self.py()))? {
289+
fraction_as_int(self)
272290
} else if let Ok(float) = self.extract::<f64>() {
273291
float_as_int(self, float)
274292
} else if let Some(enum_val) = maybe_as_enum(self) {

src/input/shared.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,23 @@ pub fn decimal_as_int<'py>(
227227
}
228228
Ok(EitherInt::Py(numerator))
229229
}
230+
231+
pub fn fraction_as_int<'py>(input: &Bound<'py, PyAny>) -> ValResult<EitherInt<'py>> {
232+
#[cfg(Py_3_12)]
233+
let is_integer = input.call_method0("is_integer")?.extract::<bool>()?;
234+
#[cfg(not(Py_3_12))]
235+
let is_integer = input.getattr("denominator")?.extract::<i64>().map_or(false, |d| d == 1);
236+
237+
if is_integer {
238+
#[cfg(Py_3_11)]
239+
let as_int = input.call_method0("__int__");
240+
#[cfg(not(Py_3_11))]
241+
let as_int = input.call_method0("__trunc__");
242+
match as_int {
243+
Ok(i) => Ok(EitherInt::Py(i.as_any().to_owned())),
244+
Err(_) => Err(ValError::new(ErrorTypeDefaults::IntType, input)),
245+
}
246+
} else {
247+
Err(ValError::new(ErrorTypeDefaults::IntFromFloat, input))
248+
}
249+
}

tests/validators/test_int.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import re
33
from decimal import Decimal
4+
from fractions import Fraction
45
from typing import Any
56

67
import pytest
@@ -132,13 +133,22 @@ def test_int_py_and_json(py_and_json: PyAndJson, input_value, expected):
132133
(-i64_max + 1, -i64_max + 1),
133134
(i64_max * 2, i64_max * 2),
134135
(-i64_max * 2, -i64_max * 2),
136+
(Fraction(10_935_244_710_974_505), 10_935_244_710_974_505), # https://github.com/pydantic/pydantic/issues/12063
137+
pytest.param(
138+
Fraction(1, 2),
139+
Err(
140+
'Input should be a valid integer, got a number with a fractional part '
141+
'[type=int_from_float, input_value=Fraction(1, 2), input_type=Fraction]'
142+
),
143+
id='fraction-remainder',
144+
),
135145
pytest.param(
136146
1.00000000001,
137147
Err(
138148
'Input should be a valid integer, got a number with a fractional part '
139149
'[type=int_from_float, input_value=1.00000000001, input_type=float]'
140150
),
141-
id='decimal-remainder',
151+
id='float-remainder',
142152
),
143153
pytest.param(
144154
Decimal('1.001'),

0 commit comments

Comments
 (0)