-
Notifications
You must be signed in to change notification settings - Fork 1k
[Variant] Support variant to Decimal32/64/128/256
#8552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
f25b499
7a32191
02d29de
f498db5
964e45a
43d579d
6f39a2a
8f0f53c
d88fd7f
522b26a
9b6d0e1
54237fe
e0b18da
1f19580
c163a91
274a028
a7cdd33
94d60c0
338defe
e1febf6
51648fd
a48bbf4
cb2576c
5ffab93
ef62474
21a83ed
25e4aa9
539d73f
dfe9960
cfc8580
9ed0d7a
0567cb6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,17 +19,23 @@ use crate::cast::*; | |
|
||
/// A utility trait that provides checked conversions between | ||
/// decimal types inspired by [`NumCast`] | ||
pub(crate) trait DecimalCast: Sized { | ||
pub trait DecimalCast: Sized { | ||
/// Convert the decimal to an i32 | ||
fn to_i32(self) -> Option<i32>; | ||
|
||
/// Convert the decimal to an i64 | ||
fn to_i64(self) -> Option<i64>; | ||
|
||
/// Convert the decimal to an i128 | ||
fn to_i128(self) -> Option<i128>; | ||
|
||
/// Convert the decimal to an i256 | ||
fn to_i256(self) -> Option<i256>; | ||
|
||
/// Convert a decimal from a decimal | ||
fn from_decimal<T: DecimalCast>(n: T) -> Option<Self>; | ||
|
||
/// Convert a decimal from a f64 | ||
fn from_f64(n: f64) -> Option<Self>; | ||
} | ||
|
||
|
@@ -139,6 +145,133 @@ impl DecimalCast for i256 { | |
} | ||
} | ||
|
||
/// Build a rescale function from (input_precision, input_scale) to (output_precision, output_scale) | ||
/// returning a closure `Fn(I::Native) -> Option<O::Native>` that performs the conversion. | ||
pub fn rescale_decimal<I, O>( | ||
input_precision: u8, | ||
input_scale: i8, | ||
output_precision: u8, | ||
output_scale: i8, | ||
) -> impl Fn(I::Native) -> Result<O::Native, ArrowError> | ||
where | ||
I: DecimalType, | ||
O: DecimalType, | ||
I::Native: DecimalCast, | ||
O::Native: DecimalCast, | ||
{ | ||
let delta_scale = output_scale - input_scale; | ||
|
||
// Determine if the cast is infallible based on precision/scale math | ||
let is_infallible_cast = | ||
is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale); | ||
|
||
// Build a single mode once and use a thin closure that calls into it | ||
enum RescaleMode<I, O> { | ||
SameScale, | ||
Up { mul: O }, | ||
Down { div: I, half: I, half_neg: I }, | ||
Invalid, | ||
} | ||
|
||
let mode = if delta_scale == 0 { | ||
RescaleMode::SameScale | ||
} else if delta_scale > 0 { | ||
match O::Native::from_decimal(10_i128).and_then(|t| t.pow_checked(delta_scale as u32).ok()) | ||
{ | ||
Some(mul) => RescaleMode::Up { mul }, | ||
None => RescaleMode::Invalid, | ||
} | ||
} else { | ||
match I::Native::from_decimal(10_i128) | ||
.and_then(|t| t.pow_checked(delta_scale.unsigned_abs() as u32).ok()) | ||
{ | ||
Some(div) => { | ||
let half = div.div_wrapping(I::Native::from_usize(2).unwrap()); | ||
let half_neg = half.neg_wrapping(); | ||
RescaleMode::Down { | ||
div, | ||
half, | ||
half_neg, | ||
} | ||
} | ||
None => RescaleMode::Invalid, | ||
} | ||
}; | ||
|
||
let f = move |x: I::Native| { | ||
match &mode { | ||
RescaleMode::SameScale => O::Native::from_decimal(x), | ||
RescaleMode::Up { mul } => { | ||
O::Native::from_decimal(x).and_then(|x| x.mul_checked(*mul).ok()) | ||
} | ||
RescaleMode::Down { | ||
div, | ||
half, | ||
half_neg, | ||
} => { | ||
// div is >= 10 and so this cannot overflow | ||
let d = x.div_wrapping(*div); | ||
let r = x.mod_wrapping(*div); | ||
|
||
// Round result | ||
let adjusted = match x >= I::Native::ZERO { | ||
true if r >= *half => d.add_wrapping(I::Native::ONE), | ||
false if r <= *half_neg => d.sub_wrapping(I::Native::ONE), | ||
_ => d, | ||
}; | ||
O::Native::from_decimal(adjusted) | ||
} | ||
RescaleMode::Invalid => None, | ||
} | ||
}; | ||
|
||
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale); | ||
|
||
move |x| { | ||
if is_infallible_cast { | ||
f(x).ok_or_else(|| error(x)) | ||
} else { | ||
f(x).ok_or_else(|| error(x)).and_then(|v| { | ||
O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v) | ||
}) | ||
} | ||
} | ||
} | ||
|
||
/// Returns true if casting from (input_precision, input_scale) to | ||
/// (output_precision, output_scale) is infallible based on precision/scale math. | ||
fn is_infallible_decimal_cast( | ||
input_precision: u8, | ||
input_scale: i8, | ||
output_precision: u8, | ||
output_scale: i8, | ||
) -> bool { | ||
let delta_scale = output_scale - input_scale; | ||
let input_precision_i8 = input_precision as i8; | ||
let output_precision_i8 = output_precision as i8; | ||
liamzwbao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
if delta_scale >= 0 { | ||
// if the gain in precision (digits) is greater than the multiplication due to scaling | ||
// every number will fit into the output type | ||
// Example: If we are starting with any number of precision 5 [xxxxx], | ||
// then an increase of scale by 3 will have the following effect on the representation: | ||
// [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type | ||
// needs to provide at least 8 digits precision | ||
input_precision_i8 + delta_scale <= output_precision_i8 | ||
} else { | ||
// if the reduction of the input number through scaling (dividing) is greater | ||
// than a possible precision loss (plus potential increase via rounding) | ||
// every input number will fit into the output type | ||
// Example: If we are starting with any number of precision 5 [xxxxx], | ||
// then and decrease the scale by 3 will have the following effect on the representation: | ||
// [xxxxx] -> [xx] (+ 1 possibly, due to rounding). | ||
// The rounding may add an additional digit, so the cast to be infallible, | ||
liamzwbao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
// the output type needs to have at least 3 digits of precision. | ||
// e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: | ||
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible | ||
input_precision_i8 + delta_scale < output_precision_i8 | ||
} | ||
} | ||
|
||
pub(crate) fn cast_decimal_to_decimal_error<I, O>( | ||
output_precision: u8, | ||
output_scale: i8, | ||
|
@@ -174,55 +307,20 @@ where | |
I::Native: DecimalCast + ArrowNativeTypeOp, | ||
O::Native: DecimalCast + ArrowNativeTypeOp, | ||
{ | ||
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale); | ||
let delta_scale = input_scale - output_scale; | ||
// if the reduction of the input number through scaling (dividing) is greater | ||
// than a possible precision loss (plus potential increase via rounding) | ||
// every input number will fit into the output type | ||
// Example: If we are starting with any number of precision 5 [xxxxx], | ||
// then and decrease the scale by 3 will have the following effect on the representation: | ||
// [xxxxx] -> [xx] (+ 1 possibly, due to rounding). | ||
// The rounding may add an additional digit, so the cast to be infallible, | ||
// the output type needs to have at least 3 digits of precision. | ||
// e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: | ||
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible | ||
let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8); | ||
|
||
let div = I::Native::from_decimal(10_i128) | ||
.unwrap() | ||
.pow_checked(delta_scale as u32)?; | ||
|
||
let half = div.div_wrapping(I::Native::from_usize(2).unwrap()); | ||
let half_neg = half.neg_wrapping(); | ||
|
||
let f = |x: I::Native| { | ||
// div is >= 10 and so this cannot overflow | ||
let d = x.div_wrapping(div); | ||
let r = x.mod_wrapping(div); | ||
// make sure we don't perform calculations that don't make sense w/o validation | ||
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?; | ||
|
||
let is_infallible_cast = | ||
is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale); | ||
|
||
// Round result | ||
let adjusted = match x >= I::Native::ZERO { | ||
true if r >= half => d.add_wrapping(I::Native::ONE), | ||
false if r <= half_neg => d.sub_wrapping(I::Native::ONE), | ||
_ => d, | ||
}; | ||
O::Native::from_decimal(adjusted) | ||
}; | ||
let f = rescale_decimal::<I, O>(input_precision, input_scale, output_precision, output_scale); | ||
|
||
Ok(if is_infallible_cast { | ||
// make sure we don't perform calculations that don't make sense w/o validation | ||
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?; | ||
let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the result is guaranteed | ||
// to fit into the target type | ||
array.unary(g) | ||
// unwrapping is safe since the result is guaranteed to fit into the target type | ||
array.unary(|x| f(x).unwrap()) | ||
} else if cast_options.safe { | ||
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) | ||
array.unary_opt(|x| f(x).ok()) | ||
} else { | ||
array.try_unary(|x| { | ||
f(x).ok_or_else(|| error(x)).and_then(|v| { | ||
O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v) | ||
}) | ||
})? | ||
array.try_unary(f)? | ||
}) | ||
} | ||
|
||
|
@@ -240,35 +338,20 @@ where | |
I::Native: DecimalCast + ArrowNativeTypeOp, | ||
O::Native: DecimalCast + ArrowNativeTypeOp, | ||
{ | ||
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale); | ||
let delta_scale = output_scale - input_scale; | ||
let mul = O::Native::from_decimal(10_i128) | ||
.unwrap() | ||
.pow_checked(delta_scale as u32)?; | ||
|
||
// if the gain in precision (digits) is greater than the multiplication due to scaling | ||
// every number will fit into the output type | ||
// Example: If we are starting with any number of precision 5 [xxxxx], | ||
// then an increase of scale by 3 will have the following effect on the representation: | ||
// [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type | ||
// needs to provide at least 8 digits precision | ||
let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8); | ||
let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); | ||
// make sure we don't perform calculations that don't make sense w/o validation | ||
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?; | ||
|
||
let is_infallible_cast = | ||
is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale); | ||
let f = rescale_decimal::<I, O>(input_precision, input_scale, output_precision, output_scale); | ||
|
||
Ok(if is_infallible_cast { | ||
// make sure we don't perform calculations that don't make sense w/o validation | ||
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?; | ||
// unwrapping is safe since the result is guaranteed to fit into the target type | ||
let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul); | ||
array.unary(f) | ||
array.unary(|x| f(x).unwrap()) | ||
} else if cast_options.safe { | ||
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) | ||
array.unary_opt(|x| f(x).ok()) | ||
} else { | ||
array.try_unary(|x| { | ||
f(x).ok_or_else(|| error(x)).and_then(|v| { | ||
O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v) | ||
}) | ||
})? | ||
array.try_unary(f)? | ||
}) | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This refactor seems a bit "backward" to me, which probably causes the benchmark regressions:
convert_to_smaller_scale_decimal
andconvert_to_bigger_or_equal_scale_decimal
) from two locations (cast_decimal_to_decimal
andcast_decimal_to_decimal_same_type
). This avoided some branching in the inner cast loop, because the branch on direction of scale change is taken outside the loop.rescale_decimal
method, which not only requires the introduction of a newis_infallible_cast
helper method, but also leaves the twoconvert_to_xxx_scale_decimal
methods with virtually identical bodies. At that point we may as well eliminate those helpers entirely and avoid the code bloat... but the helpers probably existed for a reason (to hoist at least some branches out of the inner loop).I wonder if we should instead do:
convert_to_smaller_scale_decimal
andconvert_to_bigger_or_equal_scale_decimal
array
orcast_options
as inputOk((f, is_infallible_cast)
which corresponds to the return typeResult<(impl Fn(I::Native) -> Option<O::Native>, bool), ArrowError>
apply_decimal_cast
function helperarray
,cast_options
and the(impl Fn, bool)
pair produced by aconvert_to_xxx_scale_decimal
helperf
to an arraycast_decimal_to_decimal
andcast_decimal_to_decimal_same_type
to call those functions (see below)rescale_decimal
would be the single-row equivalent ofcast_decimal_to_decimal
, returningOption<O::Native>
validate_decimal_precision_and_scale
and fails on error, so we don't need to validate on a per-row basis.cast_decimal_to_decimal
rescale_decimal