@@ -87,7 +87,33 @@ impl ArrayFunctions {
87
87
}
88
88
Ok ( b. finish ( ) )
89
89
}
90
- fn array_position ( ) { }
90
+ fn array_position < T > ( array : & ListArray , val : T :: Native ) -> Result < Int32Array , ArrowError >
91
+ where
92
+ T : ArrowPrimitiveType + ArrowNumericType ,
93
+ T :: Native : std:: cmp:: PartialEq < T :: Native > ,
94
+ {
95
+ let mut b = Int32Builder :: new ( array. len ( ) ) ;
96
+ // get array datatype so we can downcast appropriately
97
+ let data_type = array. value_type ( ) ;
98
+ for i in 0 ..array. len ( ) {
99
+ if array. is_null ( i) {
100
+ b. append_value ( 0 ) ?
101
+ } else {
102
+ let values = array. values ( ) ;
103
+ let values = values. as_any ( ) . downcast_ref :: < PrimitiveArray < T > > ( ) . unwrap ( ) ;
104
+ let values = values. value_slice (
105
+ array. value_offset ( i) as usize ,
106
+ array. value_length ( i) as usize ,
107
+ ) ;
108
+ let pos = values. iter ( ) . position ( |x| x == & val) ;
109
+ match pos {
110
+ Some ( pos) => b. append_value ( ( pos + 1 ) as i32 ) ?,
111
+ None => b. append_value ( 0 ) ?,
112
+ } ;
113
+ }
114
+ }
115
+ Ok ( b. finish ( ) )
116
+ }
91
117
fn array_remove ( ) { }
92
118
fn array_repeat ( ) { }
93
119
fn array_sort ( ) { }
@@ -198,4 +224,30 @@ mod tests {
198
224
assert_eq ! ( true , bools. value( 4 ) ) ;
199
225
assert_eq ! ( false , bools. value( 5 ) ) ;
200
226
}
227
+
228
+ #[ test]
229
+ fn test_array_position ( ) {
230
+ // Construct a value array
231
+ let value_data =
232
+ Int64Array :: from ( vec ! [ 0 , 0 , 0 , 1 , 2 , 1 , 3 , 4 , 5 , 1 , 3 , 2 , 3 , 2 , 8 , 3 ] ) . data ( ) ;
233
+
234
+ let value_offsets = Buffer :: from ( & [ 0 , 3 , 6 , 8 , 12 , 14 , 16 ] . to_byte_slice ( ) ) ;
235
+
236
+ // Construct a list array from the above two
237
+ let list_data_type = DataType :: List ( Box :: new ( DataType :: Int64 ) ) ;
238
+ let list_data = ArrayData :: builder ( list_data_type. clone ( ) )
239
+ . len ( 6 )
240
+ . add_buffer ( value_offsets. clone ( ) )
241
+ . add_child_data ( value_data. clone ( ) )
242
+ . build ( ) ;
243
+ let list_array = ListArray :: from ( list_data) ;
244
+
245
+ let bools = ArrayFunctions :: array_position :: < Int64Type > ( & list_array, 2 ) . unwrap ( ) ;
246
+ assert_eq ! ( 0 , bools. value( 0 ) ) ;
247
+ assert_eq ! ( 2 , bools. value( 1 ) ) ;
248
+ assert_eq ! ( 0 , bools. value( 2 ) ) ;
249
+ assert_eq ! ( 4 , bools. value( 3 ) ) ;
250
+ assert_eq ! ( 2 , bools. value( 4 ) ) ;
251
+ assert_eq ! ( 0 , bools. value( 5 ) ) ;
252
+ }
201
253
}
0 commit comments