Skip to content
This repository was archived by the owner on Dec 29, 2021. It is now read-only.

Commit 24a0d63

Browse files
committed
functions::array::array_remove
1 parent 1a103dd commit 24a0d63

File tree

1 file changed

+65
-1
lines changed

1 file changed

+65
-1
lines changed

src/functions/array.rs

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ impl ArrayFunctions {
8787
}
8888
Ok(b.finish())
8989
}
90+
91+
/// Locates the position of the first occurrence of the given value in the given array.
92+
/// Returns 0 if element is not found, otherwise a 1-based index with the position in the array.
9093
fn array_position<T>(array: &ListArray, val: T::Native) -> Result<Int32Array, ArrowError>
9194
where
9295
T: ArrowPrimitiveType + ArrowNumericType,
@@ -114,7 +117,38 @@ impl ArrayFunctions {
114117
}
115118
Ok(b.finish())
116119
}
117-
fn array_remove() {}
120+
121+
/// Remove all elements that equal the given element in the array
122+
fn array_remove<T>(array: &ListArray, val: T::Native) -> Result<ListArray, ArrowError>
123+
where
124+
T: ArrowPrimitiveType + ArrowNumericType,
125+
T::Native: std::cmp::PartialEq<T::Native>,
126+
{
127+
let values_builder: PrimitiveBuilder<T> = PrimitiveBuilder::new(array.values().len());
128+
let mut b = ListBuilder::new(values_builder);
129+
// get array datatype so we can downcast appropriately
130+
let data_type = array.value_type();
131+
for i in 0..array.len() {
132+
if array.is_null(i) {
133+
b.append(true)?
134+
} else {
135+
let values = array.values();
136+
let values = values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
137+
let values = values.value_slice(
138+
array.value_offset(i) as usize,
139+
array.value_length(i) as usize,
140+
);
141+
values.iter().for_each(|x| {
142+
// append value if it should not be removed
143+
if x != &val {
144+
b.values().append_value(*x).unwrap();
145+
}
146+
});
147+
b.append(true)?;
148+
}
149+
}
150+
Ok(b.finish())
151+
}
118152
fn array_repeat() {}
119153
fn array_sort() {}
120154
fn array_union() {}
@@ -250,4 +284,34 @@ mod tests {
250284
assert_eq!(2, bools.value(4));
251285
assert_eq!(0, bools.value(5));
252286
}
287+
288+
#[test]
289+
fn test_array_remove() {
290+
// Construct a value array
291+
let value_data =
292+
Int64Array::from(vec![0, 0, 0, 1, 2, 1, 3, 4, 5, 1, 3, 2, 3, 2, 8, 3]).data();
293+
294+
let value_offsets = Buffer::from(&[0, 3, 6, 8, 12, 14, 16].to_byte_slice());
295+
296+
// Construct a list array from the above two
297+
let list_data_type = DataType::List(Box::new(DataType::Int64));
298+
let list_data = ArrayData::builder(list_data_type.clone())
299+
.len(6)
300+
.add_buffer(value_offsets.clone())
301+
.add_child_data(value_data.clone())
302+
.build();
303+
let list_array = ListArray::from(list_data);
304+
305+
let b = ArrayFunctions::array_remove::<Int64Type>(&list_array, 2).unwrap();
306+
let values = b.values();
307+
let values = values.as_any().downcast_ref::<PrimitiveArray<Int64Type>>().unwrap();
308+
309+
assert_eq!(13, values.len());
310+
assert_eq!(0, b.value_offset(0));
311+
assert_eq!(3, b.value_offset(1));
312+
assert_eq!(5, b.value_offset(2));
313+
assert_eq!(7, b.value_offset(3));
314+
assert_eq!(10, b.value_offset(4));
315+
assert_eq!(11, b.value_offset(5));
316+
}
253317
}

0 commit comments

Comments
 (0)