Skip to content

Commit 7ece631

Browse files
committed
DEVEXP-484 Now supporting all aggregates
1 parent f6d76ac commit 7ece631

16 files changed

+722
-85
lines changed

docs/reading.md

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,24 +110,20 @@ fixed via changes to the options passed to the connector should be reported as n
110110

111111
The Spark connector framework supports pushing down multiple operations to the connector data source. This can
112112
often provide a significant performance boost by allowing the data source to perform the operation, which can result in
113-
both fewer rows returned to Spark and less work for Spark to perform. The connector supports pushing
113+
both fewer rows returned to Spark and less work for Spark to perform. The MarkLogic Spark connector supports pushing
114114
down the following operations to MarkLogic:
115115

116116
- `count`
117117
- `drop` and `select`
118118
- `filter` and `where`
119-
- `groupBy` when followed by `count`
119+
- `groupBy` plus any of `avg`, `count`, `max`, `mean`, `min`, or `sum`
120120
- `limit`
121121
- `orderBy` and `sort`
122122

123-
For each of the above operations, the user's Optic query is enhanced to include the associated Optic function.
124-
Note that if multiple partitions are used to perform the `read` operation, each
125-
partition will apply the above functions on the rows that it retrieves from MarkLogic. Spark will then merge the results
126-
from each partition and re-apply the function calls as necessary to ensure that the correct response is returned.
127-
128-
If either `count` or `groupBy` and `count` are pushed down, the connector will make a single request to MarkLogic to
129-
resolve the query (thus ignoring the number of partitions and batch size that may have been configured; see below
130-
for more information on these options), ensuring that a single count or set of counts is returned to Spark.
123+
For each of the above operations, the user's Optic query is enhanced to include the associated Optic function. Note
124+
that if multiple partitions are used to perform the `read` operation, each partition will apply the above
125+
functions on the rows that it retrieves from MarkLogic. Spark will then merge the results from each partition and
126+
apply the aggregation to ensure that the correct response is returned.
131127

132128
In the following example, every operation after `load()` is pushed down to MarkLogic, thereby resulting in far fewer
133129
rows being returned to Spark and far less work having to be done by Spark:

src/main/java/com/marklogic/spark/reader/MarkLogicScanBuilder.java

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717

1818
import com.marklogic.spark.reader.filter.FilterFactory;
1919
import com.marklogic.spark.reader.filter.OpticFilter;
20-
import org.apache.spark.sql.connector.expressions.Expression;
2120
import org.apache.spark.sql.connector.expressions.SortOrder;
22-
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc;
2321
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation;
22+
import org.apache.spark.sql.connector.expressions.aggregate.Avg;
23+
import org.apache.spark.sql.connector.expressions.aggregate.Count;
2424
import org.apache.spark.sql.connector.expressions.aggregate.CountStar;
25+
import org.apache.spark.sql.connector.expressions.aggregate.Max;
26+
import org.apache.spark.sql.connector.expressions.aggregate.Min;
27+
import org.apache.spark.sql.connector.expressions.aggregate.Sum;
2528
import org.apache.spark.sql.connector.read.Scan;
2629
import org.apache.spark.sql.connector.read.ScanBuilder;
2730
import org.apache.spark.sql.connector.read.SupportsPushDownAggregates;
@@ -36,7 +39,10 @@
3639

3740
import java.util.ArrayList;
3841
import java.util.Arrays;
42+
import java.util.HashSet;
3943
import java.util.List;
44+
import java.util.Set;
45+
import java.util.stream.Stream;
4046

