Skip to content

Commit 8824cd5

Browse files
committed
Improving groupBy pushdown to work for multiple column names
Turns out this was a simple enhancement since `op.groupBy` already supports multiple column names being passed in.
1 parent 99f3cdf commit 8824cd5

File tree

4 files changed

+68
-23
lines changed

4 files changed

+68
-23
lines changed

src/main/java/com/marklogic/spark/reader/MarkLogicScanBuilder.java

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,10 @@ public boolean pushAggregation(Aggregation aggregation) {
145145
}
146146
if (supportCompletePushDown(aggregation)) {
147147
if (aggregation.groupByExpressions().length > 0) {
148-
Expression expr = aggregation.groupByExpressions()[0];
149148
if (logger.isInfoEnabled()) {
150-
logger.info("Pushing down groupBy + count on: {}", expr.describe());
149+
logger.info("Pushing down groupBy + count on: {}", Arrays.asList(aggregation.groupByExpressions()));
151150
}
152-
readContext.pushDownGroupByCount(expr);
151+
readContext.pushDownGroupByCount(aggregation.groupByExpressions());
153152
} else {
154153
if (logger.isInfoEnabled()) {
155154
logger.info("Pushing down count()");
@@ -167,12 +166,10 @@ public boolean supportCompletePushDown(Aggregation aggregation) {
167166
return false;
168167
}
169168
AggregateFunc[] expressions = aggregation.aggregateExpressions();
170-
if (expressions.length == 1 && expressions[0] instanceof CountStar) {
171-
// If a count() is used, it's supported if there's no groupBy - i.e. just doing a count() by itself -
172-
// and supported with a single groupBy - e.g. groupBy("column").count().
173-
return aggregation.groupByExpressions().length < 2;
174-
}
175-
return false;
169+
// If a count() is used, it's supported if there's no groupBy - i.e. just doing a count() by itself -
170+
// and supported with 1 to many groupBy's - e.g. groupBy("column", "someOtherColumn").count().
171+
// Other aggregate functions will be supported in the near future.
172+
return expressions.length == 1 && expressions[0] instanceof CountStar;
176173
}
177174

178175
@Override

src/main/java/com/marklogic/spark/reader/PlanUtil.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ static ObjectNode buildGroupByCount() {
4848
});
4949
}
5050

51-
static ObjectNode buildGroupByCount(String columnName) {
51+
static ObjectNode buildGroupByCount(List<String> columnNames) {
5252
return newOperation("group-by", args -> {
53-
populateSchemaCol(args.addObject(), columnName);
53+
ArrayNode columns = args.addArray();
54+
columnNames.forEach(columnName -> populateSchemaCol(columns.addObject(), columnName));
5455
addCountArg(args);
5556
});
5657
}

src/main/java/com/marklogic/spark/reader/ReadContext.java

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
import java.util.Iterator;
4343
import java.util.List;
4444
import java.util.Map;
45+
import java.util.stream.Collectors;
46+
import java.util.stream.Stream;
4547

4648
/**
4749
* Captures state - all of which is serializable - that can be calculated at different times based on a user's inputs.
@@ -166,21 +168,30 @@ void pushDownCount() {
166168
modifyPlanAnalysisToUseSingleBucket();
167169
}
168170

169-
void pushDownGroupByCount(Expression groupBy) {
170-
final String columnName = PlanUtil.expressionToColumnName(groupBy);
171-
addOperatorToPlan(PlanUtil.buildGroupByCount(columnName));
171+
void pushDownGroupByCount(Expression[] groupByExpressions) {
172+
List<String> columnNames = Stream.of(groupByExpressions)
173+
.map(groupBy -> PlanUtil.expressionToColumnName(groupBy))
174+
.collect(Collectors.toList());
172175

173-
StructField columnField = null;
174-
for (StructField field : this.schema.fields()) {
175-
if (columnName.equals(field.name())) {
176-
columnField = field;
177-
break;
176+
addOperatorToPlan(PlanUtil.buildGroupByCount(columnNames));
177+
178+
StructType newSchema = new StructType();
179+
180+
for (String columnName : columnNames) {
181+
StructField columnField = null;
182+
for (StructField field : this.schema.fields()) {
183+
if (columnName.equals(field.name())) {
184+
columnField = field;
185+
break;
186+
}
178187
}
188+
if (columnField == null) {
189+
throw new IllegalArgumentException("Unable to find groupBy column in schema; column name: " + columnName);
190+
}
191+
newSchema = newSchema.add(columnField);
179192
}
180-
if (columnField == null) {
181-
throw new IllegalArgumentException("Unable to find groupBy column in schema; groupBy expression: " + groupBy.describe());
182-
}
183-
this.schema = new StructType().add(columnField).add("count", DataTypes.LongType);
193+
194+
this.schema = newSchema.add("count", DataTypes.LongType);
184195
modifyPlanAnalysisToUseSingleBucket();
185196
}
186197

src/test/java/com/marklogic/spark/reader/PushDownGroupByCountTest.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package com.marklogic.spark.reader;
1717

1818
import com.marklogic.spark.Options;
19+
import org.apache.spark.sql.Column;
1920
import org.apache.spark.sql.Row;
2021
import org.junit.jupiter.api.Test;
2122

@@ -39,6 +40,24 @@ void groupByWithNoQualifier() {
3940
assertEquals(1l, (long) rows.get(0).getAs("CitationID"));
4041
}
4142

43+
@Test
44+
void groupByMultipleColumns() {
45+
List<Row> rows = newDefaultReader()
46+
.option(Options.READ_OPTIC_QUERY, QUERY_WITH_NO_QUALIFIER)
47+
.load()
48+
.groupBy("CitationID", "Date")
49+
.count()
50+
.orderBy("CitationID")
51+
.collectAsList();
52+
53+
verifyGroupByWasPushedDown(rows);
54+
55+
assertEquals(1l, (long) rows.get(0).getAs("CitationID"));
56+
assertEquals("2022-07-13", rows.get(0).getAs("Date").toString());
57+
assertEquals(2l, (long) rows.get(1).getAs("CitationID"));
58+
assertEquals("2022-05-11", rows.get(1).getAs("Date").toString());
59+
}
60+
4261
@Test
4362
void noRowsFound() {
4463
List<Row> rows = newDefaultReader()
@@ -80,6 +99,23 @@ void groupByWithSchemaAndView() {
8099
assertEquals(1l, (long) rows.get(0).getAs("Medical.Authors.CitationID"));
81100
}
82101

102+
@Test
103+
void groupByMultipleColumnsAndSchemaAndView() {
104+
List<Row> rows = newDefaultReader()
105+
.load()
106+
.groupBy("`Medical.Authors.CitationID`", "`Medical.Authors.Date`")
107+
.count()
108+
.orderBy("`Medical.Authors.CitationID`")
109+
.collectAsList();
110+
111+
verifyGroupByWasPushedDown(rows);
112+
113+
verifyGroupByWasPushedDown(rows);
114+
assertEquals(1l, (long) rows.get(0).getAs("Medical.Authors.CitationID"));
115+
assertEquals("2022-07-13", rows.get(0).getAs("Medical.Authors.Date").toString());
116+
}
117+
118+
83119
@Test
84120
void groupByCountLimitOrderBy() {
85121
List<Row> rows = newDefaultReader()

0 commit comments

Comments
 (0)