Skip to content

Commit 5240970

Browse files
authored
refactor(query): use bigint to handle the fallback of decimal op overflow (#16215)
* refactor(query): use bigint to handle fallback overflow * refactor(query): use bigint to handle fallback overflow
1 parent aed29dc commit 5240970

File tree

6 files changed

+151
-19
lines changed

6 files changed

+151
-19
lines changed

Cargo.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/query/expression/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ log = { workspace = true }
4444
match-template = { workspace = true }
4545
memchr = { version = "2", default-features = false }
4646
micromarshal = "0.5.0"
47+
num-bigint = "0.4.6"
4748
num-traits = "0.2.15"
4849
ordered-float = { workspace = true, features = ["serde", "rand", "borsh"] }
4950
rand = { workspace = true }

src/query/expression/src/types/decimal.rs

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::cmp::Ordering;
1516
use std::fmt::Debug;
1617
use std::marker::PhantomData;
1718
use std::ops::Range;
@@ -25,9 +26,13 @@ use databend_common_io::display_decimal_128;
2526
use databend_common_io::display_decimal_256;
2627
use enum_as_inner::EnumAsInner;
2728
use ethnum::i256;
29+
use ethnum::u256;
2830
use ethnum::AsI256;
2931
use itertools::Itertools;
32+
use num_bigint::BigInt;
33+
use num_traits::FromBytes;
3034
use num_traits::NumCast;
35+
use num_traits::ToPrimitive;
3136
use serde::Deserialize;
3237
use serde::Serialize;
3338

@@ -356,7 +361,7 @@ pub trait Decimal:
356361
fn checked_mul(self, rhs: Self) -> Option<Self>;
357362
fn checked_rem(self, rhs: Self) -> Option<Self>;
358363

359-
fn do_round_div(self, rhs: Self, mul: Self) -> Option<Self>;
364+
fn do_round_div(self, rhs: Self, mul_scale: u32) -> Option<Self>;
360365

361366
// mul two decimals and return a decimal with rounding option
362367
fn do_round_mul(self, rhs: Self, shift_scale: u32) -> Option<Self>;
@@ -368,6 +373,7 @@ pub trait Decimal:
368373

369374
fn from_float(value: f64) -> Self;
370375
fn from_i128<U: Into<i128>>(value: U) -> Self;
376+
fn from_bigint(value: BigInt) -> Option<Self>;
371377

372378
fn de_binary(bytes: &mut &[u8]) -> Self;
373379
fn display(self, scale: u8) -> String;
@@ -471,7 +477,8 @@ impl Decimal for i128 {
471477
}
472478
}
473479

