Skip to content

Commit 7921c80

Browse files
authored
Merge pull request #47 from marklogic/feature/groupBy
Fixes for count/orderBy on qualified column names
2 parents d91a92f + 5cc8771 commit 7921c80

File tree

7 files changed

+264
-52
lines changed

7 files changed

+264
-52
lines changed

src/main/java/com/marklogic/spark/reader/MarkLogicPartitionReader.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ public boolean next() {
9090
if (rowIterator.hasNext()) {
9191
return true;
9292
} else {
93-
if (logger.isTraceEnabled()) {
94-
logger.trace("Count of rows for partition {} and bucket {}: {}", this.partition,
93+
if (logger.isDebugEnabled()) {
94+
logger.debug("Count of rows for partition {} and bucket {}: {}", this.partition,
9595
this.partition.buckets.get(nextBucketIndex - 1), currentBucketRowCount);
9696
}
9797
currentBucketRowCount = 0;

src/main/java/com/marklogic/spark/reader/MarkLogicScanBuilder.java

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import com.marklogic.spark.reader.filter.FilterFactory;
1919
import com.marklogic.spark.reader.filter.OpticFilter;
20+
import org.apache.spark.sql.connector.expressions.Expression;
2021
import org.apache.spark.sql.connector.expressions.SortOrder;
2122
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc;
2223
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation;
@@ -140,28 +141,45 @@ public boolean pushOffset(int offset) {
140141
@Override
141142
public boolean pushAggregation(Aggregation aggregation) {
142143
if (supportCompletePushDown(aggregation)) {
143-
if (logger.isDebugEnabled()) {
144-
logger.debug("Pushing down count()");
144+
if (aggregation.groupByExpressions().length > 0) {
145+
Expression expr = aggregation.groupByExpressions()[0];
146+
if (logger.isDebugEnabled()) {
147+
logger.debug("Pushing down by groupBy + count on: {}", expr.describe());
148+
}
149+
readContext.pushDownGroupByCount(expr);
150+
} else {
151+
if (logger.isDebugEnabled()) {
152+
logger.debug("Pushing down count()");
153+
}
154+
readContext.pushDownCount();
145155
}
146-
readContext.pushDownCount();
147156
return true;
148157
}
149158
return false;
150159
}
151160

152161
@Override
153162
public boolean supportCompletePushDown(Aggregation aggregation) {
154-
// Only a single "count()" call is supported so far. Will expand as we add support for other aggregations,
155-
// including support for groupBy() + count().
156163
AggregateFunc[] expressions = aggregation.aggregateExpressions();
157-
return expressions.length == 1 && expressions[0] instanceof CountStar && aggregation.groupByExpressions().length == 0;
164+
if (expressions.length == 1 && expressions[0] instanceof CountStar) {
165+
// If a count() is used, it's supported if there's no groupBy - i.e. just doing a count() by itself -
166+
// and supported with a single groupBy - e.g. groupBy("column").count().
167+
return aggregation.groupByExpressions().length < 2;
168+
}
169+
return false;
158170
}
159171

160172
@Override
161173
public void pruneColumns(StructType requiredSchema) {
162-
if (logger.isDebugEnabled()) {
163-
logger.debug("Pushing down required schema: {}", requiredSchema.json());
174+
if (requiredSchema.equals(readContext.getSchema())) {
175+
if (logger.isDebugEnabled()) {
176+
logger.debug("The schema to push down is equal to the existing schema, so not pushing it down.");
177+
}
178+
} else {
179+
if (logger.isDebugEnabled()) {
180+
logger.debug("Pushing down required schema: {}", requiredSchema.json());
181+
}
182+
readContext.pushDownRequiredSchema(requiredSchema);
164183
}
165-
readContext.pushDownRequiredSchema(requiredSchema);
166184
}
167185
}

src/main/java/com/marklogic/spark/reader/PlanUtil.java

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
import com.fasterxml.jackson.databind.node.ArrayNode;
55
import com.fasterxml.jackson.databind.node.ObjectNode;
66
import com.marklogic.spark.reader.filter.OpticFilter;
7+
import org.apache.spark.sql.connector.expressions.Expression;
8+
import org.apache.spark.sql.connector.expressions.NamedReference;
79
import org.apache.spark.sql.connector.expressions.SortDirection;
810
import org.apache.spark.sql.connector.expressions.SortOrder;
911
import org.apache.spark.sql.types.StructField;
1012
import org.apache.spark.sql.types.StructType;
13+
import org.slf4j.Logger;
14+
import org.slf4j.LoggerFactory;
1115

1216
import java.util.List;
1317
import java.util.function.Consumer;
@@ -18,15 +22,24 @@
1822
*/
1923
abstract class PlanUtil {
2024

25+
private final static Logger logger = LoggerFactory.getLogger(PlanUtil.class);
26+
2127
private final static ObjectMapper objectMapper = new ObjectMapper();
2228

2329
static ObjectNode buildGroupByCount() {
2430
return newOperation("group-by", args -> args
2531
.add(objectMapper.nullNode())
26-
// Using "null" is the equivalent of "count(*)" - it counts rows, not values.
2732
.addObject().put("ns", "op").put("fn", "count").putArray("args").add("count").add(objectMapper.nullNode()));
2833
}
2934

35+
static ObjectNode buildGroupByCount(String columnName) {
36+
return newOperation("group-by", args -> {
37+
populateSchemaCol(args.addObject(), columnName);
38+
// Using "null" is the equivalent of "count(*)" - it counts rows, not values.
39+
args.addObject().put("ns", "op").put("fn", "count").putArray("args").add("count").add(objectMapper.nullNode());
40+
});
41+
}
42+
3043
static ObjectNode buildLimit(int limit) {
3144
return newOperation("limit", args -> args.add(limit));
3245
}
@@ -37,30 +50,41 @@ static ObjectNode buildOffset(int offset) {
3750

3851
static ObjectNode buildOrderBy(SortOrder sortOrder) {
3952
final String direction = SortDirection.ASCENDING.equals(sortOrder.direction()) ? "asc" : "desc";
40-
final String columnName = sortOrder.expression().describe();
41-
return newOperation("order-by", args -> args.addObject()
42-
.put("ns", "op").put("fn", direction)
43-
.putArray("args").addObject()
44-
.put("ns", "op").put("fn", "col").putArray("args").add(columnName));
53+
final String columnName = expressionToColumnName(sortOrder.expression());
54+
return newOperation("order-by", args -> {
55+
ArrayNode orderByArgs = args.addObject().put("ns", "op").put("fn", direction).putArray("args");
56+
// This may be a bad hack to account for when the user does a groupBy/count/orderBy/limit, which does not
57+
// seem like the correct approach - the Spark ScanBuilder javadocs indicate that it should be limit/orderBy
58+
// instead. In the former scenario, we get "COUNT(*)" as the expression to order by, and we know that's not
59+
// the column name.
60+
if (logger.isDebugEnabled()) {
61+
logger.debug("Adjusting `COUNT(*)` column to be `count`");
62+
}
63+
populateSchemaCol(orderByArgs.addObject(), "COUNT(*)".equals(columnName) ? "count" : columnName);
64+
});
4565
}
4666

4767
static ObjectNode buildSelect(StructType schema) {
4868
return newOperation("select", args -> {
4969
ArrayNode innerArgs = args.addArray();
5070
for (StructField field : schema.fields()) {
51-
ArrayNode colArgs = innerArgs.addObject().put("ns", "op").put("fn", "schema-col").putArray("args");
52-
String[] parts = field.name().split("\\.");
53-
if (parts.length == 3) {
54-
colArgs.add(parts[0]).add(parts[1]).add(parts[2]);
55-
} else if (parts.length == 2) {
56-
colArgs.add(objectMapper.nullNode()).add(parts[0]).add(parts[1]);
57-
} else {
58-
colArgs.add(objectMapper.nullNode()).add(objectMapper.nullNode()).add(parts[0]);
59-
}
71+
populateSchemaCol(innerArgs.addObject(), field.name());
6072
}
6173
});
6274
}
6375

76+
private static void populateSchemaCol(ObjectNode node, String columnName) {
77+
ArrayNode colArgs = node.put("ns", "op").put("fn", "schema-col").putArray("args");
78+
String[] parts = columnName.split("\\.");
79+
if (parts.length == 3) {
80+
colArgs.add(parts[0]).add(parts[1]).add(parts[2]);
81+
} else if (parts.length == 2) {
82+
colArgs.add(objectMapper.nullNode()).add(parts[0]).add(parts[1]);
83+
} else {
84+
colArgs.add(objectMapper.nullNode()).add(objectMapper.nullNode()).add(parts[0]);
85+
}
86+
}
87+
6488
static ObjectNode buildWhere(List<OpticFilter> opticFilters) {
6589
return newOperation("where", args -> {
6690
// If there's only one filter, can toss it into the "where" clause. Else, toss an "and" into the "where" and
@@ -78,4 +102,19 @@ private static ObjectNode newOperation(String name, Consumer<ArrayNode> withArgs
78102
withArgs.accept(operation.putArray("args"));
79103
return operation;
80104
}
105+
106+
static String expressionToColumnName(Expression expression) {
107+
// The structure of an Expression isn't well-understood yet. But when it refers to a single column, the
108+
// column name can be found in the below manner. Anything else is not supported yet.
109+
NamedReference[] refs = expression.references();
110+
if (refs == null || refs.length < 1) {
111+
return expression.describe();
112+
}
113+
String[] fieldNames = refs[0].fieldNames();
114+
if (fieldNames.length != 1) {
115+
throw new IllegalArgumentException("Unsupported expression: " + expression + "; expecting expression " +
116+
"to have exactly one field name.");
117+
}
118+
return fieldNames[0];
119+
}
81120
}

src/main/java/com/marklogic/spark/reader/ReadContext.java

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
import com.marklogic.spark.Options;
3131
import com.marklogic.spark.reader.filter.OpticFilter;
3232
import org.apache.spark.sql.SparkSession;
33+
import org.apache.spark.sql.connector.expressions.Expression;
3334
import org.apache.spark.sql.connector.expressions.SortOrder;
3435
import org.apache.spark.sql.types.DataTypes;
36+
import org.apache.spark.sql.types.StructField;
3537
import org.apache.spark.sql.types.StructType;
3638
import org.slf4j.Logger;
3739
import org.slf4j.LoggerFactory;
@@ -172,15 +174,44 @@ void pushDownTopN(SortOrder[] orders, int limit) {
172174
void pushDownCount() {
173175
if (planAnalysisFoundAtLeastOneRow()) {
174176
addOperatorToPlan(PlanUtil.buildGroupByCount());
175-
176-
// As will likely be the case for all aggregations, the schema needs to be modified. And the plan analysis is
177-
// rebuilt to contain a single bucket, as the assumption is that MarkLogic can efficiently determine the count
178-
// in a single call to /v1/rows, regardless of the number of matching rows.
177+
// As will likely be the case for all aggregations, the schema needs to be modified.
179178
this.schema = new StructType().add("count", DataTypes.LongType);
180-
this.planAnalysis = new PlanAnalysis(this.planAnalysis.boundedPlan);
179+
modifyPlanAnalysisToUseSingleBucket();
180+
}
181+
}
182+
183+
void pushDownGroupByCount(Expression groupBy) {
184+
if (planAnalysisFoundAtLeastOneRow()) {
185+
final String columnName = PlanUtil.expressionToColumnName(groupBy);
186+
addOperatorToPlan(PlanUtil.buildGroupByCount(columnName));
187+
188+
StructField columnField = null;
189+
for (StructField field : this.schema.fields()) {
190+
if (columnName.equals(field.name())) {
191+
columnField = field;
192+
break;
193+
}
194+
}
195+
if (columnField == null) {
196+
throw new IllegalArgumentException("Unable to find groupBy column in schema; groupBy expression: " + groupBy.describe());
197+
}
198+
this.schema = new StructType().add(columnField).add("count", DataTypes.LongType);
199+
modifyPlanAnalysisToUseSingleBucket();
181200
}
182201
}
183202

203+
/**
204+
* Used when the assumption is that MarkLogic can efficiently execute a plan in a single call to /v1/rows. This is
205+
* typically done for "count()" operations. In such a scenario, returning 2 or more rows may produce an incorrect
206+
* result as well - for example, for a "count()" call, only the first row will be reported as the count.
207+
*/
208+
private void modifyPlanAnalysisToUseSingleBucket() {
209+
if (logger.isDebugEnabled()) {
210+
logger.debug("Modifying plan analysis to use a single bucket");
211+
}
212+
this.planAnalysis = new PlanAnalysis(this.planAnalysis.boundedPlan);
213+
}
214+
184215
void pushDownRequiredSchema(StructType requiredSchema) {
185216
if (planAnalysisFoundAtLeastOneRow()) {
186217
this.schema = requiredSchema;
@@ -202,6 +233,9 @@ private boolean planAnalysisFoundAtLeastOneRow() {
202233
* @param operator
203234
*/
204235
private void addOperatorToPlan(ObjectNode operator) {
236+
if (logger.isDebugEnabled()) {
237+
logger.debug("Adding operator to plan: {}", operator);
238+
}
205239
ArrayNode operators = (ArrayNode) planAnalysis.boundedPlan.get("$optic").get("args");
206240
operators.insert(operators.size() - 1, operator);
207241
}

src/test/java/com/marklogic/spark/reader/PushDownCountTest.java

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,4 @@ void count() {
2626
"that regardless of the number of matching rows, MarkLogic can efficiently determine a count in a single " +
2727
"request.");
2828
}
29-
30-
@Test
31-
void groupByAndCount() {
32-
List<Row> rows = newDefaultReader()
33-
.option(Options.READ_OPTIC_DSL, QUERY_WITH_NO_QUALIFIER)
34-
.load()
35-
.groupBy("CitationID")
36-
.count()
37-
.orderBy("CitationID")
38-
.collectAsList();
39-
40-
assertEquals(15, countOfRowsReadFromMarkLogic, "groupBy + count is not yet being pushed down to MarkLogic; " +
41-
"only count() by itself is being pushed down. So expecting all rows to be read for now.");
42-
43-
assertEquals(4, (long) rows.get(0).getAs("count"));
44-
assertEquals(4, (long) rows.get(1).getAs("count"));
45-
assertEquals(4, (long) rows.get(2).getAs("count"));
46-
assertEquals(1, (long) rows.get(3).getAs("count"));
47-
assertEquals(2, (long) rows.get(4).getAs("count"));
48-
}
4929
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package com.marklogic.spark.reader;
2+
3+
import com.marklogic.spark.Options;
4+
import org.apache.spark.sql.Row;
5+
import org.junit.jupiter.api.Test;
6+
7+
import java.util.List;
8+
9+
import static org.junit.jupiter.api.Assertions.assertEquals;
10+
11+
public class PushDownGroupByCountTest extends AbstractPushDownTest {
12+
13+
@Test
14+
void groupByWithNoQualifier() {
15+
List<Row> rows = newDefaultReader()
16+
.option(Options.READ_OPTIC_DSL, QUERY_WITH_NO_QUALIFIER)
17+
.load()
18+
.groupBy("CitationID")
19+
.count()
20+
.orderBy("CitationID")
21+
.collectAsList();
22+
23+
verifyGroupByWasPushedDown(rows);
24+
assertEquals(1l, (long) rows.get(0).getAs("CitationID"));
25+
}
26+
27+
@Test
28+
void groupByWithView() {
29+
List<Row> rows = newDefaultReader()
30+
.option(Options.READ_OPTIC_DSL, "op.fromView('Medical', 'Authors', 'example')")
31+
.load()
32+
.groupBy("`example.CitationID`")
33+
.count()
34+
.orderBy("`example.CitationID`")
35+
.collectAsList();
36+
37+
verifyGroupByWasPushedDown(rows);
38+
assertEquals(1l, (long) rows.get(0).getAs("example.CitationID"));
39+
}
40+
41+
@Test
42+
void groupByWithSchemaAndView() {
43+
List<Row> rows = newDefaultReader()
44+
.load()
45+
.groupBy("`Medical.Authors.CitationID`")
46+
.count()
47+
.orderBy("`Medical.Authors.CitationID`")
48+
.collectAsList();
49+
50+
verifyGroupByWasPushedDown(rows);
51+
assertEquals(1l, (long) rows.get(0).getAs("Medical.Authors.CitationID"));
52+
}
53+
54+
@Test
55+
void groupByCountLimitOrderBy() {
56+
List<Row> rows = newDefaultReader()
57+
.option(Options.READ_OPTIC_DSL, QUERY_WITH_NO_QUALIFIER)
58+
.load()
59+
.groupBy("CitationID")
60+
.count()
61+
.limit(4)
62+
// When the user puts the orderBy after limit, Spark doesn't push the orderBy down. Spark will instead
63+
// apply the orderBy itself.
64+
.orderBy("count")
65+
.collectAsList();
66+
67+
assertEquals(4, rows.size());
68+
assertEquals(4, countOfRowsReadFromMarkLogic);
69+
assertEquals(4l, (long) rows.get(0).getAs("CitationID"));
70+
assertEquals(1l, (long) rows.get(0).getAs("count"));
71+
}
72+
73+
@Test
74+
void groupByCountOrderByLimit() {
75+
List<Row> rows = newDefaultReader()
76+
.option(Options.READ_OPTIC_DSL, QUERY_WITH_NO_QUALIFIER)
77+
.load()
78+
.groupBy("CitationID")
79+
.count()
80+
// If the user puts orderBy before limit, Spark will send "COUNT(*)" as the column name for the orderBy.
81+
// The connector is expected to translate that into "count"; not sure how it should work otherwise. Spark
82+
// is expected to push down the limit as well.
83+
.orderBy("count")
84+
.limit(4)
85+
.collectAsList();
86+
87+
assertEquals(4, rows.size());
88+
assertEquals(4, countOfRowsReadFromMarkLogic);
89+
assertEquals(4l, (long) rows.get(0).getAs("CitationID"));
90+
assertEquals(1l, (long) rows.get(0).getAs("count"));
91+
}
92+
93+
private void verifyGroupByWasPushedDown(List<Row> rows) {
94+
assertEquals(5, countOfRowsReadFromMarkLogic, "groupBy should be pushed down to MarkLogic when used with " +
95+
"count, and since there are 5 CitationID values, 5 rows should be returned.");
96+
97+
assertEquals(4, (long) rows.get(0).getAs("count"));
98+
assertEquals(4, (long) rows.get(1).getAs("count"));
99+
assertEquals(4, (long) rows.get(2).getAs("count"));
100+
assertEquals(1, (long) rows.get(3).getAs("count"));
101+
assertEquals(2, (long) rows.get(4).getAs("count"));
102+
}
103+
}

0 commit comments

Comments
 (0)