Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f25b499
[Variant] Support variant to `Decimal32/64/128/256`
liamzwbao Oct 3, 2025
7a32191
Simplify logic
liamzwbao Oct 4, 2025
02d29de
Using macro to generalize
liamzwbao Oct 4, 2025
f498db5
Support i256 and Decimal256
liamzwbao Oct 4, 2025
964e45a
Simplify decimal builders
liamzwbao Oct 4, 2025
43d579d
Merge branch 'main' into issue-8477-variant-to-arrow-decimal
liamzwbao Oct 4, 2025
6f39a2a
fmt
liamzwbao Oct 4, 2025
8f0f53c
Add comment
liamzwbao Oct 4, 2025
d88fd7f
assert precision and scale in tests
liamzwbao Oct 6, 2025
522b26a
Merge branch 'main' into issue-8477-variant-to-arrow-decimal
liamzwbao Oct 6, 2025
9b6d0e1
address comments
liamzwbao Oct 6, 2025
54237fe
add more overflow cases and valid cases that will overflow in current…
liamzwbao Oct 8, 2025
e0b18da
WIP
liamzwbao Oct 8, 2025
1f19580
Refactor common logic
liamzwbao Oct 8, 2025
c163a91
Refactor common logic
liamzwbao Oct 8, 2025
274a028
Refactor common logic
liamzwbao Oct 8, 2025
a7cdd33
Use rescale_decimal for variant decimal scaling
liamzwbao Oct 8, 2025
94d60c0
Fix clippy
liamzwbao Oct 8, 2025
338defe
Merge branch 'main' into issue-8477-variant-to-arrow-decimal
liamzwbao Oct 9, 2025
e1febf6
Address comments
liamzwbao Oct 9, 2025
51648fd
Move rescale_decimal into variant-compute
liamzwbao Oct 10, 2025
a48bbf4
Revert changes in arrow-cast
liamzwbao Oct 10, 2025
cb2576c
Fix doc
liamzwbao Oct 10, 2025
5ffab93
Merge branch 'main' into issue-8477-variant-to-arrow-decimal
liamzwbao Oct 10, 2025
ef62474
Return value instead of fn
liamzwbao Oct 10, 2025
21a83ed
Merge branch 'main' into issue-8477-variant-to-arrow-decimal
liamzwbao Oct 10, 2025
25e4aa9
Fix large scale reduction case
liamzwbao Oct 10, 2025
539d73f
Reuse DecimalCast
liamzwbao Oct 10, 2025
dfe9960
Merge branch 'main' into issue-8477-variant-to-arrow-decimal
liamzwbao Oct 14, 2025
cfc8580
Use trait VariantDecimalType
liamzwbao Oct 15, 2025
9ed0d7a
Add doc
liamzwbao Oct 15, 2025
0567cb6
Refactor tests to use `into`
liamzwbao Oct 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 152 additions & 69 deletions arrow-cast/src/cast/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@ use crate::cast::*;

/// A utility trait that provides checked conversions between
/// decimal types inspired by [`NumCast`]
pub(crate) trait DecimalCast: Sized {
pub trait DecimalCast: Sized {
/// Convert the decimal to an i32
fn to_i32(self) -> Option<i32>;

/// Convert the decimal to an i64
fn to_i64(self) -> Option<i64>;

/// Convert the decimal to an i128
fn to_i128(self) -> Option<i128>;

/// Convert the decimal to an i256
fn to_i256(self) -> Option<i256>;

/// Convert a decimal from a decimal
fn from_decimal<T: DecimalCast>(n: T) -> Option<Self>;

/// Convert a decimal from a f64
fn from_f64(n: f64) -> Option<Self>;
}

Expand Down Expand Up @@ -139,6 +145,133 @@ impl DecimalCast for i256 {
}
}

/// Build a rescale function from (input_precision, input_scale) to (output_precision, output_scale)
/// returning a closure `Fn(I::Native) -> Option<O::Native>` that performs the conversion.
pub fn rescale_decimal<I, O>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactor seems a bit "backward" to me, which probably causes the benchmark regressions:

  • Original code was dispatching to two methods (convert_to_smaller_scale_decimal and convert_to_bigger_or_equal_scale_decimal) from two locations (cast_decimal_to_decimal and cast_decimal_to_decimal_same_type). This avoided some branching in the inner cast loop, because the branch on direction of scale change is taken outside the loop.
  • New code pushes everything down into this new rescale_decimal method, which not only requires the introduction of a new is_infallible_cast helper method, but also leaves the two convert_to_xxx_scale_decimal methods with virtually identical bodies. At that point we may as well eliminate those helpers entirely and avoid the code bloat... but the helpers probably existed for a reason (to hoist at least some branches out of the inner loop).
  • The new code also allocates errors that get downgraded to empty options, where the original code upgraded empty options to errors. Arrow errors allocate strings, so that's a meaningful difference.

