Skip to content

Commit 879eab9

Browse files
committed
[jOOQ#360] Add Agg.stddev() and variance() and similar aggregate functions
- Renamed methods to stddevDouble() and varianceDouble() (no usage of By()) - Simplified collectors by reusing JDK Collectors.collectingAndThen() - Removed comments and Javadoc
1 parent 3e0e0d2 commit 879eab9

File tree

4 files changed

+149
-226
lines changed

4 files changed

+149
-226
lines changed

jOOL-java-8/src/main/java/org/jooq/lambda/Agg.java

Lines changed: 46 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import java.util.stream.Collector;
3131
import java.util.stream.Stream;
3232

33-
import org.jooq.lambda.tuple.Tuple;
3433
import org.jooq.lambda.tuple.Tuple2;
3534

3635
/**
@@ -39,6 +38,7 @@
3938
* The class name isn't set in stone and will change.
4039
*
4140
* @author Lukas Eder
41+
* @author Jichen Lu
4242
*/
4343
public class Agg {
4444

@@ -1013,14 +1013,14 @@ private static int percentileIndex(double percentile, int size) {
10131013
}
10141014

10151015
/**
1016-
* Get a {@link Collector} that calculates the derived <code>PERCENTILE_DISC(percentile)</code> function given a specific ordering, producing multiple results.
1016+
* Get a {@link Collector} that calculates the <code>PERCENTILE_DISC(percentile)</code> function given a specific ordering, producing multiple results.
10171017
*/
10181018
public static <T extends Comparable<? super T>> Collector<T, ?, Seq<T>> medianAll() {
10191019
return medianAllBy(t -> t, naturalOrder());
10201020
}
10211021

10221022
/**
1023-
* Get a {@link Collector} that calculates the derived <code>PERCENTILE_DISC(percentile)</code> function given a specific ordering, producing multiple results.
1023+
* Get a {@link Collector} that calculates the <code>PERCENTILE_DISC(percentile)</code> function given a specific ordering, producing multiple results.
10241024
*/
10251025
public static <T> Collector<T, ?, Seq<T>> medianAll(Comparator<? super T> comparator) {
10261026
return medianAllBy(t -> t, comparator);
@@ -1040,6 +1040,49 @@ private static int percentileIndex(double percentile, int size) {
10401040
return percentileAllBy(0.5, function, comparator);
10411041
}
10421042

1043+
/**
1044+
* Get a {@link Collector} that calculates the <code>STDDEV_POP()</code> function.
1045+
*/
1046+
public static <T> Collector<Double, ?, Optional<Double>> stddevDouble() {
1047+
return stddevDouble(t -> t);
1048+
}
1049+
1050+
/**
1051+
* Get a {@link Collector} that calculates the <code>STDDEV_POP()</code> function.
1052+
*/
1053+
public static <T, U> Collector<T, ?, Optional<Double>> stddevDouble(ToDoubleFunction<? super T> function) {
1054+
return collectingAndThen(toList(), l -> l.isEmpty() ? Optional.empty() : Optional.of(Math.sqrt(variance0(l, function))));
1055+
}
1056+
1057+
/**
1058+
* Get a {@link Collector} that calculates the <code>VAR_POP()</code> function.
1059+
*/
1060+
public static Collector<Double, ?, Optional<Double>> varianceDouble() {
1061+
return varianceDouble(t -> t);
1062+
}
1063+
1064+
/**
1065+
* Get a {@link Collector} that calculates the <code>VAR_POP()</code> function.
1066+
*/
1067+
public static <T> Collector<T, ?, Optional<Double>> varianceDouble(ToDoubleFunction<? super T> function) {
1068+
return collectingAndThen(toList(), l -> l.isEmpty() ? Optional.empty() : Optional.of(variance0(l, function)));
1069+
}
1070+
1071+
private static <T> double variance0(List<T> l, ToDoubleFunction<? super T> function) {
1072+
double sum = 0.0;
1073+
double sumVariance = 0.0;
1074+
1075+
for (T o : l)
1076+
sum += function.applyAsDouble(o);
1077+
1078+
double average = sum / l.size();
1079+
1080+
for (T o : l)
1081+
sumVariance += Math.pow(function.applyAsDouble(o) - average, 2);
1082+
1083+
return sumVariance / l.size();
1084+
}
1085+
10431086
/**
10441087
* Get a {@link Collector} that calculates the common prefix of a set of strings.
10451088
*/
@@ -1079,149 +1122,4 @@ private static int percentileIndex(double percentile, int size) {
10791122
s -> s.map(Objects::toString).orElse("")
10801123
);
10811124
}
1082-
1083-
1084-
//CS304 Issue link: https://github.com/jOOQ/jOOL/issues/360
1085-
1086-
/**
1087-
* Calculate the variance of objects with mapping function from object to double value,
1088-
* with given list, size and mapping function.
1089-
*
1090-
* <p><pre>
1091-
* This function is based on the definition of standard deviation:
1092-
* First, calculate the sum of value of all objects and the average value;
1093-
* Second, calculate the sum of square of difference between each objects and average;
1094-
* Third, use sum of variance divide size to obtain variance of all objects.
1095-
* </pre>
1096-
*
1097-
* @param function mapping function from object to double value
1098-
* @param l a list containing all the objects
1099-
* @param size the size of the list in collector
1100-
* @return the variance value
1101-
* @version 1.0
1102-
* @author Jichen Lu
1103-
* @date 2021-04-25
1104-
*/
1105-
private static <T> double getVariance(
1106-
Function<? super T, Double> function,
1107-
ArrayList<T> l, int size) {
1108-
double sum = 0.0;
1109-
double average;
1110-
double sumVariance = 0.0;
1111-
double variance;
1112-
for (T o : l) {
1113-
double temp = function.apply(o);
1114-
sum += temp;
1115-
}
1116-
average = sum / size;
1117-
for (T o : l) {
1118-
double temp = Math.pow(function.apply(o) - average, 2);
1119-
sumVariance += temp;
1120-
}
1121-
variance = sumVariance / size;
1122-
return variance;
1123-
}
1124-
1125-
1126-
//CS304 Issue link: https://github.com/jOOQ/jOOL/issues/360
1127-
1128-
/**
1129-
* Calculate the variance of the given object collectors,
1130-
* based on the mapping function from object to double number.
1131-
*
1132-
* <p><pre> Usage of aggregation function stddevBy():
1133-
* The mapping function is a function mapping the objects to a double value, for instance:
1134-
* {@code
1135-
* Function<Integer, Double> mapping = e -> Double.valueOf(e);
1136-
* }
1137-
* The specific usage of stddevBy is like:
1138-
* {@code
1139-
* Seq.of(1, 1, 1, 1).collect(Agg.stddevBy(mapping));
1140-
* }
1141-
* Besides, self defined class is also allowed with mapping function:
1142-
* {@code
1143-
* Seq.of(new Node(1), new Node(1), new Node(1), new Node(1)).collect(Agg.stddevBy(mapping1));
1144-
* }
1145-
* </pre>
1146-
*
1147-
* @param function mapping function from object to double value
1148-
* @return the stddev value
1149-
* @version 1.0
1150-
* @author Jichen Lu
1151-
* @date 2021-04-25
1152-
*/
1153-
public static <T> Collector<T, ?, Optional<Double>> stddevBy(
1154-
Function<? super T,
1155-
Double> function) {
1156-
return Collector.of(
1157-
(Supplier<ArrayList<T>>) ArrayList::new,
1158-
ArrayList::add,
1159-
(l1, l2) -> {
1160-
l1.addAll(l2);
1161-
return l1;
1162-
},
1163-
l -> {
1164-
int size = l.size();
1165-
if (size == 0) {
1166-
return Optional.empty();
1167-
}
1168-
double variance = getVariance(function, l, size);
1169-
double stddev;
1170-
1171-
1172-
stddev = Math.sqrt(variance);
1173-
1174-
return Optional.of(stddev);
1175-
}
1176-
);
1177-
}
1178-
1179-
1180-
//CS304 Issue link: https://github.com/jOOQ/jOOL/issues/360
1181-
1182-
/**
1183-
* Calculate the variance of the given object collectors,
1184-
* based on the mapping function from object to double number.
1185-
*
1186-
* <p><pre> Usage of aggregation function varianceBy():
1187-
* The mapping function is a function mapping the objects to a double value, for instance:
1188-
* {@code
1189-
* Function<Integer, Double> mapping = e -> Double.valueOf(e);
1190-
* }
1191-
* The specific usage of varianceBy is like:
1192-
* {@code
1193-
* Seq.of(1, 1, 1, 1).collect(Agg.varianceBy(mapping));
1194-
* }
1195-
* Besides, self defined class is also allowed with mapping function:
1196-
* {@code
1197-
* Seq.of(new Node(1), new Node(1), new Node(1), new Node(1)).collect(Agg.varianceBy(mapping1));
1198-
* }
1199-
* </pre>
1200-
*
1201-
* @param function mapping function from object to double value
1202-
* @return the stddev value
1203-
* @version 1.0
1204-
* @author Jichen Lu
1205-
* @date 2021-04-25
1206-
*/
1207-
public static <T> Collector<T, ?, Optional<Double>> varianceBy(
1208-
Function<? super T,
1209-
Double> function) {
1210-
return Collector.of(
1211-
(Supplier<ArrayList<T>>) ArrayList::new,
1212-
ArrayList::add,
1213-
(l1, l2) -> {
1214-
l1.addAll(l2);
1215-
return l1;
1216-
},
1217-
l -> {
1218-
int size = l.size();
1219-
if (size == 0) {
1220-
return Optional.empty();
1221-
}
1222-
double variance = getVariance(function, l, size);
1223-
return Optional.of(variance);
1224-
}
1225-
);
1226-
}
12271125
}

jOOL-java-8/src/test/java/org/jooq/lambda/CollectorTests.java

Lines changed: 22 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,12 @@
2222
import static org.jooq.lambda.tuple.Tuple.tuple;
2323
import static org.junit.Assert.assertEquals;
2424

25-
import java.text.DecimalFormat;
2625
import java.util.Optional;
2726
import java.util.function.Function;
2827
import java.util.Comparator;
29-
import java.util.Optional;
3028
import java.util.function.Supplier;
3129
import java.util.stream.Collector;
3230
import java.util.stream.Stream;
33-
import java.util.function.Function;
3431

3532
import org.jooq.lambda.tuple.Tuple;
3633
import org.jooq.lambda.tuple.Tuple9;
@@ -571,7 +568,7 @@ public void testMedianAllByWithoutComparator() {
571568

572569
@Test
573570
public void testMedianAllByWithoutComparator2() {
574-
Item a = new Item(1), b = new Item(1), c = new Item(2), d = new Item(2),
571+
Item a = new Item(1), b = new Item(1), c = new Item(2), d = new Item(2),
575572
e = new Item(2), f = new Item(3), g = new Item(3), h = new Item(3),
576573
i = new Item(4), j = new Item(4), k = new Item(5), l = new Item(6),
577574
m = new Item(7), n = new Item(7), o = new Item(7), p = new Item(7);
@@ -586,12 +583,12 @@ public void testMedianAllWithComparator() {
586583

587584
@Test
588585
public void testMedianAllWithComparator2() {
589-
Item a = new Item(1), b = new Item(1), c = new Item(2), d = new Item(2),
586+
Item a = new Item(1), b = new Item(1), c = new Item(2), d = new Item(2),
590587
e = new Item(2), f = new Item(3), g = new Item(3), h = new Item(3),
591588
i = new Item(4), j = new Item(4), k = new Item(5), l = new Item(6),
592589
m = new Item(7), n = new Item(7), o = new Item(7), p = new Item(7);
593590

594-
assertEquals(asList(j, i), Seq.of(c, j, n, d, e, o, l, p, a, m, h, b, k, g, f, i).collect(medianAll(Comparator.comparing(Item::reverse))).toList());
591+
assertEquals(asList(j, i), Seq.of(c, j, n, d, e, o, l, p, a, m, h, b, k, g, f, i).collect(medianAll(Comparator.comparing(CollectorTests.Item::reverse))).toList());
595592
}
596593

597594
@Test
@@ -605,7 +602,7 @@ public void testMedianAllWithoutComparator() {
605602

606603
@Test
607604
public void testMedianAllWithoutComparator2() {
608-
Item a = new Item(1), b = new Item(1), c = new Item(2), d = new Item(2),
605+
Item a = new Item(1), b = new Item(1), c = new Item(2), d = new Item(2),
609606
e = new Item(2), f = new Item(3), g = new Item(3), h = new Item(3),
610607
i = new Item(4), j = new Item(4), k = new Item(5), l = new Item(6),
611608
m = new Item(7), n = new Item(7), o = new Item(7), p = new Item(7);
@@ -1313,73 +1310,29 @@ public void testDropping() {
13131310
assertEquals(Seq.of("c", "d").toList(), Seq.of("a", "b", "c", "d").collect(Agg.dropping(2)).toList());
13141311
}
13151312

1316-
//CS304 (manually written) Issue link: https://github.com/jOOQ/jOOL/issues/360
1317-
1318-
/**
1319-
* Test the Seq with numbers.
1320-
*
1321-
* @result The result will be the standard deviation and variance.
1322-
*/
13231313
@Test
13241314
public void testStddevAndVarianceWithNumber() {
1325-
DecimalFormat df = new java.text.DecimalFormat("#.000");
1326-
1327-
Function<Integer, Double> mapping = Double::valueOf;
1328-
1329-
assertEquals(Optional.empty(), Seq.<Integer>of().collect(Agg.varianceBy(mapping)));
1330-
assertEquals(Optional.empty(), Seq.<Integer>of().collect(Agg.stddevBy(mapping)));
1331-
assertEquals(Optional.of(0.0), Seq.of(1).collect(Agg.varianceBy(mapping)));
1332-
assertEquals(Optional.of(0.0), Seq.of(1).collect(Agg.stddevBy(mapping)));
1333-
assertEquals(Optional.of(0.0), Seq.of(1, 1, 1, 1).collect(Agg.varianceBy(mapping)));
1334-
assertEquals(Optional.of(0.0), Seq.of(1, 1, 1, 1).collect(Agg.stddevBy(mapping)));
1335-
assertEquals(Optional.of(1.0), Seq.of(1, 1, 3, 3).collect(Agg.varianceBy(mapping)));
1336-
assertEquals(Optional.of(1.0), Seq.of(1, 1, 3, 3).collect(Agg.stddevBy(mapping)));
1337-
assertEquals(Optional.of(1.250), Seq.of(1, 2, 3, 4).collect(Agg.varianceBy(mapping)));
1338-
assertEquals(Optional.of(1.118), Optional.of(Double.parseDouble(df.format(Seq.of(1, 2, 3, 4).collect(Agg.stddevBy(mapping)).get()))));
1339-
1340-
1315+
assertEquals(Optional.empty(), Seq.<Double>of().collect(Agg.varianceDouble()));
1316+
assertEquals(Optional.empty(), Seq.<Double>of().collect(Agg.stddevDouble()));
1317+
assertEquals(Optional.of(0.0), Seq.of(1.0).collect(Agg.varianceDouble()));
1318+
assertEquals(Optional.of(0.0), Seq.of(1.0).collect(Agg.stddevDouble()));
1319+
assertEquals(Optional.of(0.0), Seq.of(1.0, 1.0, 1.0, 1.0).collect(Agg.varianceDouble()));
1320+
assertEquals(Optional.of(0.0), Seq.of(1.0, 1.0, 1.0, 1.0).collect(Agg.stddevDouble()));
1321+
assertEquals(Optional.of(1.0), Seq.of(1.0, 1.0, 3.0, 3.0).collect(Agg.varianceDouble()));
1322+
assertEquals(Optional.of(1.0), Seq.of(1.0, 1.0, 3.0, 3.0).collect(Agg.stddevDouble()));
1323+
assertEquals(Optional.of(1.250), Seq.of(1.0, 2.0, 3.0, 4.0).collect(Agg.varianceDouble()));
13411324
}
13421325

1343-
//CS304 (manually written) Issue link: https://github.com/jOOQ/jOOL/issues/360
1344-
1345-
/**
1346-
* Test the Seq with numbers.
1347-
*
1348-
* @result The result will be the standard deviation and variance.
1349-
*/
13501326
@Test
13511327
public void testStddevAndVarianceWithObject() {
1352-
DecimalFormat df = new java.text.DecimalFormat("#.000");
1353-
1354-
class Node {
1355-
final int value;
1356-
1357-
Node(int value) {
1358-
this.value = value;
1359-
}
1360-
1361-
public Double function() {
1362-
return (double) value;
1363-
}
1364-
1365-
public int getValue() {
1366-
return this.value;
1367-
}
1368-
}
1369-
Function<Node, Double> mapping = e -> (double) e.getValue();
1370-
1371-
assertEquals(Optional.empty(), Seq.<Node>of().collect(Agg.varianceBy(mapping)));
1372-
assertEquals(Optional.empty(), Seq.<Node>of().collect(Agg.stddevBy(mapping)));
1373-
assertEquals(Optional.of(0.0), Seq.of(new Node(1)).collect(Agg.varianceBy(mapping)));
1374-
assertEquals(Optional.of(0.0), Seq.of(new Node(1)).collect(Agg.stddevBy(mapping)));
1375-
assertEquals(Optional.of(0.0), Seq.of(new Node(1), new Node(1), new Node(1), new Node(1)).collect(Agg.varianceBy(mapping)));
1376-
assertEquals(Optional.of(0.0), Seq.of(new Node(1), new Node(1), new Node(1), new Node(1)).collect(Agg.stddevBy(mapping)));
1377-
assertEquals(Optional.of(1.0), Seq.of(new Node(1), new Node(1), new Node(3), new Node(3)).collect(Agg.varianceBy(mapping)));
1378-
assertEquals(Optional.of(1.0), Seq.of(new Node(1), new Node(1), new Node(3), new Node(3)).collect(Agg.stddevBy(mapping)));
1379-
assertEquals(Optional.of(1.250), Seq.of(new Node(1), new Node(2), new Node(3), new Node(4)).collect(Agg.varianceBy(mapping)));
1380-
assertEquals(Optional.of(1.118), Optional.of(Double.parseDouble(df.format(Seq.of(new Node(1), new Node(2), new Node(3), new Node(4)).collect(Agg.stddevBy(mapping)).get()))));
1381-
1328+
assertEquals(Optional.empty(), Seq.<Item>of().collect(Agg.varianceDouble(e -> (double) e.val)));
1329+
assertEquals(Optional.empty(), Seq.<Item>of().collect(Agg.stddevDouble(e -> (double) e.val)));
1330+
assertEquals(Optional.of(0.0), Seq.of(new Item(1)).collect(Agg.varianceDouble(e -> (double) e.val)));
1331+
assertEquals(Optional.of(0.0), Seq.of(new Item(1)).collect(Agg.stddevDouble(e -> (double) e.val)));
1332+
assertEquals(Optional.of(0.0), Seq.of(new Item(1), new Item(1), new Item(1), new Item(1)).collect(Agg.varianceDouble(e -> (double) e.val)));
1333+
assertEquals(Optional.of(0.0), Seq.of(new Item(1), new Item(1), new Item(1), new Item(1)).collect(Agg.stddevDouble(e -> (double) e.val)));
1334+
assertEquals(Optional.of(1.0), Seq.of(new Item(1), new Item(1), new Item(3), new Item(3)).collect(Agg.varianceDouble(e -> (double) e.val)));
1335+
assertEquals(Optional.of(1.0), Seq.of(new Item(1), new Item(1), new Item(3), new Item(3)).collect(Agg.stddevDouble(e -> (double) e.val)));
1336+
assertEquals(Optional.of(1.250), Seq.of(new Item(1), new Item(2), new Item(3), new Item(4)).collect(Agg.varianceDouble(e -> (double) e.val)));
13821337
}
1383-
1384-
13851338
}

0 commit comments

Comments
 (0)