Skip to content

Commit 5cc8771

Browse files
committed
Fixes for count/orderBy on qualified column names
And... I implemented groupBy + count while I was in here fixing things as it was trivial to do. Also tweaked some logging based on testing. But the main thing is the addition of tests to verify that column names work, regardless of whether they have no qualifier, only a view, or a schema + view.
1 parent d91a92f commit 5cc8771

File tree

7 files changed

+264
-52
lines changed

7 files changed

+264
-52
lines changed

src/main/java/com/marklogic/spark/reader/MarkLogicPartitionReader.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ public boolean next() {
9090
if (rowIterator.hasNext()) {
9191
return true;
9292
} else {
93-
if (logger.isTraceEnabled()) {
94-
logger.trace("Count of rows for partition {} and bucket {}: {}", this.partition,
93+
if (logger.isDebugEnabled()) {
94+
logger.debug("Count of rows for partition {} and bucket {}: {}", this.partition,
9595
this.partition.buckets.get(nextBucketIndex - 1), currentBucketRowCount);
9696
}
9797
currentBucketRowCount = 0;

src/main/java/com/marklogic/spark/reader/MarkLogicScanBuilder.java

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import com.marklogic.spark.reader.filter.FilterFactory;
1919
import com.marklogic.spark.reader.filter.OpticFilter;
20+
import org.apache.spark.sql.connector.expressions.Expression;
2021
import org.apache.spark.sql.connector.expressions.SortOrder;
2122
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc;
2223
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation;
@@ -140,28 +141,45 @@ public boolean pushOffset(int offset) {
140141
@Override
141142
public boolean pushAggregation(Aggregation aggregation) {
142143
if (supportCompletePushDown(aggregation)) {
143-
if (logger.isDebugEnabled()) {
144-
logger.debug("Pushing down count()");
144+
if (aggregation.groupByExpressions().length > 0) {
145+
Expression expr = aggregation.groupByExpressions()[0];
146+
if (logger.isDebugEnabled()) {
147+
logger.debug("Pushing down by groupBy + count on: {}", expr.describe());
148+
}
149+
readContext.pushDownGroupByCount(expr);
150+
} else {
151+
if (logger.isDebugEnabled()) {
152+
logger.debug("Pushing down count()");
153+
}
154+
readContext.pushDownCount();
145155
}
146-
readContext.pushDownCount();
147156
return true;
148157
}
149158
return false;
150159
}
151160

152161
@Override
153162
public boolean supportCompletePushDown(Aggregation aggregation) {
154-
// Only a single "count()" call is supported so far. Will expand as we add support for other aggregations,
155-
// including support for groupBy() + count().
156163
AggregateFunc[] expressions = aggregation.aggregateExpressions();
157-
return expressions.length == 1 && expressions[0] instanceof CountStar && aggregation.groupByExpressions().length == 0;
164+
if (expressions.length == 1 && expressions[0] instanceof CountStar) {
165+
// If a count() is used, it's supported if there's no groupBy - i.e. just doing a count() by itself -
166+
// and supported with a single groupBy - e.g. groupBy("column").count().
167+
return aggregation.groupByExpressions().length < 2;
168+
}
169+
return false;
158170
}
159171

160172
@Override
161173
public void pruneColumns(StructType requiredSchema) {
162-
if (logger.isDebugEnabled()) {
163-
logger.debug("Pushing down required schema: {}", requiredSchema.json());
174+
if (requiredSchema.equals(readContext.getSchema())) {
175+
if (logger.isDebugEnabled()) {
176+
logger.debug("The schema to push down is equal to the existing schema, so not pushing it down.");
177+
}
178+
} else {
179+
if (logger.isDebugEnabled()) {
180+
logger.debug("Pushing down required schema: {}", requiredSchema.json());
181+
}
182+
readContext.pushDownRequiredSchema(requiredSchema);
164183
}
165-
readContext.pushDownRequiredSchema(requiredSchema);
166184
}
167185
}

src/main/java/com/marklogic/spark/reader/PlanUtil.java

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
import com.fasterxml.jackson.databind.node.ArrayNode;
55
import com.fasterxml.jackson.databind.node.ObjectNode;
66
import com.marklogic.spark.reader.filter.OpticFilter;
7+
import org.apache.spark.sql.connector.expressions.Expression;
8+
import org.apache.spark.sql.connector.expressions.NamedReference;
79
import org.apache.spark.sql.connector.expressions.SortDirection;
810
import org.apache.spark.sql.connector.expressions.SortOrder;
911
import org.apache.spark.sql.types.StructField;
1012
import org.apache.spark.sql.types.StructType;
13+
import org.slf4j.Logger;
14+
import org.slf4j.LoggerFactory;
1115

1216
import java.util.List;
1317
import java.util.function.Consumer;
@@ -18,15 +22,24 @@
1822
*/
1923
abstract class PlanUtil {
2024

25+
private final static Logger logger = LoggerFactory.getLogger(PlanUtil.class);
26+
2127
private final static ObjectMapper objectMapper = new ObjectMapper();
2228

2329
static ObjectNode buildGroupByCount() {
2430
return newOperation("group-by", args -> args
2531
.add(objectMapper.nullNode())
26-
// Using "null" is the equivalent of "count(*)" - it counts rows, not values.
2732
.addObject().put("ns", "op").put("fn", "count").putArray("args").add("count").add(objectMapper.nullNode()));
2833
}
2934

35+
static ObjectNode buildGroupByCount(String columnName) {
36+
return newOperation("group-by", args -> {
37+
populateSchemaCol(args.addObject(), columnName);
38+
// Using "null" is the equivalent of "count(*)" - it counts rows, not values.
39+
args.addObject().put("ns", "op").put("fn", "count").putArray("args").add("count").add(objectMapper.nullNode());
40+
});
41+
}
42+
3043
static ObjectNode buildLimit(int limit) {
3144
return newOperation("limit", args -> args.add(limit));
3245
}
@@ -37,30 +50,41 @@ static ObjectNode buildOffset(int offset) {
3750

3851
static ObjectNode buildOrderBy(SortOrder sortOrder) {
3952
final String direction = SortDirection.ASCENDING.equals(sortOrder.direction()) ? "asc" : "desc";
40-
final String columnName = sortOrder.expression().describe();
41-
return newOperation("order-by", args -> args.addObject()
42-
.put("ns", "op").put("fn", direction)
43-
.putArray("args").addObject()
44-
.put("ns", "op").put("fn", "col").putArray("args").add(columnName));
53+
final String columnName = expressionToColumnName(sortOrder.expression());
54+
return newOperation("order-by", args -> {
55+
ArrayNode orderByArgs = args.addObject().put("ns", "op").put("fn", direction).putArray("args");
56+
// This may be a bad hack to account for when the user does a groupBy/count/orderBy/limit, which does not
57+
// seem like the correct approach - the Spark ScanBuilder javadocs indicate that it should be limit/orderBy
58+
// instead. In the former scenario, we get "COUNT(*)" as the expression to order by, and we know that's not
59+
// the column name.
60+
if (logger.isDebugEnabled()) {
61+
logger.debug("Adjusting `COUNT(*)` column to be `count`");
62+
}
63+
populateSchemaCol(orderByArgs.addObject(), "COUNT(*)".equals(columnName) ? "count" : columnName);
64+
});
4565
}
4666

4767
static ObjectNode buildSelect(StructType schema) {
4868
return newOperation("select", args -> {
4969
ArrayNode innerArgs = args.addArray();
5070
for (StructField field : schema.fields()) {
51-
ArrayNode colArgs = innerArgs.addObject().put("ns", "op").put("fn", "schema-col").putArray("args");
52-
String[] parts = field.name().split("\\.");
53-
if (parts.length == 3) {
54-
colArgs.add(parts[0]).add(parts[1]).add(parts[2]);
55-
} else if (parts.length == 2) {
56-
colArgs.add(objectMapper.nullNode()).add(parts[0]).add(parts[1]);
57-
} else {
58-
colArgs.add(objectMapper.nullNode()).add(objectMapper.nullNode()).add(parts[0]);
59-
}
71+
populateSchemaCol(innerArgs.addObject(), field.name());
6072
}
6173
});
6274
}
6375

76+
private static void populateSchemaCol(ObjectNode node, String columnName) {
77+
ArrayNode colArgs = node.put("ns", "op").put("fn", "schema-col").putArray("args");
78+
String[] parts = columnName.split("\\.");
79+
if (parts.length == 3) {
80+
colArgs.add(parts[0]).add(parts[1]).add(parts[2]);
81+
} else if (parts.length == 2) {
82+
colArgs.add(objectMapper.nullNode()).add(parts[0]).add(parts[1]);
83+
} else {
84+
colArgs.add(objectMapper.nullNode()).add(objectMapper.nullNode()).add(parts[0]);
85+
}
86+
}
87+
6488
static ObjectNode buildWhere(List<OpticFilter> opticFilters) {
6589
return newOperation("where", args -> {
6690
// If there's only one filter, can toss it into the "where" clause. Else, toss an "and" into the "where" and
@@ -78,4 +102,19 @@ private static ObjectNode newOperation(String name, Consumer<ArrayNode> withArgs
78102
withArgs.accept(operation.putArray("args"));
79103
return operation;
80104
}
105+
106+
static String expressionToColumnName(Expression expression) {
107+
// The structure of an Expression isn't well-understood yet. But when it refers to a single column, the
108+
// column name can be found in the below manner. Anything else is not supported yet.
109+
NamedReference[] refs = expression.references();
110+
if (refs == null || refs.length < 1) {
111+
return expression.describe();
112+
}
113+
String[] fieldNames = refs[0].fieldNames();
114+
if (fieldNames.length != 1) {
115+
throw new IllegalArgumentException("Unsupported expression: " + expression + "; expecting expression " +
116+
"to have exactly one field name.");
117+
}
118+
return fieldNames[0];
119+
}
81120
}

src/main/java/com/marklogic/spark/reader/ReadContext.java

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
import com.marklogic.spark.Options;
3131
import com.marklogic.spark.reader.filter.OpticFilter;
3232
import org.apache.spark.sql.SparkSession;
33+
import org.apache.spark.sql.connector.expressions.Expression;
3334
import org.apache.spark.sql.connector.expressions.SortOrder;
3435
import org.apache.spark.sql.types.DataTypes;
36+
import org.apache.spark.sql.types.StructField;
3537
import org.apache.spark.sql.types.StructType;
3638
import org.slf4j.Logger;
3739
import org.slf4j.LoggerFactory;
@@ -172,15 +174,44 @@ void pushDownTopN(SortOrder[] orders, int limit) {
172174
void pushDownCount() {
173175
if (planAnalysisFoundAtLeastOneRow()) {
174176
addOperatorToPlan(PlanUtil.buildGroupByCount());
175-
176-
// As will likely be the case for all aggregations, the schema needs to be modified. And the plan analysis is
177-
// rebuilt to contain a single bucket, as the assumption is that MarkLogic can efficiently determine the count
178-
// in a single call to /v1/rows, regardless of the number of matching rows.
177+
// As will likely be the case for all aggregations, the schema needs to be modified.
179178
this.schema = new StructType().add("count", DataTypes.LongType);
180-
this.planAnalysis = new PlanAnalysis(this.planAnalysis.boundedPlan);
179+
modifyPlanAnalysisToUseSingleBucket();
180+
}
181+
}
182+
183+
void pushDownGroupByCount(Expression groupBy) {
184+
if (planAnalysisFoundAtLeastOneRow()) {
185+
final String columnName = PlanUtil.expressionToColumnName(groupBy);
186+
addOperatorToPlan(PlanUtil.buildGroupByCount(columnName));
187+
188+
StructField columnField = null;
189+
for (StructField field : this.schema.fields()) {
190+
if (columnName.equals(field.name())) {
191+
columnField = field;
192+
break;
193+
}
194+
}
195+
if (columnField == null) {
196+
throw new IllegalArgumentException("Unable to find groupBy column in schema; groupBy expression: " + groupBy.describe());
197+
}
198+
this.schema = new StructType().add(columnField).add("count", DataTypes.LongType);
199+
modifyPlanAnalysisToUseSingleBucket();
181200
}
182201
}
183202

203+
/**
204+
* Used when the assumption is that MarkLogic can efficiently execute a plan in a single call to /v1/rows. This is
205+
* typically done for "count()" operations. In such a scenario, returning 2 or more rows may produce an incorrect
206+
* result as well - for example, for a "count()" call, only the first row will be reported as the count.
207+
*/
208+
private void modifyPlanAnalysisToUseSingleBucket() {
209+
if (logger.isDebugEnabled()) {
210+
logger.debug("Modifying plan analysis to use a single bucket");
211+
}
212+
this.planAnalysis = new PlanAnalysis(this.planAnalysis.boundedPlan);
213+
}
214+
184215
void pushDownRequiredSchema(StructType requiredSchema) {
185216
if (planAnalysisFoundAtLeastOneRow()) {
186217
this.schema = requiredSchema;
@@ -202,6 +233,9 @@ private boolean planAnalysisFoundAtLeastOneRow() {
202233
* @param operator
203234
*/
204235
private void addOperatorToPlan(ObjectNode operator) {
236+
if (logger.isDebugEnabled()) {
237+
logger.debug("Adding operator to plan: {}", operator);
238+
}
205239
ArrayNode operators = (ArrayNode) planAnalysis.boundedPlan.get("$optic").get("args");
206240
operators.insert(operators.size() - 1, operator);
207241
}

src/test/java/com/marklogic/spark/reader/PushDownCountTest.java

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,4 @@ void count() {
2626
"that regardless of the number of matching rows, MarkLogic can efficiently determine a count in a single " +
2727
"request.");
2828
}
29-
30-
@Test
31-
void groupByAndCount() {
32-
List<Row> rows = newDefaultReader()
33-
.option(Options.READ_OPTIC_DSL, QUERY_WITH_NO_QUALIFIER)
34-
.load()
35-
.groupBy("CitationID")
36-
.count()
37-
.orderBy("CitationID")
38-
.collectAsList();
39-
40-
assertEquals(15, countOfRowsReadFromMarkLogic, "groupBy + count is not yet being pushed down to MarkLogic; " +
41-
"only count() by itself is being pushed down. So expecting all rows to be read for now.");
42-
43-
assertEquals(4, (long) rows.get(0).getAs("count"));
44-
assertEquals(4, (long) rows.get(1).getAs("count"));
45-
assertEquals(4, (long) rows.get(2).getAs("count"));
46-
assertEquals(1, (long) rows.get(3).getAs("count"));
47-
assertEquals(2, (long) rows.get(4).getAs("count"));
48-
}
4929
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package com.marklogic.spark.reader;
2+
3+
import com.marklogic.spark.Options;
4+
import org.apache.spark.sql.Row;
5+
import org.junit.jupiter.api.Test;
6+
7+
import java.util.List;
8+
9+
import static org.junit.jupiter.api.Assertions.assertEquals;
10+
11+
public class PushDownGroupByCountTest extends AbstractPushDownTest {
12+
13+
@Test
14+
void groupByWithNoQualifier() {
15+
List<Row> rows = newDefaultReader()
16+
.option(Options.READ_OPTIC_DSL, QUERY_WITH_NO_QUALIFIER)
17+
.load()
18+
.groupBy("CitationID")
19+
.count()
20+
.orderBy("CitationID")
21+
.collectAsList();
22+
23+
verifyGroupByWasPushedDown(rows);
24+
assertEquals(1l, (long) rows.get(0).getAs("CitationID"));
25+
}
26+
27+
@Test
28+
void groupByWithView() {
29+
List<Row> rows = newDefaultReader()
30+
.option(Options.READ_OPTIC_DSL, "op.fromView('Medical', 'Authors', 'example')")
31+
.load()
32+
.groupBy("`example.CitationID`")
33+
.count()
34+
.orderBy("`example.CitationID`")
35+
.collectAsList();
36+
37+
verifyGroupByWasPushedDown(rows);
38+
assertEquals(1l, (long) rows.get(0).getAs("example.CitationID"));
39+
}
40+
41+
@Test
42+
void groupByWithSchemaAndView() {
43+
List<Row> rows = newDefaultReader()
44+
.load()
45+
.groupBy("`Medical.Authors.CitationID`")
46+
.count()
47+
.orderBy("`Medical.Authors.CitationID`")
48+
.collectAsList();
49+
50+
verifyGroupByWasPushedDown(rows);
51+
assertEquals(1l, (long) rows.get(0).getAs("Medical.Authors.CitationID"));
52+
}
53+
54+
@Test
55+
void groupByCountLimitOrderBy() {
56+
List<Row> rows = newDefaultReader()
57+
.option(Options.READ_OPTIC_DSL, QUERY_WITH_NO_QUALIFIER)
58+
.load()
59+
.groupBy("CitationID")
60+
.count()
61+
.limit(4)
62+
// When the user puts the orderBy after limit, Spark doesn't push the orderBy down. Spark will instead
63+
// apply the orderBy itself.
64+
.orderBy("count")
65+
.collectAsList();
66+
67+
assertEquals(4, rows.size());
68+
assertEquals(4, countOfRowsReadFromMarkLogic);
69+
assertEquals(4l, (long) rows.get(0).getAs("CitationID"));
70+
assertEquals(1l, (long) rows.get(0).getAs("count"));
71+
}
72+
73+
@Test
74+
void groupByCountOrderByLimit() {
75+
List<Row> rows = newDefaultReader()
76+
.option(Options.READ_OPTIC_DSL, QUERY_WITH_NO_QUALIFIER)
77+
.load()
78+
.groupBy("CitationID")
79+
.count()
80+
// If the user puts orderBy before limit, Spark will send "COUNT(*)" as the column name for the orderBy.
81+
// The connector is expected to translate that into "count"; not sure how it should work otherwise. Spark
82+
// is expected to push down the limit as well.
83+
.orderBy("count")
84+
.limit(4)
85+
.collectAsList();
86+
87+
assertEquals(4, rows.size());
88+
assertEquals(4, countOfRowsReadFromMarkLogic);
89+
assertEquals(4l, (long) rows.get(0).getAs("CitationID"));
90+
assertEquals(1l, (long) rows.get(0).getAs("count"));
91+
}
92+
93+
private void verifyGroupByWasPushedDown(List<Row> rows) {
94+
assertEquals(5, countOfRowsReadFromMarkLogic, "groupBy should be pushed down to MarkLogic when used with " +
95+
"count, and since there are 5 CitationID values, 5 rows should be returned.");
96+
97+
assertEquals(4, (long) rows.get(0).getAs("count"));
98+
assertEquals(4, (long) rows.get(1).getAs("count"));
99+
assertEquals(4, (long) rows.get(2).getAs("count"));
100+
assertEquals(1, (long) rows.get(3).getAs("count"));
101+
assertEquals(2, (long) rows.get(4).getAs("count"));
102+
}
103+
}

0 commit comments

Comments
 (0)