I wonder if we should instead do:

  • rework convert_to_smaller_scale_decimal and convert_to_bigger_or_equal_scale_decimal
    • no longer take array or cast_options as input
    • return Ok((f, is_infallible_cast) which corresponds to the return type
      Result<(impl Fn(I::Native) -> Option<O::Native>, bool), ArrowError>
  • define a new generic apply_decimal_cast function helper
    • it takes as input array, cast_options and the (impl Fn, bool) pair produced by a convert_to_xxx_scale_decimal helper
    • it handles the three ways of applying f to an array
  • rework cast_decimal_to_decimal and cast_decimal_to_decimal_same_type to call those functions (see below)
  • rescale_decimal would be the single-row equivalent of cast_decimal_to_decimal, returning Option<O::Native>
  • The decimal builder's constructor calls validate_decimal_precision_and_scale and fails on error, so we don't need to validate on a per-row basis.
cast_decimal_to_decimal
let array: PrimitiveArray<O> = if input_scale > output_scale {
    let (f, is_infallible_cast) = convert_to_smaller_scale_decimal(...)?;
    apply_decimal_cast(array, cast_options, f, is_infallible)?
} else {
    let (f, is_infallible_cast) = convert_to_bigger_or_equal_scale_decimal(...)?;
    apply_decimal_cast(array, cast_options, f, is_infallible)?
}
rescale_decimal
if input_scale > output_scale {
    let (f, _) = convert_to_smaller_scale_decimal(...)?;
    f(integer)
} else {
    let (f, _) = convert_to_bigger_or_equal_scale_decimal(...)?;
    f(integer)
}

input_precision: u8,
input_scale: i8,
output_precision: u8,
output_scale: i8,
) -> impl Fn(I::Native) -> Result<O::Native, ArrowError>
where
I: DecimalType,
O: DecimalType,
I::Native: DecimalCast,
O::Native: DecimalCast,
{
let delta_scale = output_scale - input_scale;

// Determine if the cast is infallible based on precision/scale math
let is_infallible_cast =
is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale);

// Build a single mode once and use a thin closure that calls into it
enum RescaleMode<I, O> {
SameScale,
Up { mul: O },
Down { div: I, half: I, half_neg: I },
Invalid,
}

let mode = if delta_scale == 0 {
RescaleMode::SameScale
} else if delta_scale > 0 {
match O::Native::from_decimal(10_i128).and_then(|t| t.pow_checked(delta_scale as u32).ok())
{
Some(mul) => RescaleMode::Up { mul },
None => RescaleMode::Invalid,
}
} else {
match I::Native::from_decimal(10_i128)
.and_then(|t| t.pow_checked(delta_scale.unsigned_abs() as u32).ok())
{
Some(div) => {
let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
let half_neg = half.neg_wrapping();
RescaleMode::Down {
div,
half,
half_neg,
}
}
None => RescaleMode::Invalid,
}
};

let f = move |x: I::Native| {
match &mode {
RescaleMode::SameScale => O::Native::from_decimal(x),
RescaleMode::Up { mul } => {
O::Native::from_decimal(x).and_then(|x| x.mul_checked(*mul).ok())
}
RescaleMode::Down {
div,
half,
half_neg,
} => {
// div is >= 10 and so this cannot overflow
let d = x.div_wrapping(*div);
let r = x.mod_wrapping(*div);

// Round result
let adjusted = match x >= I::Native::ZERO {
true if r >= *half => d.add_wrapping(I::Native::ONE),
false if r <= *half_neg => d.sub_wrapping(I::Native::ONE),
_ => d,
};
O::Native::from_decimal(adjusted)
}
RescaleMode::Invalid => None,
}
};

let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);

move |x| {
if is_infallible_cast {
f(x).ok_or_else(|| error(x))
} else {
f(x).ok_or_else(|| error(x)).and_then(|v| {
O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
})
}
}
}