474-
fn do_round_div(self, rhs: Self, mul: Self) -> Option<Self> {
480+
fn do_round_div(self, rhs: Self, mul_scale: u32) -> Option<Self> {
481+
let mul = i256::e(mul_scale);
475482
if self.is_negative() == rhs.is_negative() {
476483
let res = (i256::from(self) * i256::from(mul) + i256::from(rhs) / 2) / i256::from(rhs);
477484
Some(*res.low())
@@ -535,6 +542,10 @@ impl Decimal for i128 {
535542
value.into()
536543
}
537544

545+
fn from_bigint(value: BigInt) -> Option<Self> {
546+
value.to_i128()
547+
}
548+
538549
fn de_binary(bytes: &mut &[u8]) -> Self {
539550
let bs: [u8; std::mem::size_of::<Self>()] =
540551
bytes[0..std::mem::size_of::<Self>()].try_into().unwrap();
@@ -687,19 +698,48 @@ impl Decimal for i256 {
687698

688699
fn do_round_mul(self, rhs: Self, shift_scale: u32) -> Option<Self> {
689700
let div = i256::e(shift_scale);
690-
if self.is_negative() == rhs.is_negative() {
701+
let ret: Option<i256> = if self.is_negative() == rhs.is_negative() {
691702
self.checked_mul(rhs).map(|x| (x + div / 2) / div)
692703
} else {
693704
self.checked_mul(rhs).map(|x| (x - div / 2) / div)
694-
}
705+
};
706+
707+
ret.or_else(|| {
708+
let a = BigInt::from_le_bytes(&self.to_le_bytes());
709+
let b = BigInt::from_le_bytes(&rhs.to_le_bytes());
710+
let div = BigInt::from(10).pow(shift_scale);
711+
if self.is_negative() == rhs.is_negative() {
712+
Self::from_bigint((a * b + div.clone() / 2) / div)
713+
} else {
714+
Self::from_bigint((a * b - div.clone() / 2) / div)
715+
}
716+
})
695717
}
696718

697-
fn do_round_div(self, rhs: Self, mul: Self) -> Option<Self> {
698-
if self.is_negative() == rhs.is_negative() {
719+
fn do_round_div(self, rhs: Self, mul_scale: u32) -> Option<Self> {
720+
let fallback = || {
721+
let a = BigInt::from_le_bytes(&self.to_le_bytes());
722+
let b = BigInt::from_le_bytes(&rhs.to_le_bytes());
723+
let mul = BigInt::from(10).pow(mul_scale);
724+
if self.is_negative() == rhs.is_negative() {
725+
Self::from_bigint((a * mul + b.clone() / 2) / b)
726+
} else {
727+
Self::from_bigint((a * mul - b.clone() / 2) / b)
728+
}
729+
};
730+
731+
if mul_scale >= MAX_DECIMAL256_PRECISION as _ {
732+
return fallback();
733+
}
734+
735+
let mul = i256::e(mul_scale);
736+
let ret: Option<i256> = if self.is_negative() == rhs.is_negative() {
699737
self.checked_mul(mul).map(|x| (x + rhs / 2) / rhs)
700738
} else {
701739
self.checked_mul(mul).map(|x| (x - rhs / 2) / rhs)
702-
}
740+
};
741+
742+
ret.or_else(fallback)
703743
}
704744

705745
fn min_for_precision(to_precision: u8) -> Self {
@@ -725,6 +765,32 @@ impl Decimal for i256 {
725765
i256::from(value.into())
726766
}
727767

768+
fn from_bigint(value: BigInt) -> Option<Self> {
769+
let mut ret: u256 = u256::ZERO;
770+
let mut bits = 0;
771+
772+
for i in value.iter_u64_digits() {
773+
if bits >= 256 {
774+
return None;
775+
}
776+
ret |= u256::from(i) << bits;
777+
bits += 64;
778+
}
779+
780+
match value.sign() {
781+
num_bigint::Sign::Plus => i256::try_from(ret).ok(),
782+
num_bigint::Sign::NoSign => Some(i256::ZERO),
783+
num_bigint::Sign::Minus => {
784+
let m: u256 = u256::ONE << 255;
785+
match ret.cmp(&m) {
786+
Ordering::Less => Some(-i256::try_from(ret).unwrap()),
787+
Ordering::Equal => Some(i256::MIN),
788+
Ordering::Greater => None,
789+
}
790+
}
791+
}
792+
}
793+
728794
fn de_binary(bytes: &mut &[u8]) -> Self {
729795
let bs: [u8; std::mem::size_of::<Self>()] =
730796
bytes[0..std::mem::size_of::<Self>()].try_into().unwrap();
@@ -947,10 +1013,9 @@ impl DecimalDataType {
9471013
let l = a.leading_digits() + b.leading_digits();
9481014
precision = l + scale;
9491015
} else if is_divide {
950-
let l = a.leading_digits() + b.scale();
951-
scale = a.scale().max((a.scale() + 6).min(12));
952-
// P = L + S
953-
precision = l + scale;
1016+
scale = a.scale().max((a.scale() + 6).min(12)); // scale must be >= a.sale()
1017+
let l = a.leading_digits() + b.scale(); // l must be >= a.leading_digits()
1018+
precision = l + scale; // so precision must be >= a.precision()
9541019
} else if is_plus_minus {
9551020
scale = std::cmp::max(a.scale(), b.scale());
9561021
// for addition/subtraction, we add 1 to the width to ensure we don't overflow
@@ -984,8 +1049,18 @@ impl DecimalDataType {
9841049
result_type,
9851050
))
9861051
} else if is_divide {
987-
let (a, b) = Self::div_common_type(a, b, result_type.size())?;
988-
Ok((a, b, result_type))
1052+
let p = precision.max(a.precision()).max(b.precision());
1053+
Ok((
1054+
Self::from_size(DecimalSize {
1055+
precision: p,
1056+
scale: a.scale(),
1057+
})?,
1058+
Self::from_size(DecimalSize {
1059+
precision: p,
1060+
scale: b.scale(),
1061+
})?,
1062+
result_type,
1063+
))
9891064
} else {
9901065
Ok((result_type, result_type, result_type))
9911066
}

src/query/expression/tests/it/decimal.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ use databend_common_expression::types::decimal::DecimalSize;
2121
use databend_common_expression::types::DataType;
2222
use databend_common_expression::types::DecimalDataType;
2323
use databend_common_expression::types::NumberDataType;
24+
use ethnum::i256;
25+
use num_bigint::BigInt;
2426
use pretty_assertions::assert_eq;
2527

2628
#[test]
@@ -168,3 +170,51 @@ fn test_float_to_128() {
168170
assert_eq!(r, b);
169171
}
170172
}
173+
174+
#[test]
175+
fn test_from_bigint() {
176+
let cases = vec![
177+
("0", 0i128),
178+
("12345", 12345i128),
179+
("-1", -1i128),
180+
("-170141183460469231731687303715884105728", i128::MIN),
181+
("170141183460469231731687303715884105727", i128::MAX),
182+
];
183+
184+
for (a, b) in cases {
185+
let r = BigInt::parse_bytes(a.as_bytes(), 10).unwrap();
186+
assert_eq!(i128::from_bigint(r), Some(b));
187+
}
188+
189+
let cases = vec![
190+
("0".to_string(), i256::ZERO),
191+
("12345".to_string(), i256::from(12345)),
192+
("-1".to_string(), i256::from(-1)),
193+
(
194+
"12".repeat(25),
195+
i256::from_str_radix(&"12".repeat(25), 10).unwrap(),
196+
),
197+
(
198+
"1".repeat(26),
199+
i256::from_str_radix(&"1".repeat(26), 10).unwrap(),
200+
),
201+
(i256::MIN.to_string(), i256::MIN),
202+
(i256::MAX.to_string(), i256::MAX),
203+
];
204+
205+
for (a, b) in cases {
206+
let r = BigInt::parse_bytes(a.as_bytes(), 10).unwrap();
207+
assert_eq!(i256::from_bigint(r), Some(b));
208+
}
209+
210+
let cases = vec![
211+
("1".repeat(78), None),
212+
("12".repeat(78), None),
213+
("234".repeat(78), None),
214+
];
215+
216+
for (a, b) in cases {
217+
let r = BigInt::parse_bytes(a.as_bytes(), 10).unwrap();
218+
assert_eq!(i256::from_bigint(r), b);
219+
}
220+
}

src/query/functions/src/scalars/decimal/arithmetic.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,7 @@ macro_rules! binary_decimal {
8989
let scale_b = $right.scale();
9090

9191
// Note: the result scale is always larger than the left scale
92-
let scale_mul = scale_b + $size.scale - scale_a;
93-
let multiplier = T::e(scale_mul as u32);
92+
let scale_mul = (scale_b + $size.scale - scale_a) as u32;
9493
let func = |a: T, b: T, result: &mut Vec<T>, ctx: &mut EvalContext| {
9594
// We are using round div here which follow snowflake's behavior: https://docs.snowflake.com/sql-reference/operators-arithmetic
9695
// For example:
@@ -102,7 +101,7 @@ macro_rules! binary_decimal {
102101
ctx.set_error(result.len(), "divided by zero");
103102
result.push(one);
104103
} else {
105-
match a.do_round_div(b, multiplier) {
104+
match a.do_round_div(b, scale_mul) {
106105
Some(t) => result.push(t),
107106
None => {
108107
ctx.set_error(

tests/sqllogictests/suites/base/11_data_type/11_0006_data_type_decimal.test

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,13 @@ SELECT CAST(987654321.34 AS DECIMAL(20, 2)) / CAST(1.23 AS DECIMAL(6, 2)) AS res
295295
----
296296
802970992.95934959
297297

298+
query IIIIII
299+
select 3.33 a , ('3.' || repeat('3', 72))::Decimal(76, 72) b, a / b, a * b, (-a) /b, (-a) * b
300+
----
301+
3.33 3.333333333333333333333333333333333333333333333333333333333333333333333333 0.99900000 11.099999999999999999999999999999999999999999999999999999999999999999999999 -0.99900000 -11.099999999999999999999999999999999999999999999999999999999999999999999999
302+
303+
statement error
304+
select (repeat('9', 38) || '.3')::Decimal(76, 72) a, a * a
298305

299306
query I
300307
SELECT CAST(987654321.34 AS DECIMAL(76, 2)) / CAST(1.23 AS DECIMAL(76, 2)) AS result;

0 commit comments

Comments
 (0)