Skip to content

Commit 61950a8

Browse files
authored
feat: Support vector data type (#630)
* feat: Support vector data type * fix
1 parent 5af93e0 commit 61950a8

File tree

4 files changed

+134
-2
lines changed

4 files changed

+134
-2
lines changed

bindings/nodejs/src/lib.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,21 @@ impl ToNapiValue for Value<'_> {
320320
databend_driver::Value::Geometry(s) => String::to_napi_value(env, s.to_string()),
321321
databend_driver::Value::Interval(s) => String::to_napi_value(env, s.to_string()),
322322
databend_driver::Value::Geography(s) => String::to_napi_value(env, s.to_string()),
323+
databend_driver::Value::Vector(inner) => {
324+
let mut arr = ctx.create_array(inner.len() as u32)?;
325+
for (i, v) in inner.iter().enumerate() {
326+
arr.set(
327+
i as u32,
328+
Value::new(
329+
&databend_driver::Value::Number(databend_driver::NumberValue::Float32(
330+
*v,
331+
)),
332+
val.opts,
333+
),
334+
)?;
335+
}
336+
Array::to_napi_value(env, arr)
337+
}
323338
}
324339
}
325340
}

bindings/python/src/types.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ impl<'py> IntoPyObject<'py> for Value {
110110
let s = Duration::microseconds(total_micros);
111111
s.into_bound_py_any(py)?
112112
}
113+
databend_driver::Value::Vector(inner) => {
114+
let list = PyList::new(
115+
py,
116+
inner.into_iter().map(|v| {
117+
Value(databend_driver::Value::Number(
118+
databend_driver::NumberValue::Float32(v),
119+
))
120+
}),
121+
)?;
122+
list.into_bound_py_any(py)?
123+
}
113124
};
114125
Ok(val)
115126
}

sql/src/schema.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ pub(crate) const ARROW_EXT_TYPE_GEOMETRY: &str = "Geometry";
3838
pub(crate) const ARROW_EXT_TYPE_GEOGRAPHY: &str = "Geography";
3939
#[cfg(feature = "flight-sql")]
4040
pub(crate) const ARROW_EXT_TYPE_INTERVAL: &str = "Interval";
41+
#[cfg(feature = "flight-sql")]
42+
pub(crate) const ARROW_EXT_TYPE_VECTOR: &str = "Vector";
4143

