Skip to content
This repository was archived by the owner on Dec 29, 2021. It is now read-only.

Commit 4fb102d

Browse files
committed
some aggregate functions
1 parent 5000599 commit 4fb102d

File tree

3 files changed

+74
-14
lines changed

3 files changed

+74
-14
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ One can think of this library partly as a playground for features that could for
4242
- [X] Trig functions (sin, cos, tan, asin, asinh, ...) (using the `num` crate where possible)
4343
- [X] Basic arithmetic (add, mul, divide, subtract) **Implemented from Arrow**
4444
- [ ] Date/Time functions
45-
- [ ] String functions
45+
- [ ] String functions (in progress, subset implemented)
4646
- [ ] Crypto/hash functions (md5, crc32, sha{x}, ...)
4747
- [ ] Other functions (that we haven't classified)
4848

4949
- Aggregate Functions
50-
- [ ] Sum
51-
- [ ] Count
50+
- [X] Sum, max, min
51+
- [X] Count
5252
- [ ] Statistical aggregations (mean, mode, median, stddev, ...)
5353

5454
- Window Functions

src/functions/aggregate.rs

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,66 @@
1-
use arrow::array::PrimitiveArray;
1+
use arrow::array::Array;
2+
use arrow::array::{Int64Array, PrimitiveArray};
23
use arrow::array_ops;
34
use arrow::datatypes::ArrowNumericType;
5+
use arrow::datatypes::ArrowPrimitiveType;
6+
use arrow::datatypes::Int64Type;
47
use std::ops::Add;
58

69
struct AggregateFunctions;
710

811
impl AggregateFunctions {
9-
pub fn max<T>(array: &PrimitiveArray<T>) -> Option<T::Native>
12+
pub fn max<T>(arrays: Vec<&PrimitiveArray<T>>) -> Option<T::Native>
1013
where
1114
T: ArrowNumericType,
15+
T::Native: std::cmp::Ord,
1216
{
13-
array_ops::max(array)
17+
arrays.iter().map(|array| array_ops::max(array).unwrap()).max()
1418
}
15-
pub fn min<T>(array: &PrimitiveArray<T>) -> Option<T::Native>
19+
pub fn min<T>(arrays: Vec<&PrimitiveArray<T>>) -> Option<T::Native>
1620
where
1721
T: ArrowNumericType,
22+
T::Native: std::cmp::Ord,
1823
{
19-
array_ops::min(array)
24+
arrays.iter().map(|array| array_ops::max(array).unwrap()).max()
2025
}
21-
pub fn avg() {}
22-
pub fn count() {}
23-
pub fn sum<T>(array: &PrimitiveArray<T>) -> Option<T::Native>
26+
// pub fn avg<T>(array: &PrimitiveArray<T>) -> Option<f64>
27+
// where
28+
// T: ArrowNumericType
29+
// {
30+
// let sum = array_ops::sum(array);
31+
// match sum {
32+
// None => None,
33+
// Some(sum) => {
34+
// let count = AggregateFunctions::count(array).unwrap();
35+
// let sum = sum as f64;
36+
// Some(sum / count as f64)
37+
// }
38+
// }
39+
// }
40+
41+
/// Count returns the number of non-null values in the array/column.
42+
///
43+
/// For the number of all values, use `len()`
44+
pub fn count<T>(arrays: Vec<&PrimitiveArray<T>>) -> Option<i64>
45+
where
46+
T: ArrowPrimitiveType,
47+
{
48+
let mut sum = 0;
49+
arrays.iter().for_each(|array| sum += (array.len() - array.null_count()) as i64);
50+
51+
Some(sum)
52+
}
53+
fn count_distinct() {}
54+
pub fn sum<T>(arrays: Vec<&PrimitiveArray<T>>) -> Option<T::Native>
2455
where
2556
T: ArrowNumericType,
2657
T::Native: Add<Output = T::Native>,
2758
{
28-
array_ops::sum(array)
59+
let mut sum = T::default_value();
60+
arrays.iter().for_each(|array| sum = sum + array_ops::sum(array).unwrap_or(T::default_value()));
61+
62+
Some(sum)
63+
2964
}
3065
pub fn first() {}
3166
pub fn kurtosis() {}
@@ -37,3 +72,28 @@ impl AggregateFunctions {
3772
pub fn variance() {}
3873
// TODO population and sample variances
3974
}
75+
76+
#[cfg(test)]
77+
mod tests {
78+
use super::*;
79+
use arrow::array::{Float64Array, Int32Array};
80+
use arrow::datatypes::Int32Type;
81+
82+
#[test]
83+
fn testit() {
84+
let a = Int32Array::from(vec![6, 7, 8, 9, 10]);
85+
let b = Int32Array::from(vec![7, 6, 8, 9, 10]);
86+
let c = b.value(0);
87+
let d = Int32Array::from(vec![c]);
88+
// assert_eq!(7.0, c);
89+
assert_eq!(a.value(0), b.value(1));
90+
assert_eq!(b.value(0), d.value(0));
91+
}
92+
93+
#[test]
94+
fn test_aggregate_count() {
95+
let a = Int32Array::from(vec![5, 6, 7, 8, 9]);
96+
let c = AggregateFunctions::count(vec![&a]).unwrap();
97+
assert_eq!(5, c);
98+
}
99+
}

src/functions/scalar.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -797,8 +797,8 @@ mod tests {
797797
#[test]
798798
fn test_str_upper_and_lower() {
799799
let mut builder = BinaryBuilder::new(10);
800-
builder.append_string("Hello");
801-
builder.append_string("Arrow");
800+
builder.append_string("Hello").unwrap();
801+
builder.append_string("Arrow").unwrap();
802802
let array = builder.finish();
803803
let lower = ScalarFunctions::lower(&array).unwrap();
804804
assert_eq!("hello", lower.get_string(0));

0 commit comments

Comments
 (0)