23
23
import org .apache .spark .sql .connector .expressions .NamedReference ;
24
24
import org .apache .spark .sql .connector .expressions .SortDirection ;
25
25
import org .apache .spark .sql .connector .expressions .SortOrder ;
26
+ import org .apache .spark .sql .connector .expressions .aggregate .AggregateFunc ;
27
+ import org .apache .spark .sql .connector .expressions .aggregate .Aggregation ;
28
+ import org .apache .spark .sql .connector .expressions .aggregate .Avg ;
29
+ import org .apache .spark .sql .connector .expressions .aggregate .Count ;
30
+ import org .apache .spark .sql .connector .expressions .aggregate .CountStar ;
31
+ import org .apache .spark .sql .connector .expressions .aggregate .Max ;
32
+ import org .apache .spark .sql .connector .expressions .aggregate .Min ;
33
+ import org .apache .spark .sql .connector .expressions .aggregate .Sum ;
26
34
import org .apache .spark .sql .types .StructField ;
27
35
import org .apache .spark .sql .types .StructType ;
28
36
import org .slf4j .Logger ;
29
37
import org .slf4j .LoggerFactory ;
30
38
39
+ import java .util .HashMap ;
31
40
import java .util .List ;
41
+ import java .util .Map ;
32
42
import java .util .function .Consumer ;
43
+ import java .util .function .Function ;
33
44
34
45
/**
35
46
* Methods for modifying a serialized Optic plan. These were moved here both to facilitate unit testing for some of them
@@ -41,27 +52,59 @@ public abstract class PlanUtil {
41
52
42
53
private final static ObjectMapper objectMapper = new ObjectMapper ();
43
54
44
- static ObjectNode buildGroupByCount () {
45
- return newOperation ("group-by" , args -> {
46
- args .add (objectMapper .nullNode ());
47
- addCountArg (args );
55
+ private static Map <Class <? extends AggregateFunc >, Function <AggregateFunc , OpticFunction >> aggregateFunctionHandlers ;
56
+
57
+ // Construct the mapping of Spark aggregate function instances to OpticFunction instances that are used to build
58
+ // the corresponding serialized Optic function reference.
59
+ static {
60
+ aggregateFunctionHandlers = new HashMap <>();
61
+ aggregateFunctionHandlers .put (Avg .class , func -> {
62
+ Avg avg = (Avg ) func ;
63
+ return new OpticFunction ("avg" , avg .column (), avg .isDistinct ());
64
+ });
65
+ aggregateFunctionHandlers .put (Count .class , func -> {
66
+ Count count = (Count )func ;
67
+ return new OpticFunction ("count" , count .column (), count .isDistinct ());
68
+ });
69
+ aggregateFunctionHandlers .put (Max .class , func -> new OpticFunction ("max" , ((Max ) func ).column ()));
70
+ aggregateFunctionHandlers .put (Min .class , func -> new OpticFunction ("min" , ((Min ) func ).column ()));
71
+ aggregateFunctionHandlers .put (Sum .class , func -> {
72
+ Sum sum = (Sum ) func ;
73
+ return new OpticFunction ("sum" , sum .column (), sum .isDistinct ());
48
74
});
49
75
}
50
76
51
- static ObjectNode buildGroupByCount (List <String > columnNames ) {
52
- return newOperation ("group-by" , args -> {
53
- ArrayNode columns = args .addArray ();
77
+ static ObjectNode buildGroupByAggregation (List <String > columnNames , Aggregation aggregation ) {
78
+ return newOperation ("group-by" , groupByArgs -> {
79
+ ArrayNode columns = groupByArgs .addArray ();
54
80
columnNames .forEach (columnName -> populateSchemaCol (columns .addObject (), columnName ));
55
- addCountArg (args );
56
- });
57
- }
58
81
59
- private static void addCountArg (ArrayNode args ) {
60
- args .addObject ().put ("ns" , "op" ).put ("fn" , "count" ).putArray ("args" )
61
- // "count" is used as the column name as that's what Spark uses when the operation is not pushed down.
62
- .add ("count" )
63
- // Using "null" is the equivalent of "count(*)" - it counts rows, not values.
64
- .add (objectMapper .nullNode ());
82
+ ArrayNode aggregates = groupByArgs .addArray ();
83
+ for (AggregateFunc func : aggregation .aggregateExpressions ()) {
84
+ // Need special handling for CountStar, as it does not have a column name with it.
85
+ if (func instanceof CountStar ) {
86
+ aggregates .addObject ().put ("ns" , "op" ).put ("fn" , "count" ).putArray ("args" )
87
+ // "count" is used as the column name as that's what Spark uses when the operation is not pushed down.
88
+ .add ("count" )
89
+ // Using "null" is the equivalent of "count(*)" - it counts rows, not values.
90
+ .add (objectMapper .nullNode ());
91
+ } else if (aggregateFunctionHandlers .containsKey (func .getClass ())) {
92
+ OpticFunction opticFunction = aggregateFunctionHandlers .get (func .getClass ()).apply (func );
93
+ ArrayNode aggregateArgs = aggregates
94
+ .addObject ().put ("ns" , "op" ).put ("fn" , opticFunction .functionName )
95
+ .putArray ("args" );
96
+ aggregateArgs .add (func .toString ());
97
+ populateSchemaCol (aggregateArgs .addObject (), opticFunction .columnName );
98
+ // TODO This is the correct JSON to add, but have not found a way to create an AggregateFunc that
99
+ // returns "true" for isDistinct().
100
+ if (opticFunction .distinct ) {
101
+ aggregateArgs .addObject ().put ("values" , "distinct" );
102
+ }
103
+ } else {
104
+ logger .info ("Unsupported aggregate function, will not be pushed to Optic: {}" , func );
105
+ }
106
+ }
107
+ });
65
108
}
66
109
67
110
static ObjectNode buildLimit (int limit ) {
@@ -71,7 +114,7 @@ static ObjectNode buildLimit(int limit) {
71
114
static ObjectNode buildOrderBy (SortOrder [] sortOrders ) {
72
115
return newOperation ("order-by" , args -> {
73
116
ArrayNode innerArgs = args .addArray ();
74
- for (SortOrder sortOrder : sortOrders ) {
117
+ for (SortOrder sortOrder : sortOrders ) {
75
118
final String direction = SortDirection .ASCENDING .equals (sortOrder .direction ()) ? "asc" : "desc" ;
76
119
ArrayNode orderByArgs = innerArgs .addObject ().put ("ns" , "op" ).put ("fn" , direction ).putArray ("args" );
77
120
String columnName = expressionToColumnName (sortOrder .expression ());
@@ -170,4 +213,24 @@ static String expressionToColumnName(Expression expression) {
170
213
}
171
214
return fieldNames [0 ];
172
215
}
216
+
217
+ /**
218
+ * Captures the name of an Optic function and the column name based on a Spark AggregateFunc's Expression. Used
219
+ * to simplify building a serialized Optic function reference.
220
+ */
221
+ private static class OpticFunction {
222
+ final String functionName ;
223
+ final String columnName ;
224
+ final boolean distinct ;
225
+
226
+ OpticFunction (String functionName , Expression column ) {
227
+ this (functionName , column , false );
228
+ }
229
+
230
+ OpticFunction (String functionName , Expression column , boolean distinct ) {
231
+ this .functionName = functionName ;
232
+ this .columnName = expressionToColumnName (column );
233
+ this .distinct = distinct ;
234
+ }
235
+ }
173
236
}
0 commit comments