4244
#[derive(Debug, Clone, PartialEq, Eq)]
4345
pub enum NumberDataType {
@@ -95,6 +97,7 @@ pub enum DataType {
9597
Geometry,
9698
Geography,
9799
Interval,
100+
Vector(u64),
98101
// Generic(usize),
99102
}
100103

@@ -156,6 +159,7 @@ impl std::fmt::Display for DataType {
156159
DataType::Geometry => write!(f, "Geometry"),
157160
DataType::Geography => write!(f, "Geography"),
158161
DataType::Interval => write!(f, "Interval"),
162+
DataType::Vector(d) => write!(f, "Vector({d})"),
159163
}
160164
}
161165
}
@@ -275,6 +279,10 @@ impl TryFrom<&TypeDesc<'_>> for DataType {
275279
"Geometry" => DataType::Geometry,
276280
"Geography" => DataType::Geography,
277281
"Interval" => DataType::Interval,
282+
"Vector" => {
283+
let dimension = desc.args[0].name.parse::<u64>()?;
284+
DataType::Vector(dimension)
285+
}
278286
_ => return Err(Error::Parsing(format!("Unknown type: {desc:?}"))),
279287
};
280288
Ok(dt)
@@ -320,6 +328,26 @@ impl TryFrom<&Arc<ArrowField>> for Field {
320328
ARROW_EXT_TYPE_BITMAP => DataType::Bitmap,
321329
ARROW_EXT_TYPE_GEOMETRY => DataType::Geometry,
322330
ARROW_EXT_TYPE_GEOGRAPHY => DataType::Geography,
331+
ARROW_EXT_TYPE_INTERVAL => DataType::Interval,
332+
ARROW_EXT_TYPE_VECTOR => match f.data_type() {
333+
ArrowDataType::FixedSizeList(field, dimension) => {
334+
let dimension = match field.data_type() {
335+
ArrowDataType::Float32 => *dimension as u64,
336+
_ => {
337+
return Err(Error::Parsing(format!(
338+
"Unsupported FixedSizeList Arrow type: {:?}",
339+
field.data_type()
340+
)));
341+
}
342+
};
343+
DataType::Vector(dimension)
344+
}
345+
arrow_type => {
346+
return Err(Error::Parsing(format!(
347+
"Unsupported Arrow type: {arrow_type:?}",
348+
)));
349+
}
350+
},
323351
_ => {
324352
return Err(Error::Parsing(format!(
325353
"Unsupported extension datatype for arrow field: {f:?}"

sql/src/value.rs

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use {
3636
crate::schema::{
3737
ARROW_EXT_TYPE_BITMAP, ARROW_EXT_TYPE_EMPTY_ARRAY, ARROW_EXT_TYPE_EMPTY_MAP,
3838
ARROW_EXT_TYPE_GEOGRAPHY, ARROW_EXT_TYPE_GEOMETRY, ARROW_EXT_TYPE_INTERVAL,
39-
ARROW_EXT_TYPE_VARIANT, EXTENSION_KEY,
39+
ARROW_EXT_TYPE_VARIANT, ARROW_EXT_TYPE_VECTOR, EXTENSION_KEY,
4040
},
4141
arrow_array::{
4242
Array as ArrowArray, BinaryArray, BooleanArray, Date32Array, Decimal128Array,
@@ -93,6 +93,7 @@ pub enum Value {
9393
Geometry(String),
9494
Geography(String),
9595
Interval(String),
96+
Vector(Vec<f32>),
9697
}
9798

9899
impl Value {
@@ -145,6 +146,7 @@ impl Value {
145146
Self::Variant(_) => DataType::Variant,
146147
Self::Geometry(_) => DataType::Geometry,
147148
Self::Geography(_) => DataType::Geography,
149+
Self::Vector(v) => DataType::Vector(v.len() as u64),
148150
}
149151
}
150152
}
@@ -229,7 +231,7 @@ impl TryFrom<(&DataType, String)> for Value {
229231
DataType::Geometry => Ok(Self::Geometry(v)),
230232
DataType::Geography => Ok(Self::Geography(v)),
231233
DataType::Interval => Ok(Self::Interval(v)),
232-
DataType::Array(_) | DataType::Map(_) | DataType::Tuple(_) => {
234+
DataType::Array(_) | DataType::Map(_) | DataType::Tuple(_) | DataType::Vector(_) => {
233235
let mut reader = Cursor::new(v.as_str());
234236
let decoder = ValueDecoder {};
235237
decoder.read_field(t, &mut reader)
@@ -329,6 +331,50 @@ impl TryFrom<(&ArrowField, &Arc<dyn ArrowArray>, usize)> for Value {
329331
None => Err(ConvertError::new("geography", format!("{array:?}")).into()),
330332
}
331333
}
334+
ARROW_EXT_TYPE_VECTOR => {
335+
if field.is_nullable() && array.is_null(seq) {
336+
return Ok(Value::Null);
337+
}
338+
match field.data_type() {
339+
ArrowDataType::FixedSizeList(_, dimension) => {
340+
match array
341+
.as_any()
342+
.downcast_ref::<arrow_array::FixedSizeListArray>()
343+
{
344+
Some(inner_array) => {
345+
match inner_array
346+
.value(seq)
347+
.as_any()
348+
.downcast_ref::<Float32Array>()
349+
{
350+
Some(inner_array) => {
351+
let dimension = *dimension as usize;
352+
let mut values = Vec::with_capacity(dimension);
353+
for i in 0..dimension {
354+
let value = inner_array.value(i);
355+
values.push(value);
356+
}
357+
Ok(Value::Vector(values))
358+
}
359+
None => Err(ConvertError::new(
360+
"vector float32",
361+
format!("{inner_array:?}"),
362+
)
363+
.into()),
364+
}
365+
}
366+
None => {
367+
Err(ConvertError::new("vector", format!("{array:?}")).into())
368+
}
369+
}
370+
}
371+
arrow_type => Err(ConvertError::new(
372+
"vector",
373+
format!("Unsupported Arrow type: {arrow_type:?}"),
374+
)
375+
.into()),
376+
}
377+
}
332378
_ => Err(ConvertError::new(
333379
"extension",
334380
format!("Unsupported extension datatype for arrow field: {field:?}"),
@@ -890,6 +936,17 @@ fn encode_value(f: &mut std::fmt::Formatter<'_>, val: &Value, raw: bool) -> std:
890936
write!(f, ")")?;
891937
Ok(())
892938
}
939+
Value::Vector(vals) => {
940+
write!(f, "[")?;
941+
for (i, val) in vals.iter().enumerate() {
942+
if i > 0 {
943+
write!(f, ",")?;
944+
}
945+
write!(f, "{val}")?;
946+
}
947+
write!(f, "]")?;
948+
Ok(())
949+
}
893950
}
894951
}
895952

@@ -1608,6 +1665,7 @@ impl ValueDecoder {
16081665
DataType::Array(inner_ty) => self.read_array(inner_ty.as_ref(), reader),
16091666
DataType::Map(inner_ty) => self.read_map(inner_ty.as_ref(), reader),
16101667
DataType::Tuple(inner_tys) => self.read_tuple(inner_tys.as_ref(), reader),
1668+
DataType::Vector(dimension) => self.read_vector(*dimension as usize, reader),
16111669
DataType::Nullable(inner_ty) => self.read_nullable(inner_ty.as_ref(), reader),
16121670
}
16131671
}
@@ -1812,6 +1870,26 @@ impl ValueDecoder {
18121870
Ok(Value::Array(vals))
18131871
}
18141872

1873+
fn read_vector<R: AsRef<[u8]>>(
1874+
&self,
1875+
dimension: usize,
1876+
reader: &mut Cursor<R>,
1877+
) -> Result<Value> {
1878+
let mut vals = Vec::with_capacity(dimension);
1879+
reader.must_ignore_byte(b'[')?;
1880+
for idx in 0..dimension {
1881+
let _ = reader.ignore_white_spaces();
1882+
if idx > 0 {
1883+
reader.must_ignore_byte(b',')?;
1884+
}
1885+
let _ = reader.ignore_white_spaces();
1886+
let val: f32 = reader.read_float_text()?;
1887+
vals.push(val);
1888+
}
1889+
reader.must_ignore_byte(b']')?;
1890+
Ok(Value::Vector(vals))
1891+
}
1892+
18151893
fn read_map<R: AsRef<[u8]>>(&self, ty: &DataType, reader: &mut Cursor<R>) -> Result<Value> {
18161894
const KEY: usize = 0;
18171895
const VALUE: usize = 1;

0 commit comments

Comments
 (0)