Skip to content
This repository was archived by the owner on Dec 29, 2021. It is now read-only.

Commit ed6823f

Browse files
committed
more array functions
1 parent 8d577ff commit ed6823f

File tree

2 files changed

+174
-9
lines changed

2 files changed

+174
-9
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ edition = "2018"
66

77
[dependencies]
88
arrow = { git = "https://github.com/apache/arrow"}
9-
# arrow = { git = "https://github.com/nevi-me/arrow", rev="a52fa39dbd8e35278604e8ad16cb069bf30b0c9a"}
109
# arrow = { path = "../../arrow/rust/arrow"}
1110
num = "0.2"
1211
num-traits = "0.2"
1312
csv = "1"
1413
byteorder = "1"
15-
flatbuffers = "0.5"
14+
flatbuffers = "0.5"
15+
array_tool = "1"

src/functions/array.rs

Lines changed: 172 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use array_tool::vec::*;
12
use arrow::array::*;
23
use arrow::builder::*;
34
use arrow::datatypes::*;
@@ -35,10 +36,108 @@ impl ArrayFunctions {
3536
}
3637
Ok(b.finish())
3738
}
38-
fn array_distinct() {}
39-
fn array_except() {}
40-
fn array_intersect() {}
4139
fn array_join() {}
40+
fn array_distinct<T>(array: &ListArray) -> Result<ListArray, ArrowError>
41+
where
42+
T: ArrowPrimitiveType + ArrowNumericType,
43+
{
44+
let values_builder: PrimitiveBuilder<T> = PrimitiveBuilder::new(array.values().len());
45+
let mut b = ListBuilder::new(values_builder);
46+
// get array datatype so we can downcast appropriately
47+
let data_type = array.value_type();
48+
for i in 0..array.len() {
49+
if array.is_null(i) {
50+
b.append(true)?
51+
} else {
52+
let values = array.values();
53+
let values = values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
54+
let values = values.value_slice(
55+
array.value_offset(i) as usize,
56+
array.value_length(i) as usize,
57+
).to_vec();
58+
let u = values.unique();
59+
// TODO check how nulls are treated here
60+
u.iter().for_each(|x| b.values().append_value(*x).unwrap());
61+
}
62+
}
63+
Ok(b.finish())
64+
}
65+
pub fn array_except<T>(a: &ListArray, b: &ListArray) -> Result<ListArray, ArrowError>
66+
where
67+
T: ArrowPrimitiveType + ArrowNumericType,
68+
T::Native: std::cmp::PartialEq<T::Native> + std::cmp::Ord,
69+
{
70+
// check that lengths of both arrays are equal
71+
if a.len() != b.len() {
72+
return Err(ArrowError::ComputeError("Expected array a and b to have the same length".to_string()))
73+
}
74+
let values_builder: PrimitiveBuilder<T> = PrimitiveBuilder::new(a.values().len());
75+
let mut c = ListBuilder::new(values_builder);
76+
// get array datatype so we can downcast appropriately
77+
let data_type = a.value_type();
78+
for i in 0..a.len() {
79+
if a.is_null(i) {
80+
c.append(true)?
81+
} else {
82+
let a_values = a.values();
83+
let a_values = a_values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
84+
let a_values = a_values.value_slice(
85+
a.value_offset(i) as usize,
86+
a.value_length(i) as usize,
87+
).to_vec();
88+
let b_values = b.values();
89+
let b_values = b_values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
90+
let b_values = b_values.value_slice(
91+
b.value_offset(i) as usize,
92+
b.value_length(i) as usize,
93+
).to_vec();
94+
95+
let u = a_values.uniq(b_values);
96+
// TODO check how nulls are treated here
97+
u.iter().for_each(|x| c.values().append_value(*x).unwrap());
98+
c.append(true)?;
99+
}
100+
}
101+
Ok(c.finish())
102+
}
103+
pub fn array_intersect<T>(a: &ListArray, b: &ListArray) -> Result<ListArray, ArrowError>
104+
where
105+
T: ArrowPrimitiveType + ArrowNumericType,
106+
T::Native: std::cmp::PartialEq<T::Native> + std::cmp::Ord,
107+
{
108+
// check that lengths of both arrays are equal
109+
if a.len() != b.len() {
110+
return Err(ArrowError::ComputeError("Expected array a and b to have the same length".to_string()))
111+
}
112+
let values_builder: PrimitiveBuilder<T> = PrimitiveBuilder::new(a.values().len());
113+
let mut c = ListBuilder::new(values_builder);
114+
// get array datatype so we can downcast appropriately
115+
let data_type = a.value_type();
116+
for i in 0..a.len() {
117+
if a.is_null(i) {
118+
c.append(true)?
119+
} else {
120+
let a_values = a.values();
121+
let a_values = a_values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
122+
let a_values = a_values.value_slice(
123+
a.value_offset(i) as usize,
124+
a.value_length(i) as usize,
125+
).to_vec();
126+
let b_values = b.values();
127+
let b_values = b_values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
128+
let b_values = b_values.value_slice(
129+
b.value_offset(i) as usize,
130+
b.value_length(i) as usize,
131+
).to_vec();
132+
133+
let u = a_values.intersect(b_values);
134+
// TODO check how nulls are treated here
135+
u.iter().for_each(|x| c.values().append_value(*x).unwrap());
136+
c.append(true)?;
137+
}
138+
}
139+
Ok(c.finish())
140+
}
42141
pub fn array_max<T>(array: &ListArray) -> Result<PrimitiveArray<T>, ArrowError>
43142
where
44143
T: ArrowPrimitiveType + ArrowNumericType,
@@ -90,7 +189,7 @@ impl ArrayFunctions {
90189

91190
/// Locates the position of the first occurrence of the given value in the given array.
92191
/// Returns 0 if element is not found, otherwise a 1-based index with the position in the array.
93-
fn array_position<T>(array: &ListArray, val: T::Native) -> Result<Int32Array, ArrowError>
192+
pub fn array_position<T>(array: &ListArray, val: T::Native) -> Result<Int32Array, ArrowError>
94193
where
95194
T: ArrowPrimitiveType + ArrowNumericType,
96195
T::Native: std::cmp::PartialEq<T::Native>,
@@ -119,7 +218,7 @@ impl ArrayFunctions {
119218
}
120219

121220
/// Remove all elements that equal the given element in the array
122-
fn array_remove<T>(array: &ListArray, val: T::Native) -> Result<ListArray, ArrowError>
221+
pub fn array_remove<T>(array: &ListArray, val: T::Native) -> Result<ListArray, ArrowError>
123222
where
124223
T: ArrowPrimitiveType + ArrowNumericType,
125224
T::Native: std::cmp::PartialEq<T::Native>,
@@ -149,7 +248,36 @@ impl ArrayFunctions {
149248
}
150249
Ok(b.finish())
151250
}
152-
fn array_repeat() {}
251+
252+
/// TODO: extract repetitive code and share with other array fns that use `array_tool` crate
253+
pub fn array_repeat<T>(array: &ListArray, count: i32) -> Result<ListArray, ArrowError>
254+
where
255+
T: ArrowPrimitiveType + ArrowNumericType,
256+
T::Native: std::cmp::PartialEq<T::Native> + std::cmp::Ord,
257+
{
258+
let values_builder: PrimitiveBuilder<T> = PrimitiveBuilder::new(array.values().len());
259+
let mut c = ListBuilder::new(values_builder);
260+
// get array datatype so we can downcast appropriately
261+
let data_type = array.value_type();
262+
for i in 0..array.len() {
263+
if array.is_null(i) {
264+
c.append(true)?
265+
} else {
266+
let values = array.values();
267+
let values = values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
268+
let values = values.value_slice(
269+
array.value_offset(i) as usize,
270+
array.value_length(i) as usize,
271+
).to_vec();
272+
273+
let u = values.times(count);
274+
// TODO check how nulls are treated here
275+
u.iter().for_each(|x| c.values().append_value(*x).unwrap());
276+
c.append(true)?;
277+
}
278+
}
279+
Ok(c.finish())
280+
}
153281