4147
public class MarkLogicScanBuilder implements ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit,
4248
SupportsPushDownTopN, SupportsPushDownAggregates, SupportsPushDownRequiredColumns {
@@ -46,6 +52,15 @@ public class MarkLogicScanBuilder implements ScanBuilder, SupportsPushDownFilter
4652
private ReadContext readContext;
4753
private List<Filter> pushedFilters;
4854

55+
private final static Set<Class> SUPPORTED_AGGREGATE_FUNCTIONS = new HashSet() {{
56+
add(Avg.class);
57+
add(Count.class);
58+
add(CountStar.class);
59+
add(Max.class);
60+
add(Min.class);
61+
add(Sum.class);
62+
}};
63+
4964
public MarkLogicScanBuilder(ReadContext readContext) {
5065
this.readContext = readContext;
5166
}
@@ -138,38 +153,46 @@ public boolean isPartiallyPushed() {
138153
return readContext.getBucketCount() > 1;
139154
}
140155

156+
/**
157+
* Per the Spark javadocs, this should return true if we can push down the entire aggregation. This is only
158+
* possible if every aggregation function is supported and if only one request will be made to MarkLogic. If
159+
* multiple requests are made to MarkLogic (based on the user-defined partition count and batch size), then
160+
* Spark has to apply the aggregation against the combined set of rows returned from all requests to MarkLogic.
161+
*
162+
* @param aggregation
163+
* @return
164+
*/
141165
@Override
142-
public boolean pushAggregation(Aggregation aggregation) {
166+
public boolean supportCompletePushDown(Aggregation aggregation) {
143167
if (readContext.planAnalysisFoundNoRows()) {
144168
return false;
145169
}
146-
if (supportCompletePushDown(aggregation)) {
147-
if (aggregation.groupByExpressions().length > 0) {
148-
if (logger.isInfoEnabled()) {
149-
logger.info("Pushing down groupBy + count on: {}", Arrays.asList(aggregation.groupByExpressions()));
150-
}
151-
readContext.pushDownGroupByCount(aggregation.groupByExpressions());
152-
} else {
153-
if (logger.isInfoEnabled()) {
154-
logger.info("Pushing down count()");
155-
}
156-
readContext.pushDownCount();
157-
}
158-
return true;
170+
171+
if (hasUnsupportedAggregateFunction(aggregation)) {
172+
logger.info("Aggregation contains one or more unsupported functions, " +
173+
"so not pushing aggregation to MarkLogic: {}", describeAggregation(aggregation));
174+
return false;
159175
}
160-
return false;
176+
177+
if (readContext.getBucketCount() > 1) {
178+
logger.info("Multiple requests will be made to MarkLogic; aggregation will be applied by Spark as well: {}",
179+
describeAggregation(aggregation));
180+
return false;
181+
}
182+
return true;
161183
}
162184

163185
@Override
164-
public boolean supportCompletePushDown(Aggregation aggregation) {
165-
if (readContext.planAnalysisFoundNoRows()) {
186+
public boolean pushAggregation(Aggregation aggregation) {
187+
// For the initial 2.0 release, there aren't any known unsupported aggregate functions that can be called
188+
// after a "groupBy". If one is detected though, the aggregation won't be pushed down as it's uncertain if
189+
// pushing it down would produce the correct results.
190+
if (readContext.planAnalysisFoundNoRows() || hasUnsupportedAggregateFunction(aggregation)) {
166191
return false;
167192
}
168-
AggregateFunc[] expressions = aggregation.aggregateExpressions();
169-
// If a count() is used, it's supported if there's no groupBy - i.e. just doing a count() by itself -
170-
// and supported with 1 to many groupBy's - e.g. groupBy("column", "someOtherColumn").count().
171-
// Other aggregate functions will be supported in the near future.
172-
return expressions.length == 1 && expressions[0] instanceof CountStar;
193+
logger.info("Pushing down aggregation: {}", describeAggregation(aggregation));
194+
readContext.pushDownAggregation(aggregation);
195+
return true;
173196
}
174197

175198
@Override
@@ -189,4 +212,16 @@ public void pruneColumns(StructType requiredSchema) {
189212
readContext.pushDownRequiredSchema(requiredSchema);
190213
}
191214
}
215+
216+
private boolean hasUnsupportedAggregateFunction(Aggregation aggregation) {
217+
return Stream
218+
.of(aggregation.aggregateExpressions())
219+
.anyMatch(func -> !SUPPORTED_AGGREGATE_FUNCTIONS.contains(func.getClass()));
220+
}
221+
222+
private String describeAggregation(Aggregation aggregation) {
223+
return String.format("groupBy: %s; aggregates: %s",
224+
Arrays.asList(aggregation.groupByExpressions()),
225+
Arrays.asList(aggregation.aggregateExpressions()));
226+
}
192227
}

src/main/java/com/marklogic/spark/reader/PlanUtil.java

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,24 @@
2323
import org.apache.spark.sql.connector.expressions.NamedReference;
2424
import org.apache.spark.sql.connector.expressions.SortDirection;
2525
import org.apache.spark.sql.connector.expressions.SortOrder;
26+
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc;
27+
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation;
28+
import org.apache.spark.sql.connector.expressions.aggregate.Avg;
29+
import org.apache.spark.sql.connector.expressions.aggregate.Count;
30+
import org.apache.spark.sql.connector.expressions.aggregate.CountStar;
31+
import org.apache.spark.sql.connector.expressions.aggregate.Max;
32+
import org.apache.spark.sql.connector.expressions.aggregate.Min;
33+
import org.apache.spark.sql.connector.expressions.aggregate.Sum;
2634
import org.apache.spark.sql.types.StructField;
2735
import org.apache.spark.sql.types.StructType;
2836
import org.slf4j.Logger;
2937
import org.slf4j.LoggerFactory;
3038

39+
import java.util.HashMap;
3140
import java.util.List;
41+
import java.util.Map;
3242
import java.util.function.Consumer;
43+
import java.util.function.Function;
3344

3445
/**
3546
* Methods for modifying a serialized Optic plan. These were moved here both to facilitate unit testing for some of them
@@ -41,27 +52,59 @@ public abstract class PlanUtil {
4152

4253
private final static ObjectMapper objectMapper = new ObjectMapper();
4354

44-
static ObjectNode buildGroupByCount() {
45-
return newOperation("group-by", args -> {
46-
args.add(objectMapper.nullNode());
47-
addCountArg(args);
55+
private static Map<Class<? extends AggregateFunc>, Function<AggregateFunc, OpticFunction>> aggregateFunctionHandlers;
56+
57+
// Construct the mapping of Spark aggregate function instances to OpticFunction instances that are used to build
58+
// the corresponding serialized Optic function reference.
59+
static {
60+
aggregateFunctionHandlers = new HashMap<>();
61+
aggregateFunctionHandlers.put(Avg.class, func -> {
62+
Avg avg = (Avg) func;
63+
return new OpticFunction("avg", avg.column(), avg.isDistinct());
64+
});
65+
aggregateFunctionHandlers.put(Count.class, func -> {
66+
Count count = (Count)func;
67+
return new OpticFunction("count", count.column(), count.isDistinct());
68+
});
69+
aggregateFunctionHandlers.put(Max.class, func -> new OpticFunction("max", ((Max) func).column()));
70+
aggregateFunctionHandlers.put(Min.class, func -> new OpticFunction("min", ((Min) func).column()));
71+
aggregateFunctionHandlers.put(Sum.class, func -> {
72+
Sum sum = (Sum) func;
73+
return new OpticFunction("sum", sum.column(), sum.isDistinct());
4874
});
4975
}
5076

51-
static ObjectNode buildGroupByCount(List<String> columnNames) {
52-
return newOperation("group-by", args -> {
53-
ArrayNode columns = args.addArray();
77+
static ObjectNode buildGroupByAggregation(List<String> columnNames, Aggregation aggregation) {
78+
return newOperation("group-by", groupByArgs -> {
79+
ArrayNode columns = groupByArgs.addArray();
5480
columnNames.forEach(columnName -> populateSchemaCol(columns.addObject(), columnName));
55-
addCountArg(args);
56-
});
57-
}
5881

59-
private static void addCountArg(ArrayNode args) {
60-
args.addObject().put("ns", "op").put("fn", "count").putArray("args")
61-
// "count" is used as the column name as that's what Spark uses when the operation is not pushed down.
62-
.add("count")
63-
// Using "null" is the equivalent of "count(*)" - it counts rows, not values.
64-
.add(objectMapper.nullNode());
82+
ArrayNode aggregates = groupByArgs.addArray();
83+
for (AggregateFunc func : aggregation.aggregateExpressions()) {
84+
// Need special handling for CountStar, as it does not have a column name with it.
85+
if (func instanceof CountStar) {
86+
aggregates.addObject().put("ns", "op").put("fn", "count").putArray("args")
87+
// "count" is used as the column name as that's what Spark uses when the operation is not pushed down.
88+
.add("count")
89+
// Using "null" is the equivalent of "count(*)" - it counts rows, not values.
90+
.add(objectMapper.nullNode());
91+
} else if (aggregateFunctionHandlers.containsKey(func.getClass())) {
92+
OpticFunction opticFunction = aggregateFunctionHandlers.get(func.getClass()).apply(func);
93+
ArrayNode aggregateArgs = aggregates
94+
.addObject().put("ns", "op").put("fn", opticFunction.functionName)
95+
.putArray("args");
96+
aggregateArgs.add(func.toString());
97+
populateSchemaCol(aggregateArgs.addObject(), opticFunction.columnName);
98+
// TODO This is the correct JSON to add, but have not found a way to create an AggregateFunc that
99+
// returns "true" for isDistinct().
100+
if (opticFunction.distinct) {
101+
aggregateArgs.addObject().put("values", "distinct");
102+
}
103+
} else {
104+
logger.info("Unsupported aggregate function, will not be pushed to Optic: {}", func);
105+
}
106+
}
107+
});
65108
}
66109

67110
static ObjectNode buildLimit(int limit) {
@@ -71,7 +114,7 @@ static ObjectNode buildLimit(int limit) {
71114
static ObjectNode buildOrderBy(SortOrder[] sortOrders) {
72115
return newOperation("order-by", args -> {
73116
ArrayNode innerArgs = args.addArray();
74-
for (SortOrder sortOrder: sortOrders) {
117+
for (SortOrder sortOrder : sortOrders) {
75118
final String direction = SortDirection.ASCENDING.equals(sortOrder.direction()) ? "asc" : "desc";
76119
ArrayNode orderByArgs = innerArgs.addObject().put("ns", "op").put("fn", direction).putArray("args");
77120
String columnName = expressionToColumnName(sortOrder.expression());
@@ -170,4 +213,24 @@ static String expressionToColumnName(Expression expression) {
170213
}
171214
return fieldNames[0];
172215
}
216+
217+
/**
218+
* Captures the name of an Optic function and the column name based on a Spark AggregateFunc's Expression. Used
219+
* to simplify building a serialized Optic function reference.
220+
*/
221+
private static class OpticFunction {
222+
final String functionName;
223+
final String columnName;
224+
final boolean distinct;
225+
226+
OpticFunction(String functionName, Expression column) {
227+
this(functionName, column, false);
228+
}
229+
230+
OpticFunction(String functionName, Expression column, boolean distinct) {
231+
this.functionName = functionName;
232+
this.columnName = expressionToColumnName(column);
233+
this.distinct = distinct;
234+
}
235+
}
173236
}

0 commit comments

Comments
 (0)