Skip to content

Commit 824f79f

Browse files
authored
Merge pull request #74 from marklogic/feature/485-disable-aggregates
DEVEXP-485 Can now disable push down of aggregates
2 parents f1f4b2f + 61122f2 commit 824f79f

File tree

5 files changed

+60
-4
lines changed

5 files changed

+60
-4
lines changed

docs/configuration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ information on how data is read from MarkLogic.
6969
| spark.marklogic.read.opticQuery | Required; the Optic DSL query to run for retrieving rows; must use `op.fromView` as the accessor. |
7070
| spark.marklogic.read.numPartitions | The number of Spark partitions to create; defaults to `spark.default.parallelism` . |
7171
| spark.marklogic.read.batchSize | Approximate number of rows to retrieve in each call to MarkLogic; defaults to 10000. |
72-
72+
| spark.marklogic.read.pushDownAggregates | Whether to push down aggregate operations to MarkLogic; defaults to `true`. Set to `false` to prevent aggregates from being pushed down to MarkLogic. |
7373
## Write options
7474

7575
These options control how the connector writes data to MarkLogic. See [the guide on writing](writing.md) for more

docs/reading.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ The following results are returned:
160160
+-----+-----------+-----+
161161
```
162162

163+
### Disabling push down of aggregates
164+
165+
If you run into any issues with aggregates being pushed down to MarkLogic, you can set the
166+
`spark.marklogic.read.pushDownAggregates` option to `false`. If doing so results in what appears to be a different and
167+
correct result, please [file an issue with this project](https://github.com/marklogic/marklogic-spark-connector/issues).
168+
169+
163170
## Tuning performance
164171

165172
The primary factor affecting how quickly the connector can retrieve rows is MarkLogic's ability to process your Optic

src/main/java/com/marklogic/spark/Options.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public interface Options {
2222
String READ_OPTIC_QUERY = "spark.marklogic.read.opticQuery";
2323
String READ_NUM_PARTITIONS = "spark.marklogic.read.numPartitions";
2424
String READ_BATCH_SIZE = "spark.marklogic.read.batchSize";
25+
String READ_PUSH_DOWN_AGGREGATES = "spark.marklogic.read.pushDownAggregates";
2526

2627
String WRITE_BATCH_SIZE = "spark.marklogic.write.batchSize";
2728
String WRITE_THREAD_COUNT = "spark.marklogic.write.threadCount";

src/main/java/com/marklogic/spark/reader/MarkLogicScanBuilder.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
*/
1616
package com.marklogic.spark.reader;
1717

18+
import com.marklogic.spark.Options;
1819
import com.marklogic.spark.reader.filter.FilterFactory;
1920
import com.marklogic.spark.reader.filter.OpticFilter;
2021
import org.apache.spark.sql.connector.expressions.SortOrder;
22+
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc;
2123
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation;
2224
import org.apache.spark.sql.connector.expressions.aggregate.Avg;
2325
import org.apache.spark.sql.connector.expressions.aggregate.Count;
@@ -49,10 +51,10 @@ public class MarkLogicScanBuilder implements ScanBuilder, SupportsPushDownFilter
4951

5052
private final static Logger logger = LoggerFactory.getLogger(MarkLogicScanBuilder.class);
5153

52-
private ReadContext readContext;
54+
private final ReadContext readContext;
5355
private List<Filter> pushedFilters;
5456

55-
private final static Set<Class> SUPPORTED_AGGREGATE_FUNCTIONS = new HashSet() {{
57+
private final static Set<Class<? extends AggregateFunc>> SUPPORTED_AGGREGATE_FUNCTIONS = new HashSet() {{
5658
add(Avg.class);
5759
add(Count.class);
5860
add(CountStar.class);
@@ -164,7 +166,7 @@ public boolean isPartiallyPushed() {
164166
*/
165167
@Override
166168
public boolean supportCompletePushDown(Aggregation aggregation) {
167-
if (readContext.planAnalysisFoundNoRows()) {
169+
if (readContext.planAnalysisFoundNoRows() || pushDownAggregatesIsDisabled()) {
168170
return false;
169171
}
170172

@@ -190,6 +192,12 @@ public boolean pushAggregation(Aggregation aggregation) {
190192
if (readContext.planAnalysisFoundNoRows() || hasUnsupportedAggregateFunction(aggregation)) {
191193
return false;
192194
}
195+
196+
if (pushDownAggregatesIsDisabled()) {
197+
logger.info("Push down of aggregates is disabled; Spark will handle all aggregations.");
198+
return false;
199+
}
200+
193201
logger.info("Pushing down aggregation: {}", describeAggregation(aggregation));
194202
readContext.pushDownAggregation(aggregation);
195203
return true;
@@ -224,4 +232,8 @@ private String describeAggregation(Aggregation aggregation) {
224232
Arrays.asList(aggregation.groupByExpressions()),
225233
Arrays.asList(aggregation.aggregateExpressions()));
226234
}
235+
236+
private boolean pushDownAggregatesIsDisabled() {
237+
return "false".equalsIgnoreCase(readContext.getProperties().get(Options.READ_PUSH_DOWN_AGGREGATES));
238+
}
227239
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package com.marklogic.spark.reader;
2+
3+
import com.marklogic.spark.Options;
4+
import org.apache.spark.sql.Row;
5+
import org.junit.jupiter.api.Test;
6+
7+
import java.util.List;
8+
9+
import static org.junit.jupiter.api.Assertions.assertEquals;
10+
11+
public class DisablePushDownAggregatesTest extends AbstractPushDownTest {
12+
13+
@Test
14+
void disabled() {
15+
List<Row> rows = newDefaultReader()
16+
.option(Options.READ_OPTIC_QUERY, QUERY_WITH_NO_QUALIFIER)
17+
.option(Options.READ_PUSH_DOWN_AGGREGATES, false)
18+
.load()
19+
.groupBy("CitationID")
20+
.avg("LuckyNumber")
21+
.orderBy("CitationID")
22+
.collectAsList();
23+
24+
assertEquals(5, rows.size());
25+
assertEquals(15, countOfRowsReadFromMarkLogic, "Because push down of aggregates is disabled, all 15 author " +
26+
"rows should have been read from MarkLogic.");
27+
28+
// Averages should still be calculated correctly by Spark.
29+
String columnName = "avg(LuckyNumber)";
30+
assertEquals(2.5, (double) rows.get(0).getAs(columnName));
31+
assertEquals(6.5, (double) rows.get(1).getAs(columnName));
32+
assertEquals(10.5, (double) rows.get(2).getAs(columnName));
33+
assertEquals(13.0, (double) rows.get(3).getAs(columnName));
34+
assertEquals(14.5, (double) rows.get(4).getAs(columnName));
35+
}
36+
}

0 commit comments

Comments
 (0)