/// Returns true if casting from (input_precision, input_scale) to
/// (output_precision, output_scale) is infallible based on precision/scale math.
fn is_infallible_decimal_cast(
input_precision: u8,
input_scale: i8,
output_precision: u8,
output_scale: i8,
) -> bool {
let delta_scale = output_scale - input_scale;
let input_precision_i8 = input_precision as i8;
let output_precision_i8 = output_precision as i8;
if delta_scale >= 0 {
// if the gain in precision (digits) is greater than the multiplication due to scaling
// every number will fit into the output type
// Example: If we are starting with any number of precision 5 [xxxxx],
// then an increase of scale by 3 will have the following effect on the representation:
// [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
// needs to provide at least 8 digits precision
input_precision_i8 + delta_scale <= output_precision_i8
} else {
// if the reduction of the input number through scaling (dividing) is greater
// than a possible precision loss (plus potential increase via rounding)
// every input number will fit into the output type
// Example: If we are starting with any number of precision 5 [xxxxx],
// then and decrease the scale by 3 will have the following effect on the representation:
// [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
// The rounding may add an additional digit, so the cast to be infallible,
// the output type needs to have at least 3 digits of precision.
// e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
input_precision_i8 + delta_scale < output_precision_i8
}
}

pub(crate) fn cast_decimal_to_decimal_error<I, O>(
output_precision: u8,
output_scale: i8,
Expand Down Expand Up @@ -174,55 +307,20 @@ where
I::Native: DecimalCast + ArrowNativeTypeOp,
O::Native: DecimalCast + ArrowNativeTypeOp,
{
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
let delta_scale = input_scale - output_scale;
// if the reduction of the input number through scaling (dividing) is greater
// than a possible precision loss (plus potential increase via rounding)
// every input number will fit into the output type
// Example: If we are starting with any number of precision 5 [xxxxx],
// then and decrease the scale by 3 will have the following effect on the representation:
// [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
// The rounding may add an additional digit, so the cast to be infallible,
// the output type needs to have at least 3 digits of precision.
// e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);

let div = I::Native::from_decimal(10_i128)
.unwrap()
.pow_checked(delta_scale as u32)?;

let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
let half_neg = half.neg_wrapping();

let f = |x: I::Native| {
// div is >= 10 and so this cannot overflow
let d = x.div_wrapping(div);
let r = x.mod_wrapping(div);
// make sure we don't perform calculations that don't make sense w/o validation
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this move? It used to get called only for infallible casts, now it gets called for all casts?

Copy link
Contributor Author

@liamzwbao liamzwbao Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be an improvement since it validates the output precision and scale before performing operations on the array. However, these checks don’t affect correctness, because the same validation is performed again when creating the output decimal array here.

The main benefit is that it can fail early on invalid operations and avoid unnecessary operation on array, but it does add some overhead for valid operations since the conditions are checked twice.

So to be consistent, I think we should either add or remove this check across all branches.

let is_infallible_cast =
is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale);

// Round result
let adjusted = match x >= I::Native::ZERO {
true if r >= half => d.add_wrapping(I::Native::ONE),
false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
_ => d,
};
O::Native::from_decimal(adjusted)
};
let f = rescale_decimal::<I, O>(input_precision, input_scale, output_precision, output_scale);

Ok(if is_infallible_cast {
// make sure we don't perform calculations that don't make sense w/o validation
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the result is guaranteed
// to fit into the target type
array.unary(g)
// unwrapping is safe since the result is guaranteed to fit into the target type
array.unary(|x| f(x).unwrap())
} else if cast_options.safe {
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
array.unary_opt(|x| f(x).ok())
} else {
array.try_unary(|x| {
f(x).ok_or_else(|| error(x)).and_then(|v| {
O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
})
})?
array.try_unary(f)?
})
}

Expand All @@ -240,35 +338,20 @@ where
I::Native: DecimalCast + ArrowNativeTypeOp,
O::Native: DecimalCast + ArrowNativeTypeOp,
{
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
let delta_scale = output_scale - input_scale;
let mul = O::Native::from_decimal(10_i128)
.unwrap()
.pow_checked(delta_scale as u32)?;

// if the gain in precision (digits) is greater than the multiplication due to scaling
// every number will fit into the output type
// Example: If we are starting with any number of precision 5 [xxxxx],
// then an increase of scale by 3 will have the following effect on the representation:
// [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
// needs to provide at least 8 digits precision
let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
// make sure we don't perform calculations that don't make sense w/o validation
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;

let is_infallible_cast =
is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale);
let f = rescale_decimal::<I, O>(input_precision, input_scale, output_precision, output_scale);