154282
/// Sorts the input array in ascending order.
155283
///
@@ -182,7 +310,44 @@ impl ArrayFunctions {
182310
}
183311
Ok(b.finish())
184312
}
185-
fn array_union() {}
313+
pub fn array_union<T>(a: &ListArray, b: &ListArray) -> Result<ListArray, ArrowError>
314+
where
315+
T: ArrowPrimitiveType + ArrowNumericType,
316+
T::Native: std::cmp::PartialEq<T::Native> + std::cmp::Ord,
317+
{
318+
// check that lengths of both arrays are equal
319+
if a.len() != b.len() {
320+
return Err(ArrowError::ComputeError("Expected array a and b to have the same length".to_string()))
321+
}
322+
let values_builder: PrimitiveBuilder<T> = PrimitiveBuilder::new(a.values().len());
323+
let mut c = ListBuilder::new(values_builder);
324+
// get array datatype so we can downcast appropriately
325+
let data_type = a.value_type();
326+
for i in 0..a.len() {
327+
if a.is_null(i) {
328+
c.append(true)?
329+
} else {
330+
let a_values = a.values();
331+
let a_values = a_values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
332+
let a_values = a_values.value_slice(
333+
a.value_offset(i) as usize,
334+
a.value_length(i) as usize,
335+
).to_vec();
336+
let b_values = b.values();
337+
let b_values = b_values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
338+
let b_values = b_values.value_slice(
339+
b.value_offset(i) as usize,
340+
b.value_length(i) as usize,
341+
).to_vec();
342+
343+
let u = a_values.union(b_values);
344+
// TODO check how nulls are treated here
345+
u.iter().for_each(|x| c.values().append_value(*x).unwrap());
346+
c.append(true)?;
347+
}
348+
}
349+
Ok(c.finish())
350+
}
186351
fn arrays_overlap() {}
187352
fn arrays_zip() {}
188353

0 commit comments

Comments
 (0)