12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
+ use std:: cmp:: Ordering ;
15
16
use std:: fmt:: Debug ;
16
17
use std:: marker:: PhantomData ;
17
18
use std:: ops:: Range ;
@@ -25,9 +26,13 @@ use databend_common_io::display_decimal_128;
25
26
use databend_common_io:: display_decimal_256;
26
27
use enum_as_inner:: EnumAsInner ;
27
28
use ethnum:: i256;
29
+ use ethnum:: u256;
28
30
use ethnum:: AsI256 ;
29
31
use itertools:: Itertools ;
32
+ use num_bigint:: BigInt ;
33
+ use num_traits:: FromBytes ;
30
34
use num_traits:: NumCast ;
35
+ use num_traits:: ToPrimitive ;
31
36
use serde:: Deserialize ;
32
37
use serde:: Serialize ;
33
38
@@ -356,7 +361,7 @@ pub trait Decimal:
356
361
fn checked_mul ( self , rhs : Self ) -> Option < Self > ;
357
362
fn checked_rem ( self , rhs : Self ) -> Option < Self > ;
358
363
359
- fn do_round_div ( self , rhs : Self , mul : Self ) -> Option < Self > ;
364
+ fn do_round_div ( self , rhs : Self , mul_scale : u32 ) -> Option < Self > ;
360
365
361
366
// mul two decimals and return a decimal with rounding option
362
367
fn do_round_mul ( self , rhs : Self , shift_scale : u32 ) -> Option < Self > ;
@@ -368,6 +373,7 @@ pub trait Decimal:
368
373
369
374
fn from_float ( value : f64 ) -> Self ;
370
375
fn from_i128 < U : Into < i128 > > ( value : U ) -> Self ;
376
+ fn from_bigint ( value : BigInt ) -> Option < Self > ;
371
377
372
378
fn de_binary ( bytes : & mut & [ u8 ] ) -> Self ;
373
379
fn display ( self , scale : u8 ) -> String ;
@@ -471,7 +477,8 @@ impl Decimal for i128 {
471
477
}
472
478
}
473
479
474
- fn do_round_div ( self , rhs : Self , mul : Self ) -> Option < Self > {
480
+ fn do_round_div ( self , rhs : Self , mul_scale : u32 ) -> Option < Self > {
481
+ let mul = i256:: e ( mul_scale) ;
475
482
if self . is_negative ( ) == rhs. is_negative ( ) {
476
483
let res = ( i256:: from ( self ) * i256:: from ( mul) + i256:: from ( rhs) / 2 ) / i256:: from ( rhs) ;
477
484
Some ( * res. low ( ) )
@@ -535,6 +542,10 @@ impl Decimal for i128 {
535
542
value. into ( )
536
543
}
537
544
545
+ fn from_bigint ( value : BigInt ) -> Option < Self > {
546
+ value. to_i128 ( )
547
+ }
548
+
538
549
fn de_binary ( bytes : & mut & [ u8 ] ) -> Self {
539
550
let bs: [ u8 ; std:: mem:: size_of :: < Self > ( ) ] =
540
551
bytes[ 0 ..std:: mem:: size_of :: < Self > ( ) ] . try_into ( ) . unwrap ( ) ;
@@ -687,19 +698,48 @@ impl Decimal for i256 {
687
698
688
699
fn do_round_mul ( self , rhs : Self , shift_scale : u32 ) -> Option < Self > {
689
700
let div = i256:: e ( shift_scale) ;
690
- if self . is_negative ( ) == rhs. is_negative ( ) {
701
+ let ret : Option < i256 > = if self . is_negative ( ) == rhs. is_negative ( ) {
691
702
self . checked_mul ( rhs) . map ( |x| ( x + div / 2 ) / div)
692
703
} else {
693
704
self . checked_mul ( rhs) . map ( |x| ( x - div / 2 ) / div)
694
- }
705
+ } ;
706
+
707
+ ret. or_else ( || {
708
+ let a = BigInt :: from_le_bytes ( & self . to_le_bytes ( ) ) ;
709
+ let b = BigInt :: from_le_bytes ( & rhs. to_le_bytes ( ) ) ;
710
+ let div = BigInt :: from ( 10 ) . pow ( shift_scale) ;
711
+ if self . is_negative ( ) == rhs. is_negative ( ) {
712
+ Self :: from_bigint ( ( a * b + div. clone ( ) / 2 ) / div)
713
+ } else {
714
+ Self :: from_bigint ( ( a * b - div. clone ( ) / 2 ) / div)
715
+ }
716
+ } )
695
717
}
696
718
697
- fn do_round_div ( self , rhs : Self , mul : Self ) -> Option < Self > {
698
- if self . is_negative ( ) == rhs. is_negative ( ) {
719
+ fn do_round_div ( self , rhs : Self , mul_scale : u32 ) -> Option < Self > {
720
+ let fallback = || {
721
+ let a = BigInt :: from_le_bytes ( & self . to_le_bytes ( ) ) ;
722
+ let b = BigInt :: from_le_bytes ( & rhs. to_le_bytes ( ) ) ;
723
+ let mul = BigInt :: from ( 10 ) . pow ( mul_scale) ;
724
+ if self . is_negative ( ) == rhs. is_negative ( ) {
725
+ Self :: from_bigint ( ( a * mul + b. clone ( ) / 2 ) / b)
726
+ } else {
727
+ Self :: from_bigint ( ( a * mul - b. clone ( ) / 2 ) / b)
728
+ }
729
+ } ;
730
+
731
+ if mul_scale >= MAX_DECIMAL256_PRECISION as _ {
732
+ return fallback ( ) ;
733
+ }
734
+
735
+ let mul = i256:: e ( mul_scale) ;
736
+ let ret: Option < i256 > = if self . is_negative ( ) == rhs. is_negative ( ) {
699
737
self . checked_mul ( mul) . map ( |x| ( x + rhs / 2 ) / rhs)
700
738
} else {
701
739
self . checked_mul ( mul) . map ( |x| ( x - rhs / 2 ) / rhs)
702
- }
740
+ } ;
741
+
742
+ ret. or_else ( fallback)
703
743
}
704
744
705
745
fn min_for_precision ( to_precision : u8 ) -> Self {
@@ -725,6 +765,32 @@ impl Decimal for i256 {
725
765
i256:: from ( value. into ( ) )
726
766
}
727
767
768
+ fn from_bigint ( value : BigInt ) -> Option < Self > {
769
+ let mut ret: u256 = u256:: ZERO ;
770
+ let mut bits = 0 ;
771
+
772
+ for i in value. iter_u64_digits ( ) {
773
+ if bits >= 256 {
774
+ return None ;
775
+ }
776
+ ret |= u256:: from ( i) << bits;
777
+ bits += 64 ;
778
+ }
779
+
780
+ match value. sign ( ) {
781
+ num_bigint:: Sign :: Plus => i256:: try_from ( ret) . ok ( ) ,
782
+ num_bigint:: Sign :: NoSign => Some ( i256:: ZERO ) ,
783
+ num_bigint:: Sign :: Minus => {
784
+ let m: u256 = u256:: ONE << 255 ;
785
+ match ret. cmp ( & m) {
786
+ Ordering :: Less => Some ( -i256:: try_from ( ret) . unwrap ( ) ) ,
787
+ Ordering :: Equal => Some ( i256:: MIN ) ,
788
+ Ordering :: Greater => None ,
789
+ }
790
+ }
791
+ }
792
+ }
793
+
728
794
fn de_binary ( bytes : & mut & [ u8 ] ) -> Self {
729
795
let bs: [ u8 ; std:: mem:: size_of :: < Self > ( ) ] =
730
796
bytes[ 0 ..std:: mem:: size_of :: < Self > ( ) ] . try_into ( ) . unwrap ( ) ;
@@ -947,10 +1013,9 @@ impl DecimalDataType {
947
1013
let l = a. leading_digits ( ) + b. leading_digits ( ) ;
948
1014
precision = l + scale;
949
1015
} else if is_divide {
950
- let l = a. leading_digits ( ) + b. scale ( ) ;
951
- scale = a. scale ( ) . max ( ( a. scale ( ) + 6 ) . min ( 12 ) ) ;
952
- // P = L + S
953
- precision = l + scale;
1016
+ scale = a. scale ( ) . max ( ( a. scale ( ) + 6 ) . min ( 12 ) ) ; // scale must be >= a.sale()
1017
+ let l = a. leading_digits ( ) + b. scale ( ) ; // l must be >= a.leading_digits()
1018
+ precision = l + scale; // so precision must be >= a.precision()
954
1019
} else if is_plus_minus {
955
1020
scale = std:: cmp:: max ( a. scale ( ) , b. scale ( ) ) ;
956
1021
// for addition/subtraction, we add 1 to the width to ensure we don't overflow
@@ -984,8 +1049,18 @@ impl DecimalDataType {
984
1049
result_type,
985
1050
) )
986
1051
} else if is_divide {
987
- let ( a, b) = Self :: div_common_type ( a, b, result_type. size ( ) ) ?;
988
- Ok ( ( a, b, result_type) )
1052
+ let p = precision. max ( a. precision ( ) ) . max ( b. precision ( ) ) ;
1053
+ Ok ( (
1054
+ Self :: from_size ( DecimalSize {
1055
+ precision : p,
1056
+ scale : a. scale ( ) ,
1057
+ } ) ?,
1058
+ Self :: from_size ( DecimalSize {
1059
+ precision : p,
1060
+ scale : b. scale ( ) ,
1061
+ } ) ?,
1062
+ result_type,
1063
+ ) )
989
1064
} else {
990
1065
Ok ( ( result_type, result_type, result_type) )
991
1066
}
0 commit comments