diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index 1fcae9f66aca..907e61b09f7b 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -19,17 +19,23 @@ use crate::cast::*; /// A utility trait that provides checked conversions between /// decimal types inspired by [`NumCast`] -pub(crate) trait DecimalCast: Sized { +pub trait DecimalCast: Sized { + /// Convert the decimal to an i32 fn to_i32(self) -> Option; + /// Convert the decimal to an i64 fn to_i64(self) -> Option; + /// Convert the decimal to an i128 fn to_i128(self) -> Option; + /// Convert the decimal to an i256 fn to_i256(self) -> Option; + /// Convert a decimal from a decimal fn from_decimal(n: T) -> Option; + /// Convert a decimal from a f64 fn from_f64(n: f64) -> Option; } diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index aa26d0c2f9d3..c825e2752d6c 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -67,6 +67,8 @@ use arrow_schema::*; use arrow_select::take::take; use num_traits::{NumCast, ToPrimitive, cast::AsPrimitive}; +pub use decimal::DecimalCast; + /// CastOptions provides a way to override the default cast behaviors #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct CastOptions<'a> { diff --git a/parquet-variant-compute/src/type_conversion.rs b/parquet-variant-compute/src/type_conversion.rs index 7851ccc735db..83ffc8f08dc3 100644 --- a/parquet-variant-compute/src/type_conversion.rs +++ b/parquet-variant-compute/src/type_conversion.rs @@ -17,8 +17,13 @@ //! Module for transforming a typed arrow `Array` to `VariantArray`. -use arrow::datatypes::{self, ArrowPrimitiveType, ArrowTimestampType, Date32Type}; -use parquet_variant::Variant; +use arrow::array::ArrowNativeTypeOp; +use arrow::compute::DecimalCast; +use arrow::datatypes::{ + self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type, Decimal64Type, Decimal128Type, + DecimalType, +}; +use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16}; /// Options for controlling the behavior of `cast_to_variant_with_options`. #[derive(Debug, Clone, PartialEq, Eq)] @@ -82,7 +87,7 @@ impl_primitive_from_variant!(datatypes::Float64Type, as_f64); impl_primitive_from_variant!( datatypes::Date32Type, as_naive_date, - Date32Type::from_naive_date + datatypes::Date32Type::from_naive_date ); impl_timestamp_from_variant!( datatypes::TimestampMicrosecondType, @@ -109,6 +114,171 @@ impl_timestamp_from_variant!( |timestamp| Self::make_value(timestamp.naive_utc()) ); +/// Returns the unscaled integer representation for Arrow decimal type `O` +/// from a `Variant`. +/// +/// - `precision` and `scale` specify the target Arrow decimal parameters +/// - Integer variants (`Int8/16/32/64`) are treated as decimals with scale 0 +/// - Decimal variants (`Decimal4/8/16`) use their embedded precision and scale +/// +/// The value is rescaled to (`precision`, `scale`) using `rescale_decimal` and +/// returns `None` if it cannot fit the requested precision. +pub(crate) fn variant_to_unscaled_decimal( + variant: &Variant<'_, '_>, + precision: u8, + scale: i8, +) -> Option +where + O: DecimalType, + O::Native: DecimalCast, +{ + match variant { + Variant::Int8(i) => rescale_decimal::( + *i as i32, + VariantDecimal4::MAX_PRECISION, + 0, + precision, + scale, + ), + Variant::Int16(i) => rescale_decimal::( + *i as i32, + VariantDecimal4::MAX_PRECISION, + 0, + precision, + scale, + ), + Variant::Int32(i) => rescale_decimal::( + *i, + VariantDecimal4::MAX_PRECISION, + 0, + precision, + scale, + ), + Variant::Int64(i) => rescale_decimal::( + *i, + VariantDecimal8::MAX_PRECISION, + 0, + precision, + scale, + ), + Variant::Decimal4(d) => rescale_decimal::( + d.integer(), + VariantDecimal4::MAX_PRECISION, + d.scale() as i8, + precision, + scale, + ), + Variant::Decimal8(d) => rescale_decimal::( + d.integer(), + VariantDecimal8::MAX_PRECISION, + d.scale() as i8, + precision, + scale, + ), + Variant::Decimal16(d) => rescale_decimal::( + d.integer(), + VariantDecimal16::MAX_PRECISION, + d.scale() as i8, + precision, + scale, + ), + _ => None, + } +} + +/// Rescale a decimal from (input_precision, input_scale) to (output_precision, output_scale) +/// and return the scaled value if it fits the output precision. Similar to the implementation in +/// decimal.rs in arrow-cast. +pub(crate) fn rescale_decimal( + value: I::Native, + input_precision: u8, + input_scale: i8, + output_precision: u8, + output_scale: i8, +) -> Option +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast, + O::Native: DecimalCast, +{ + let delta_scale = output_scale - input_scale; + + // Determine if the cast is infallible based on precision/scale math + let is_infallible_cast = + is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale); + + let scaled = if delta_scale == 0 { + O::Native::from_decimal(value) + } else if delta_scale > 0 { + let mul = O::Native::from_decimal(10_i128) + .and_then(|t| t.pow_checked(delta_scale as u32).ok())?; + O::Native::from_decimal(value).and_then(|x| x.mul_checked(mul).ok()) + } else { + // delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. If so, the + // scale change divides out more digits than the input has precision and the result of the cast + // is always zero. For example, if we try to apply delta_scale=10 a decimal32 value, the largest + // possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. Smaller values + // (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) produce even + // smaller results, which also round to zero. In that case, just return an array of zeros. + let delta_scale = delta_scale.unsigned_abs() as usize; + let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale) else { + return Some(O::Native::ZERO); + }; + let div = max.add_wrapping(I::Native::ONE); + let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE)); + let half_neg = half.neg_wrapping(); + + // div is >= 10 and so this cannot overflow + let d = value.div_wrapping(div); + let r = value.mod_wrapping(div); + + // Round result + let adjusted = match value >= I::Native::ZERO { + true if r >= half => d.add_wrapping(I::Native::ONE), + false if r <= half_neg => d.sub_wrapping(I::Native::ONE), + _ => d, + }; + O::Native::from_decimal(adjusted) + }; + + scaled.filter(|v| is_infallible_cast || O::is_valid_decimal_precision(*v, output_precision)) +} + +/// Returns true if casting from (input_precision, input_scale) to +/// (output_precision, output_scale) is infallible based on precision/scale math. +fn is_infallible_decimal_cast( + input_precision: u8, + input_scale: i8, + output_precision: u8, + output_scale: i8, +) -> bool { + let delta_scale = output_scale - input_scale; + let input_precision = input_precision as i8; + let output_precision = output_precision as i8; + if delta_scale >= 0 { + // if the gain in precision (digits) is greater than the multiplication due to scaling + // every number will fit into the output type + // Example: If we are starting with any number of precision 5 [xxxxx], + // then an increase of scale by 3 will have the following effect on the representation: + // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type + // needs to provide at least 8 digits precision + input_precision + delta_scale <= output_precision + } else { + // if the reduction of the input number through scaling (dividing) is greater + // than a possible precision loss (plus potential increase via rounding) + // every input number will fit into the output type + // Example: If we are starting with any number of precision 5 [xxxxx], + // then and decrease the scale by 3 will have the following effect on the representation: + // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). + // The rounding may add an additional digit, so for the cast to be infallible, + // the output type needs to have at least 3 digits of precision. + // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: + // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible + input_precision + delta_scale < output_precision + } +} + /// Convert the value at a specific index in the given array into a `Variant`. macro_rules! non_generic_conversion_single_value { ($array:expr, $cast_fn:expr, $index:expr) => {{ diff --git a/parquet-variant-compute/src/variant_get.rs b/parquet-variant-compute/src/variant_get.rs index 8ee489cfe583..fda41cc84a35 100644 --- a/parquet-variant-compute/src/variant_get.rs +++ b/parquet-variant-compute/src/variant_get.rs @@ -300,16 +300,21 @@ mod test { use crate::json_to_variant; use crate::variant_array::{ShreddedVariantFieldArray, StructArrayBuilder}; use arrow::array::{ - Array, ArrayRef, AsArray, BinaryViewArray, BooleanArray, Date32Array, Float32Array, - Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, StringArray, StructArray, + Array, ArrayRef, AsArray, BinaryViewArray, BooleanArray, Date32Array, Decimal32Array, + Decimal64Array, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int8Array, + Int16Array, Int32Array, Int64Array, StringArray, StructArray, }; use arrow::buffer::NullBuffer; use arrow::compute::CastOptions; use arrow::datatypes::DataType::{Int16, Int32, Int64}; + use arrow::datatypes::i256; use arrow_schema::DataType::{Boolean, Float32, Float64, Int8}; use arrow_schema::{DataType, Field, FieldRef, Fields, TimeUnit}; use chrono::DateTime; - use parquet_variant::{EMPTY_VARIANT_METADATA_BYTES, Variant, VariantPath}; + use parquet_variant::{ + EMPTY_VARIANT_METADATA_BYTES, Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16, + VariantDecimalType, VariantPath, + }; fn single_variant_get_test(input_json: &str, path: VariantPath, expected_json: &str) { // Create input array from JSON string @@ -2842,4 +2847,521 @@ mod test { Arc::new(struct_array) } + + #[test] + fn get_decimal32_rescaled_to_scale2() { + // Build unshredded variant values with different scales + let mut builder = crate::VariantArrayBuilder::new(5); + builder.append_variant(VariantDecimal4::try_new(1234, 2).unwrap().into()); // 12.34 + builder.append_variant(VariantDecimal4::try_new(1234, 3).unwrap().into()); // 1.234 + builder.append_variant(VariantDecimal4::try_new(1234, 0).unwrap().into()); // 1234 + builder.append_null(); + builder.append_variant( + VariantDecimal8::try_new((VariantDecimal4::MAX_UNSCALED_VALUE as i64) + 1, 3) + .unwrap() + .into(), + ); // should fit into Decimal32 + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal32(9, 2), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.precision(), 9); + assert_eq!(result.scale(), 2); + assert_eq!(result.value(0), 1234); + assert_eq!(result.value(1), 123); + assert_eq!(result.value(2), 123400); + assert!(result.is_null(3)); + assert_eq!( + result.value(4), + VariantDecimal4::MAX_UNSCALED_VALUE / 10 + 1 + ); // should not be null as the final result fits into Decimal32 + } + + #[test] + fn get_decimal32_scale_down_rounding() { + let mut builder = crate::VariantArrayBuilder::new(7); + builder.append_variant(VariantDecimal4::try_new(1235, 0).unwrap().into()); + builder.append_variant(VariantDecimal4::try_new(1245, 0).unwrap().into()); + builder.append_variant(VariantDecimal4::try_new(-1235, 0).unwrap().into()); + builder.append_variant(VariantDecimal4::try_new(-1245, 0).unwrap().into()); + builder.append_variant(VariantDecimal4::try_new(1235, 2).unwrap().into()); // 12.35 rounded down to 10 for scale -1 + builder.append_variant(VariantDecimal4::try_new(1235, 3).unwrap().into()); // 1.235 rounded down to 0 for scale -1 + builder.append_variant(VariantDecimal4::try_new(5235, 3).unwrap().into()); // 5.235 rounded up to 10 for scale -1 + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal32(9, -1), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.precision(), 9); + assert_eq!(result.scale(), -1); + assert_eq!(result.value(0), 124); + assert_eq!(result.value(1), 125); + assert_eq!(result.value(2), -124); + assert_eq!(result.value(3), -125); + assert_eq!(result.value(4), 1); + assert!(result.is_valid(5)); + assert_eq!(result.value(5), 0); + assert_eq!(result.value(6), 1); + } + + #[test] + fn get_decimal32_large_scale_reduction() { + let mut builder = crate::VariantArrayBuilder::new(2); + builder.append_variant( + VariantDecimal4::try_new(-VariantDecimal4::MAX_UNSCALED_VALUE, 0) + .unwrap() + .into(), + ); + builder.append_variant( + VariantDecimal4::try_new(VariantDecimal4::MAX_UNSCALED_VALUE, 0) + .unwrap() + .into(), + ); + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal32(9, -9), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.precision(), 9); + assert_eq!(result.scale(), -9); + assert_eq!(result.value(0), -1); + assert_eq!(result.value(1), 1); + + let field = Field::new("result", DataType::Decimal32(9, -10), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.precision(), 9); + assert_eq!(result.scale(), -10); + assert!(result.is_valid(0)); + assert_eq!(result.value(0), 0); + assert!(result.is_valid(1)); + assert_eq!(result.value(1), 0); + } + + #[test] + fn get_decimal32_precision_overflow_safe() { + // Exceed Decimal32 after scaling and rounding + let mut builder = crate::VariantArrayBuilder::new(2); + builder.append_variant( + VariantDecimal4::try_new(VariantDecimal4::MAX_UNSCALED_VALUE, 0) + .unwrap() + .into(), + ); + builder.append_variant( + VariantDecimal4::try_new(VariantDecimal4::MAX_UNSCALED_VALUE, 9) + .unwrap() + .into(), + ); // integer value round up overflows + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal32(2, 2), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert!(result.is_null(0)); + assert!(result.is_null(1)); // should overflow because 1.00 does not fit into precision (2) + } + + #[test] + fn get_decimal32_precision_overflow_unsafe_errors() { + let mut builder = crate::VariantArrayBuilder::new(1); + builder.append_variant( + VariantDecimal4::try_new(VariantDecimal4::MAX_UNSCALED_VALUE, 0) + .unwrap() + .into(), + ); + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal32(9, 2), true); + let cast_options = CastOptions { + safe: false, + ..Default::default() + }; + let options = GetOptions::new() + .with_as_type(Some(FieldRef::from(field))) + .with_cast_options(cast_options); + let err = variant_get(&variant_array, options).unwrap_err(); + + assert!( + err.to_string().contains( + "Failed to cast to Decimal32(precision=9, scale=2) from variant Decimal4" + ) + ); + } + + #[test] + fn get_decimal64_rescaled_to_scale2() { + let mut builder = crate::VariantArrayBuilder::new(5); + builder.append_variant(VariantDecimal8::try_new(1234, 2).unwrap().into()); // 12.34 + builder.append_variant(VariantDecimal8::try_new(1234, 3).unwrap().into()); // 1.234 + builder.append_variant(VariantDecimal8::try_new(1234, 0).unwrap().into()); // 1234 + builder.append_null(); + builder.append_variant( + VariantDecimal16::try_new((VariantDecimal8::MAX_UNSCALED_VALUE as i128) + 1, 3) + .unwrap() + .into(), + ); // should fit into Decimal64 + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal64(18, 2), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.precision(), 18); + assert_eq!(result.scale(), 2); + assert_eq!(result.value(0), 1234); + assert_eq!(result.value(1), 123); + assert_eq!(result.value(2), 123400); + assert!(result.is_null(3)); + assert_eq!( + result.value(4), + VariantDecimal8::MAX_UNSCALED_VALUE / 10 + 1 + ); // should not be null as the final result fits into Decimal64 + } + + #[test] + fn get_decimal64_scale_down_rounding() { + let mut builder = crate::VariantArrayBuilder::new(7); + builder.append_variant(VariantDecimal8::try_new(1235, 0).unwrap().into()); + builder.append_variant(VariantDecimal8::try_new(1245, 0).unwrap().into()); + builder.append_variant(VariantDecimal8::try_new(-1235, 0).unwrap().into()); + builder.append_variant(VariantDecimal8::try_new(-1245, 0).unwrap().into()); + builder.append_variant(VariantDecimal8::try_new(1235, 2).unwrap().into()); // 12.35 rounded down to 10 for scale -1 + builder.append_variant(VariantDecimal8::try_new(1235, 3).unwrap().into()); // 1.235 rounded down to 0 for scale -1 + builder.append_variant(VariantDecimal8::try_new(5235, 3).unwrap().into()); // 5.235 rounded up to 10 for scale -1 + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal64(18, -1), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.precision(), 18); + assert_eq!(result.scale(), -1); + assert_eq!(result.value(0), 124); + assert_eq!(result.value(1), 125); + assert_eq!(result.value(2), -124); + assert_eq!(result.value(3), -125); + assert_eq!(result.value(4), 1); + assert!(result.is_valid(5)); + assert_eq!(result.value(5), 0); + assert_eq!(result.value(6), 1); + } + + #[test] + fn get_decimal64_large_scale_reduction() { + let mut builder = crate::VariantArrayBuilder::new(2); + builder.append_variant( + VariantDecimal8::try_new(-VariantDecimal8::MAX_UNSCALED_VALUE, 0) + .unwrap() + .into(), + ); + builder.append_variant( + VariantDecimal8::try_new(VariantDecimal8::MAX_UNSCALED_VALUE, 0) + .unwrap() + .into(), + ); + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal64(18, -18), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.precision(), 18); + assert_eq!(result.scale(), -18); + assert_eq!(result.value(0), -1); + assert_eq!(result.value(1), 1); + + let field = Field::new("result", DataType::Decimal64(18, -19), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.precision(), 18); + assert_eq!(result.scale(), -19); + assert!(result.is_valid(0)); + assert_eq!(result.value(0), 0); + assert!(result.is_valid(1)); + assert_eq!(result.value(1), 0); + } + + #[test] + fn get_decimal64_precision_overflow_safe() { + // Exceed Decimal64 after scaling and rounding + let mut builder = crate::VariantArrayBuilder::new(2); + builder.append_variant( + VariantDecimal8::try_new(VariantDecimal8::MAX_UNSCALED_VALUE, 0) + .unwrap() + .into(), + ); + builder.append_variant( + VariantDecimal8::try_new(VariantDecimal8::MAX_UNSCALED_VALUE, 18) + .unwrap() + .into(), + ); // integer value round up overflows + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal64(2, 2), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert!(result.is_null(0)); + assert!(result.is_null(1)); + } + + #[test] + fn get_decimal64_precision_overflow_unsafe_errors() { + let mut builder = crate::VariantArrayBuilder::new(1); + builder.append_variant( + VariantDecimal8::try_new(VariantDecimal8::MAX_UNSCALED_VALUE, 0) + .unwrap() + .into(), + ); + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal64(18, 2), true); + let cast_options = CastOptions { + safe: false, + ..Default::default() + }; + let options = GetOptions::new() + .with_as_type(Some(FieldRef::from(field))) + .with_cast_options(cast_options); + let err = variant_get(&variant_array, options).unwrap_err(); + + assert!( + err.to_string().contains( + "Failed to cast to Decimal64(precision=18, scale=2) from variant Decimal8" + ) + ); + } + + #[test] + fn get_decimal128_rescaled_to_scale2() { + let mut builder = crate::VariantArrayBuilder::new(4); + builder.append_variant(VariantDecimal16::try_new(1234, 2).unwrap().into()); + builder.append_variant(VariantDecimal16::try_new(1234, 3).unwrap().into()); + builder.append_variant(VariantDecimal16::try_new(1234, 0).unwrap().into()); + builder.append_null(); + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal128(38, 2), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.precision(), 38); + assert_eq!(result.scale(), 2); + assert_eq!(result.value(0), 1234); + assert_eq!(result.value(1), 123); + assert_eq!(result.value(2), 123400); + assert!(result.is_null(3)); + } + + #[test] + fn get_decimal128_scale_down_rounding() { + let mut builder = crate::VariantArrayBuilder::new(7); + builder.append_variant(VariantDecimal16::try_new(1235, 0).unwrap().into()); + builder.append_variant(VariantDecimal16::try_new(1245, 0).unwrap().into()); + builder.append_variant(VariantDecimal16::try_new(-1235, 0).unwrap().into()); + builder.append_variant(VariantDecimal16::try_new(-1245, 0).unwrap().into()); + builder.append_variant(VariantDecimal16::try_new(1235, 2).unwrap().into()); // 12.35 rounded down to 10 for scale -1 + builder.append_variant(VariantDecimal16::try_new(1235, 3).unwrap().into()); // 1.235 rounded down to 0 for scale -1 + builder.append_variant(VariantDecimal16::try_new(5235, 3).unwrap().into()); // 5.235 rounded up to 10 for scale -1 + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal128(38, -1), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.precision(), 38); + assert_eq!(result.scale(), -1); + assert_eq!(result.value(0), 124); + assert_eq!(result.value(1), 125); + assert_eq!(result.value(2), -124); + assert_eq!(result.value(3), -125); + assert_eq!(result.value(4), 1); + assert!(result.is_valid(5)); + assert_eq!(result.value(5), 0); + assert_eq!(result.value(6), 1); + } + + #[test] + fn get_decimal128_precision_overflow_safe() { + // Exceed Decimal128 after scaling and rounding + let mut builder = crate::VariantArrayBuilder::new(2); + builder.append_variant( + VariantDecimal16::try_new(VariantDecimal16::MAX_UNSCALED_VALUE, 0) + .unwrap() + .into(), + ); + builder.append_variant( + VariantDecimal16::try_new(VariantDecimal16::MAX_UNSCALED_VALUE, 38) + .unwrap() + .into(), + ); // integer value round up overflows + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal128(2, 2), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert!(result.is_null(0)); + assert!(result.is_null(1)); // should overflow because 1.00 does not fit into precision (2) + } + + #[test] + fn get_decimal128_precision_overflow_unsafe_errors() { + let mut builder = crate::VariantArrayBuilder::new(1); + builder.append_variant( + VariantDecimal16::try_new(VariantDecimal16::MAX_UNSCALED_VALUE, 0) + .unwrap() + .into(), + ); + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal128(38, 2), true); + let cast_options = CastOptions { + safe: false, + ..Default::default() + }; + let options = GetOptions::new() + .with_as_type(Some(FieldRef::from(field))) + .with_cast_options(cast_options); + let err = variant_get(&variant_array, options).unwrap_err(); + + assert!(err.to_string().contains( + "Failed to cast to Decimal128(precision=38, scale=2) from variant Decimal16" + )); + } + + #[test] + fn get_decimal256_rescaled_to_scale2() { + // Build unshredded variant values with different scales using Decimal16 source + let mut builder = crate::VariantArrayBuilder::new(4); + builder.append_variant(VariantDecimal16::try_new(1234, 2).unwrap().into()); // 12.34 + builder.append_variant(VariantDecimal16::try_new(1234, 3).unwrap().into()); // 1.234 + builder.append_variant(VariantDecimal16::try_new(1234, 0).unwrap().into()); // 1234 + builder.append_null(); + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal256(76, 2), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.precision(), 76); + assert_eq!(result.scale(), 2); + assert_eq!(result.value(0), i256::from_i128(1234)); + assert_eq!(result.value(1), i256::from_i128(123)); + assert_eq!(result.value(2), i256::from_i128(123400)); + assert!(result.is_null(3)); + } + + #[test] + fn get_decimal256_scale_down_rounding() { + let mut builder = crate::VariantArrayBuilder::new(7); + builder.append_variant(VariantDecimal16::try_new(1235, 0).unwrap().into()); + builder.append_variant(VariantDecimal16::try_new(1245, 0).unwrap().into()); + builder.append_variant(VariantDecimal16::try_new(-1235, 0).unwrap().into()); + builder.append_variant(VariantDecimal16::try_new(-1245, 0).unwrap().into()); + builder.append_variant(VariantDecimal16::try_new(1235, 2).unwrap().into()); // 12.35 rounded down to 10 for scale -1 + builder.append_variant(VariantDecimal16::try_new(1235, 3).unwrap().into()); // 1.235 rounded down to 0 for scale -1 + builder.append_variant(VariantDecimal16::try_new(5235, 3).unwrap().into()); // 5.235 rounded up to 10 for scale -1 + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal256(76, -1), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.precision(), 76); + assert_eq!(result.scale(), -1); + assert_eq!(result.value(0), i256::from_i128(124)); + assert_eq!(result.value(1), i256::from_i128(125)); + assert_eq!(result.value(2), i256::from_i128(-124)); + assert_eq!(result.value(3), i256::from_i128(-125)); + assert_eq!(result.value(4), i256::from_i128(1)); + assert!(result.is_valid(5)); + assert_eq!(result.value(5), i256::from_i128(0)); + assert_eq!(result.value(6), i256::from_i128(1)); + } + + #[test] + fn get_decimal256_precision_overflow_safe() { + // Exceed Decimal128 max precision (38) after scaling + let mut builder = crate::VariantArrayBuilder::new(2); + builder.append_variant( + VariantDecimal16::try_new(VariantDecimal16::MAX_UNSCALED_VALUE, 1) + .unwrap() + .into(), + ); + builder.append_variant( + VariantDecimal16::try_new(VariantDecimal16::MAX_UNSCALED_VALUE, 0) + .unwrap() + .into(), + ); + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal256(76, 39), true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&variant_array, options).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + // Input is Decimal16 with integer = 10^38-1 and scale = 1, target scale = 39 + // So expected integer is (10^38-1) * 10^(39-1) = (10^38-1) * 10^38 + let base = i256::from_i128(10); + let factor = base.checked_pow(38).unwrap(); + let expected = i256::from_i128(VariantDecimal16::MAX_UNSCALED_VALUE) + .checked_mul(factor) + .unwrap(); + assert_eq!(result.value(0), expected); + assert!(result.is_null(1)); + } + + #[test] + fn get_decimal256_precision_overflow_unsafe_errors() { + // Exceed Decimal128 max precision (38) after scaling + let mut builder = crate::VariantArrayBuilder::new(2); + builder.append_variant( + VariantDecimal16::try_new(VariantDecimal16::MAX_UNSCALED_VALUE, 1) + .unwrap() + .into(), + ); + builder.append_variant( + VariantDecimal16::try_new(VariantDecimal16::MAX_UNSCALED_VALUE, 0) + .unwrap() + .into(), + ); + let variant_array: ArrayRef = ArrayRef::from(builder.build()); + + let field = Field::new("result", DataType::Decimal256(76, 39), true); + let cast_options = CastOptions { + safe: false, + ..Default::default() + }; + let options = GetOptions::new() + .with_as_type(Some(FieldRef::from(field))) + .with_cast_options(cast_options); + let err = variant_get(&variant_array, options).unwrap_err(); + + assert!(err.to_string().contains( + "Failed to cast to Decimal256(precision=76, scale=39) from variant Decimal16" + )); + } } diff --git a/parquet-variant-compute/src/variant_to_arrow.rs b/parquet-variant-compute/src/variant_to_arrow.rs index d60a4eea05c0..f2ac4a722fff 100644 --- a/parquet-variant-compute/src/variant_to_arrow.rs +++ b/parquet-variant-compute/src/variant_to_arrow.rs @@ -16,14 +16,16 @@ // under the License. use arrow::array::{ - ArrayRef, BinaryViewArray, NullBufferBuilder, PrimitiveBuilder, builder::BooleanBuilder, + ArrayRef, BinaryViewArray, BooleanBuilder, NullBufferBuilder, PrimitiveBuilder, }; -use arrow::compute::CastOptions; -use arrow::datatypes::{self, ArrowPrimitiveType, DataType}; +use arrow::compute::{CastOptions, DecimalCast}; +use arrow::datatypes::{self, ArrowPrimitiveType, DataType, DecimalType}; use arrow::error::{ArrowError, Result}; use parquet_variant::{Variant, VariantPath}; -use crate::type_conversion::{PrimitiveFromVariant, TimestampFromVariant}; +use crate::type_conversion::{ + PrimitiveFromVariant, TimestampFromVariant, variant_to_unscaled_decimal, +}; use crate::{VariantArray, VariantValueArrayBuilder}; use arrow_schema::TimeUnit; @@ -45,6 +47,10 @@ pub(crate) enum PrimitiveVariantToArrowRowBuilder<'a> { Float16(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float16Type>), Float32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float32Type>), Float64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float64Type>), + Decimal32(VariantToDecimalArrowRowBuilder<'a, datatypes::Decimal32Type>), + Decimal64(VariantToDecimalArrowRowBuilder<'a, datatypes::Decimal64Type>), + Decimal128(VariantToDecimalArrowRowBuilder<'a, datatypes::Decimal128Type>), + Decimal256(VariantToDecimalArrowRowBuilder<'a, datatypes::Decimal256Type>), TimestampMicro(VariantToTimestampArrowRowBuilder<'a, datatypes::TimestampMicrosecondType>), TimestampMicroNtz( VariantToTimestampNtzArrowRowBuilder<'a, datatypes::TimestampMicrosecondType>, @@ -82,6 +88,10 @@ impl<'a> PrimitiveVariantToArrowRowBuilder<'a> { Float16(b) => b.append_null(), Float32(b) => b.append_null(), Float64(b) => b.append_null(), + Decimal32(b) => b.append_null(), + Decimal64(b) => b.append_null(), + Decimal128(b) => b.append_null(), + Decimal256(b) => b.append_null(), TimestampMicro(b) => b.append_null(), TimestampMicroNtz(b) => b.append_null(), TimestampNano(b) => b.append_null(), @@ -105,6 +115,10 @@ impl<'a> PrimitiveVariantToArrowRowBuilder<'a> { Float16(b) => b.append_value(value), Float32(b) => b.append_value(value), Float64(b) => b.append_value(value), + Decimal32(b) => b.append_value(value), + Decimal64(b) => b.append_value(value), + Decimal128(b) => b.append_value(value), + Decimal256(b) => b.append_value(value), TimestampMicro(b) => b.append_value(value), TimestampMicroNtz(b) => b.append_value(value), TimestampNano(b) => b.append_value(value), @@ -128,6 +142,10 @@ impl<'a> PrimitiveVariantToArrowRowBuilder<'a> { Float16(b) => b.finish(), Float32(b) => b.finish(), Float64(b) => b.finish(), + Decimal32(b) => b.finish(), + Decimal64(b) => b.finish(), + Decimal128(b) => b.finish(), + Decimal256(b) => b.finish(), TimestampMicro(b) => b.finish(), TimestampMicroNtz(b) => b.finish(), TimestampNano(b) => b.finish(), @@ -174,79 +192,94 @@ pub(crate) fn make_primitive_variant_to_arrow_row_builder<'a>( ) -> Result> { use PrimitiveVariantToArrowRowBuilder::*; - let builder = match data_type { - DataType::Boolean => Boolean(VariantToBooleanArrowRowBuilder::new(cast_options, capacity)), - DataType::Int8 => Int8(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Int16 => Int16(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Int32 => Int32(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Int64 => Int64(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::UInt8 => UInt8(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::UInt16 => UInt16(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::UInt32 => UInt32(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::UInt64 => UInt64(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Float16 => Float16(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Float32 => Float32(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Float64 => Float64(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Timestamp(TimeUnit::Microsecond, None) => TimestampMicroNtz( - VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity), - ), - DataType::Timestamp(TimeUnit::Microsecond, tz) => TimestampMicro( - VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()), - ), - DataType::Timestamp(TimeUnit::Nanosecond, None) => TimestampNanoNtz( - VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity), - ), - DataType::Timestamp(TimeUnit::Nanosecond, tz) => TimestampNano( - VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()), - ), - DataType::Date32 => Date(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - _ if data_type.is_primitive() => { - return Err(ArrowError::NotYetImplemented(format!( - "Primitive data_type {data_type:?} not yet implemented" - ))); - } - _ => { - return Err(ArrowError::InvalidArgumentError(format!( - "Not a primitive type: {data_type:?}" - ))); - } - }; + let builder = + match data_type { + DataType::Boolean => { + Boolean(VariantToBooleanArrowRowBuilder::new(cast_options, capacity)) + } + DataType::Int8 => Int8(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Int16 => Int16(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Int32 => Int32(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Int64 => Int64(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::UInt8 => UInt8(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::UInt16 => UInt16(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::UInt32 => UInt32(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::UInt64 => UInt64(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Float16 => Float16(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Float32 => Float32(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Float64 => Float64(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Decimal32(precision, scale) => Decimal32( + VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?, + ), + DataType::Decimal64(precision, scale) => Decimal64( + VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?, + ), + DataType::Decimal128(precision, scale) => Decimal128( + VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?, + ), + DataType::Decimal256(precision, scale) => Decimal256( + VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?, + ), + DataType::Timestamp(TimeUnit::Microsecond, None) => TimestampMicroNtz( + VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity), + ), + DataType::Timestamp(TimeUnit::Microsecond, tz) => TimestampMicro( + VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()), + ), + DataType::Timestamp(TimeUnit::Nanosecond, None) => TimestampNanoNtz( + VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity), + ), + DataType::Timestamp(TimeUnit::Nanosecond, tz) => TimestampNano( + VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()), + ), + DataType::Date32 => Date(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + _ if data_type.is_primitive() => { + return Err(ArrowError::NotYetImplemented(format!( + "Primitive data_type {data_type:?} not yet implemented" + ))); + } + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Not a primitive type: {data_type:?}" + ))); + } + }; Ok(builder) } @@ -431,6 +464,67 @@ define_variant_to_primitive_builder!( type_name: get_type_name::() ); +/// Builder for converting variant values to arrow Decimal values +pub(crate) struct VariantToDecimalArrowRowBuilder<'a, T> +where + T: DecimalType, + T::Native: DecimalCast, +{ + builder: PrimitiveBuilder, + cast_options: &'a CastOptions<'a>, + precision: u8, + scale: i8, +} + +impl<'a, T> VariantToDecimalArrowRowBuilder<'a, T> +where + T: DecimalType, + T::Native: DecimalCast, +{ + fn new( + cast_options: &'a CastOptions<'a>, + capacity: usize, + precision: u8, + scale: i8, + ) -> Result { + let builder = PrimitiveBuilder::::with_capacity(capacity) + .with_precision_and_scale(precision, scale)?; + Ok(Self { + builder, + cast_options, + precision, + scale, + }) + } + + fn append_null(&mut self) -> Result<()> { + self.builder.append_null(); + Ok(()) + } + + fn append_value(&mut self, value: &Variant<'_, '_>) -> Result { + if let Some(scaled) = variant_to_unscaled_decimal::(value, self.precision, self.scale) { + self.builder.append_value(scaled); + Ok(true) + } else if self.cast_options.safe { + self.builder.append_null(); + Ok(false) + } else { + Err(ArrowError::CastError(format!( + "Failed to cast to {}(precision={}, scale={}) from variant {:?}", + T::PREFIX, + self.precision, + self.scale, + value + ))) + } + } + + fn finish(mut self) -> Result { + Ok(Arc::new(self.builder.finish())) + } +} + /// Builder for creating VariantArray output (for path extraction without type conversion) pub(crate) struct VariantToBinaryVariantArrowRowBuilder { metadata: BinaryViewArray,