Skip to content

Commit 8fff4e8

Browse files
authored
Merge pull request #45 from marklogic/feature/count-fix
DEVEXP-467 Fix for count(); not pushing down when groupBy exists
2 parents c1bf751 + 016eb1e commit 8fff4e8

File tree

4 files changed

+29
-5
lines changed

4 files changed

+29
-5
lines changed

src/main/java/com/marklogic/spark/reader/MarkLogicScanBuilder.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,10 @@ public boolean pushAggregation(Aggregation aggregation) {
151151

152152
@Override
153153
public boolean supportCompletePushDown(Aggregation aggregation) {
154-
// Only a single "count()" call is supported so far. Will expand as we add support for other aggregations.
154+
// Only a single "count()" call is supported so far. Will expand as we add support for other aggregations,
155+
// including support for groupBy() + count().
155156
AggregateFunc[] expressions = aggregation.aggregateExpressions();
156-
return expressions != null && expressions.length == 1 && expressions[0] instanceof CountStar;
157+
return expressions.length == 1 && expressions[0] instanceof CountStar && aggregation.groupByExpressions().length == 0;
157158
}
158159

159160
@Override

src/main/java/com/marklogic/spark/reader/PlanUtil.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ static ObjectNode buildGroupByCount() {
2424
return newOperation("group-by", args -> args
2525
.add(objectMapper.nullNode())
2626
// Using "null" is the equivalent of "count(*)" - it counts rows, not values.
27-
.addObject().put("ns", "op").put("fn", "count").putArray("args").add("Count").add(objectMapper.nullNode()));
27+
.addObject().put("ns", "op").put("fn", "count").putArray("args").add("count").add(objectMapper.nullNode()));
2828
}
2929

3030
static ObjectNode buildLimit(int limit) {

src/main/java/com/marklogic/spark/reader/ReadContext.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ void pushDownCount() {
176176
// As will likely be the case for all aggregations, the schema needs to be modified. And the plan analysis is
177177
// rebuilt to contain a single bucket, as the assumption is that MarkLogic can efficiently determine the count
178178
// in a single call to /v1/rows, regardless of the number of matching rows.
179-
this.schema = new StructType().add("Count", DataTypes.LongType);
179+
this.schema = new StructType().add("count", DataTypes.LongType);
180180
this.planAnalysis = new PlanAnalysis(this.planAnalysis.boundedPlan);
181181
}
182182
}

src/test/java/com/marklogic/spark/reader/PushDownCountTest.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
package com.marklogic.spark.reader;
22

33
import com.marklogic.spark.Options;
4+
import org.apache.spark.sql.Row;
45
import org.junit.jupiter.api.Test;
56

7+
import java.util.List;
8+
69
import static org.junit.jupiter.api.Assertions.assertEquals;
710

811
public class PushDownCountTest extends AbstractPushDownTest {
912

1013
@Test
11-
void test() {
14+
void count() {
1215
long count = newDefaultReader()
1316
.option(Options.READ_NUM_PARTITIONS, 2)
1417
.option(Options.READ_BATCH_SIZE, 1000)
@@ -23,4 +26,24 @@ void test() {
2326
"that regardless of the number of matching rows, MarkLogic can efficiently determine a count in a single " +
2427
"request.");
2528
}
29+
30+
@Test
31+
void groupByAndCount() {
32+
List<Row> rows = newDefaultReader()
33+
.option(Options.READ_OPTIC_DSL, QUERY_WITH_NO_QUALIFIER)
34+
.load()
35+
.groupBy("CitationID")
36+
.count()
37+
.orderBy("CitationID")
38+
.collectAsList();
39+
40+
assertEquals(15, countOfRowsReadFromMarkLogic, "groupBy + count is not yet being pushed down to MarkLogic; " +
41+
"only count() by itself is being pushed down. So expecting all rows to be read for now.");
42+
43+
assertEquals(4, (long) rows.get(0).getAs("count"));
44+
assertEquals(4, (long) rows.get(1).getAs("count"));
45+
assertEquals(4, (long) rows.get(2).getAs("count"));
46+
assertEquals(1, (long) rows.get(3).getAs("count"));
47+
assertEquals(2, (long) rows.get(4).getAs("count"));
48+
}
2649
}

0 commit comments

Comments
 (0)