Ok(if is_infallible_cast {
// make sure we don't perform calculations that don't make sense w/o validation
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
// unwrapping is safe since the result is guaranteed to fit into the target type
let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul);
array.unary(f)
array.unary(|x| f(x).unwrap())
} else if cast_options.safe {
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
array.unary_opt(|x| f(x).ok())
} else {
array.try_unary(|x| {
f(x).ok_or_else(|| error(x)).and_then(|v| {
O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
})
})?
array.try_unary(f)?
})
}

Expand Down
2 changes: 2 additions & 0 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ use arrow_schema::*;
use arrow_select::take::take;
use num_traits::{NumCast, ToPrimitive, cast::AsPrimitive};

pub use decimal::{DecimalCast, rescale_decimal};

/// CastOptions provides a way to override the default cast behaviors
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CastOptions<'a> {
Expand Down
72 changes: 68 additions & 4 deletions parquet-variant-compute/src/type_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@

//! Module for transforming a typed arrow `Array` to `VariantArray`.

use arrow::datatypes::{self, ArrowPrimitiveType};
use parquet_variant::Variant;
use arrow::{
compute::{DecimalCast, rescale_decimal},
datatypes::{
self, ArrowPrimitiveType, Decimal32Type, Decimal64Type, Decimal128Type, DecimalType,
},
};
use arrow_schema::ArrowError;
use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16};

/// Options for controlling the behavior of `cast_to_variant_with_options`.
#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -61,6 +67,63 @@ impl_primitive_from_variant!(datatypes::Float16Type, as_f16);
impl_primitive_from_variant!(datatypes::Float32Type, as_f32);
impl_primitive_from_variant!(datatypes::Float64Type, as_f64);

pub(crate) fn variant_to_unscaled_decimal<O>(
variant: &Variant<'_, '_>,
precision: u8,
scale: i8,
) -> Option<O::Native>
where
O: DecimalType,
O::Native: DecimalCast,
{
let maybe_rescaled = match variant {
Variant::Int8(i) => {
rescale_decimal::<Decimal32Type, O>(VariantDecimal4::MAX_PRECISION, 0, precision, scale)(
*i as i32,
)
}
Variant::Int16(i) => {
rescale_decimal::<Decimal32Type, O>(VariantDecimal4::MAX_PRECISION, 0, precision, scale)(
*i as i32,
)
}
Variant::Int32(i) => {
rescale_decimal::<Decimal32Type, O>(VariantDecimal4::MAX_PRECISION, 0, precision, scale)(
*i,
)
}
Variant::Int64(i) => {
rescale_decimal::<Decimal64Type, O>(VariantDecimal8::MAX_PRECISION, 0, precision, scale)(
*i,
)
}
Variant::Decimal4(d) => rescale_decimal::<Decimal32Type, O>(
VariantDecimal4::MAX_PRECISION,
d.scale() as i8,
precision,
scale,
)(d.integer()),
Variant::Decimal8(d) => rescale_decimal::<Decimal64Type, O>(
VariantDecimal8::MAX_PRECISION,
d.scale() as i8,
precision,
scale,
)(d.integer()),
Variant::Decimal16(d) => rescale_decimal::<Decimal128Type, O>(
VariantDecimal16::MAX_PRECISION,
d.scale() as i8,
precision,
scale,
)(d.integer()),
_ => Err(ArrowError::InvalidArgumentError(format!(
"Invalid variant type: {:?}",
variant
))),
};

maybe_rescaled.ok()
}

/// Convert the value at a specific index in the given array into a `Variant`.
macro_rules! non_generic_conversion_single_value {
($array:expr, $cast_fn:expr, $index:expr) => {{
Expand Down Expand Up @@ -109,8 +172,9 @@ macro_rules! decimal_to_variant_decimal {
let (v, scale) = if *$scale < 0 {
// For negative scale, we need to multiply the value by 10^|scale|
// For example: 123 with scale -2 becomes 12300 with scale 0
let multiplier = <$value_type>::pow(10, (-*$scale) as u32);
(<$value_type>::checked_mul($v, multiplier), 0u8)
let v =
<$value_type>::checked_pow(10, (-*$scale) as u32).and_then(|m| m.checked_mul($v));
(v, 0u8)
} else {
(Some($v), *$scale as u8)
};
Expand Down
Loading
Loading