Skip to content

Commit 4d4add5

Browse files
committed
add cast
1 parent a5874a8 commit 4d4add5

File tree

4 files changed

+102
-1
lines changed

4 files changed

+102
-1
lines changed

src/query/expression/src/evaluator.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ use crate::types::StringType;
5050
use crate::types::ValueType;
5151
use crate::types::VariantType;
5252
use crate::values::Column;
53+
use crate::types::VectorColumn;
54+
use crate::types::VectorDataType;
55+
use crate::types::VectorScalar;
56+
use crate::types::F32;
5357
use crate::values::ColumnBuilder;
5458
use crate::values::Scalar;
5559
use crate::values::Value;
@@ -924,6 +928,85 @@ impl<'a> Evaluator<'a> {
924928
other => unreachable!("source: {}", other),
925929
}
926930
}
931+
(DataType::Array(inner_src_ty), DataType::Vector(inner_dest_ty)) => {
932+
if !matches!(&**inner_src_ty, DataType::Number(_) | DataType::Decimal(_))
933+
|| matches!(inner_dest_ty, VectorDataType::Int8(_))
934+
{
935+
return Err(ErrorCode::BadArguments(format!(
936+
"unable to cast type `{src_type}` to type `{dest_type}`"
937+
))
938+
.set_span(span));
939+
}
940+
let dimension = inner_dest_ty.dimension() as usize;
941+
match value {
942+
Value::Scalar(Scalar::Array(col)) => {
943+
if col.len() != dimension {
944+
return Err(ErrorCode::BadArguments(
945+
"Array value cast to a vector has incorrect dimension".to_string(),
946+
)
947+
.set_span(span));
948+
}
949+
let mut vals = Vec::with_capacity(dimension);
950+
match col {
951+
Column::Number(num_col) => {
952+
for i in 0..dimension {
953+
let num = unsafe { num_col.index_unchecked(i) };
954+
vals.push(num.to_f32());
955+
}
956+
}
957+
Column::Decimal(dec_col) => {
958+
for i in 0..dimension {
959+
let dec = unsafe { dec_col.index_unchecked(i) };
960+
vals.push(F32::from(dec.to_float32()));
961+
}
962+
}
963+
_ => {
964+
return Err(ErrorCode::BadArguments(
965+
"Array value cast to a vector has invalid value".to_string(),
966+
)
967+
.set_span(span));
968+
}
969+
}
970+
Ok(Value::Scalar(Scalar::Vector(VectorScalar::Float32(vals))))
971+
}
972+
Value::Column(Column::Array(array_col)) => {
973+
let mut vals = Vec::with_capacity(dimension * array_col.len());
974+
for col in array_col.iter() {
975+
if col.len() != dimension {
976+
return Err(ErrorCode::BadArguments(
977+
"Array value cast to a vector has incorrect dimension"
978+
.to_string(),
979+
)
980+
.set_span(span));
981+
}
982+
match col {
983+
Column::Number(num_col) => {
984+
for i in 0..dimension {
985+
let num = unsafe { num_col.index_unchecked(i) };
986+
vals.push(num.to_f32());
987+
}
988+
}
989+
Column::Decimal(dec_col) => {
990+
for i in 0..dimension {
991+
let dec = unsafe { dec_col.index_unchecked(i) };
992+
vals.push(F32::from(dec.to_float32()));
993+
}
994+
}
995+
_ => {
996+
return Err(ErrorCode::BadArguments(
997+
"Array value cast to a vector has invalid value"
998+
.to_string(),
999+
)
1000+
.set_span(span));
1001+
}
1002+
}
1003+
}
1004+
let vector_col = VectorColumn::Float32((vals.into(), dimension));
1005+
Ok(Value::Column(Column::Vector(vector_col)))
1006+
}
1007+
other => unreachable!("source: {}", other),
1008+
}
1009+
}
9271010

9281011
_ => Err(ErrorCode::BadArguments(format!(
9291012
"unable to cast type `{src_type}` to type `{dest_type}`"

src/query/expression/src/type_check.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,11 @@ fn can_cast_to(src_ty: &DataType, dest_ty: &DataType) -> bool {
604604
{
605605
true
606606
}
607-
607+
(DataType::Array(fields_src_ty), DataType::Vector(_))
608+
if matches!(&**fields_src_ty, DataType::Number(_) | DataType::Decimal(_)) =>
609+
{
610+
true
611+
}
608612
(DataType::Nullable(box inner_src_ty), DataType::Nullable(box inner_dest_ty))
609613
| (DataType::Nullable(box inner_src_ty), inner_dest_ty)
610614
| (inner_src_ty, DataType::Nullable(box inner_dest_ty))

src/query/expression/src/types/decimal.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ pub enum DecimalScalar {
176176
}
177177

178178
impl DecimalScalar {
179+
pub fn to_float32(&self) -> f32 {
180+
with_decimal_type!(|DECIMAL| match self {
181+
DecimalScalar::DECIMAL(v, size) => v.to_float32(size.scale),
182+
})
183+
}
184+
179185
pub fn to_float64(&self) -> f64 {
180186
with_decimal_type!(|DECIMAL| match self {
181187
DecimalScalar::DECIMAL(v, size) => v.to_float64(size.scale),

src/query/expression/src/types/number.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,14 @@ impl NumberScalar {
489489
}
490490
}
491491

492+
pub fn to_f32(&self) -> F32 {
493+
crate::with_integer_mapped_type!(|NUM_TYPE| match self {
494+
NumberScalar::NUM_TYPE(num) => (*num as f32).into(),
495+
NumberScalar::Float32(num) => *num,
496+
NumberScalar::Float64(num) => (num.into_inner() as f32).into(),
497+
})
498+
}
499+
492500
pub fn to_f64(&self) -> F64 {
493501
crate::with_integer_mapped_type!(|NUM_TYPE| match self {
494502
NumberScalar::NUM_TYPE(num) => (*num as f64).into(),

0 commit comments

Comments
 (0)