Skip to content

Commit 8379ebe

Browse files
authored
Merge pull request #158 from marklogic/feature/12257-push-limit
MLE-12257 Pushing down limit
2 parents b4ec55b + c5bdae5 commit 8379ebe

File tree

6 files changed

+97
-21
lines changed

6 files changed

+97
-21
lines changed

src/main/java/com/marklogic/spark/reader/document/DocumentBatch.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,17 @@
88
import org.apache.spark.sql.connector.read.Batch;
99
import org.apache.spark.sql.connector.read.InputPartition;
1010
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
11-
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
1211
import org.slf4j.Logger;
1312
import org.slf4j.LoggerFactory;
1413

1514
class DocumentBatch implements Batch {
1615

1716
private static final Logger logger = LoggerFactory.getLogger(DocumentBatch.class);
1817

19-
private DocumentContext context;
18+
private final DocumentContext context;
2019

21-
DocumentBatch(CaseInsensitiveStringMap options) {
22-
this.context = new DocumentContext(options);
20+
DocumentBatch(DocumentContext context) {
21+
this.context = context;
2322
}
2423

2524
/**

src/main/java/com/marklogic/spark/reader/document/DocumentContext.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
class DocumentContext extends ContextSupport {
1515

16+
private Integer limit;
17+
1618
DocumentContext(CaseInsensitiveStringMap options) {
1719
super(options.asCaseSensitiveMap());
1820
}
@@ -72,4 +74,12 @@ int getPartitionsPerForest() {
7274
int defaultPartitionsPerForest = 4;
7375
return (int) getNumericOption(Options.READ_DOCUMENTS_PARTITIONS_PER_FOREST, defaultPartitionsPerForest, 1);
7476
}
77+
78+
void setLimit(Integer limit) {
79+
this.limit = limit;
80+
}
81+
82+
Integer getLimit() {
83+
return limit;
84+
}
7585
}

src/main/java/com/marklogic/spark/reader/document/DocumentScan.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
import org.apache.spark.sql.connector.read.Batch;
44
import org.apache.spark.sql.connector.read.Scan;
55
import org.apache.spark.sql.types.StructType;
6-
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
76

87
class DocumentScan implements Scan {
98

10-
private CaseInsensitiveStringMap options;
9+
private final DocumentContext context;
1110

12-
DocumentScan(CaseInsensitiveStringMap options) {
13-
this.options = options;
11+
DocumentScan(DocumentContext context) {
12+
this.context = context;
1413
}
1514

1615
@Override
@@ -20,6 +19,6 @@ public StructType readSchema() {
2019

2120
@Override
2221
public Batch toBatch() {
23-
return new DocumentBatch(options);
22+
return new DocumentBatch(context);
2423
}
2524
}

src/main/java/com/marklogic/spark/reader/document/DocumentScanBuilder.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,33 @@
22

33
import org.apache.spark.sql.connector.read.Scan;
44
import org.apache.spark.sql.connector.read.ScanBuilder;
5+
import org.apache.spark.sql.connector.read.SupportsPushDownLimit;
56
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
67

7-
class DocumentScanBuilder implements ScanBuilder {
8+
class DocumentScanBuilder implements ScanBuilder, SupportsPushDownLimit {
89

9-
private CaseInsensitiveStringMap options;
10+
private final DocumentContext context;
1011

1112
DocumentScanBuilder(CaseInsensitiveStringMap options) {
12-
this.options = options;
13+
this.context = new DocumentContext(options);
1314
}
1415

1516
@Override
1617
public Scan build() {
17-
return new DocumentScan(options);
18+
return new DocumentScan(context);
19+
}
20+
21+
@Override
22+
public boolean pushLimit(int limit) {
23+
this.context.setLimit(limit);
24+
return true;
25+
}
26+
27+
@Override
28+
public boolean isPartiallyPushed() {
29+
// A partition reader can only ensure that it doesn't exceed the limit. In a worst case scenario, every reader
30+
// will return "limit" rows. So must return true here to ensure that Spark reduces the dataset to the
31+
// appropriate limit.
32+
return true;
1833
}
1934
}

src/main/java/com/marklogic/spark/reader/document/ForestReader.java

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,31 +42,32 @@ class ForestReader implements PartitionReader<InternalRow> {
4242
private final StructuredQueryBuilder queryBuilder;
4343
private final Set<DocumentManager.Metadata> requestedMetadata;
4444
private final boolean contentWasRequested;
45+
private final Integer limit;
4546

4647
// Only used for logging.
4748
private final ForestPartition forestPartition;
4849
private long startTime;
4950

5051
private DocumentPage currentDocumentPage;
5152

53+
// Used for logging and for ensuring a non-null limit is not exceeded.
5254
private int docCount;
5355

54-
ForestReader(ForestPartition forestPartition, DocumentContext documentContext) {
56+
ForestReader(ForestPartition forestPartition, DocumentContext context) {
5557
if (logger.isDebugEnabled()) {
5658
logger.debug("Will read from partition: {}", forestPartition);
5759
}
5860
this.forestPartition = forestPartition;
61+
this.limit = context.getLimit();
5962

60-
DatabaseClient client = documentContext.connectToMarkLogic();
61-
62-
SearchQueryDefinition query = documentContext.buildSearchQuery(client);
63-
int batchSize = documentContext.getBatchSize();
64-
this.uriBatcher = new UriBatcher(client, query, forestPartition, batchSize, false);
63+
DatabaseClient client = context.connectToMarkLogic();
64+
SearchQueryDefinition query = context.buildSearchQuery(client);
65+
this.uriBatcher = new UriBatcher(client, query, forestPartition, context.getBatchSize(), false);
6566

6667
this.documentManager = client.newDocumentManager();
6768
this.documentManager.setReadTransform(query.getResponseTransform());
68-
this.contentWasRequested = documentContext.contentWasRequested();
69-
this.requestedMetadata = documentContext.getRequestedMetadata();
69+
this.contentWasRequested = context.contentWasRequested();
70+
this.requestedMetadata = context.getRequestedMetadata();
7071
this.documentManager.setMetadataCategories(this.requestedMetadata);
7172
this.queryBuilder = client.newQueryManager().newStructuredQueryBuilder();
7273
}
@@ -76,6 +77,13 @@ public boolean next() {
7677
if (startTime == 0) {
7778
startTime = System.currentTimeMillis();
7879
}
80+
81+
if (limit != null && docCount >= limit) {
82+
// No logging here as this block may never be hit, depending on whether Spark first detects that the limit
83+
// has been reached.
84+
return false;
85+
}
86+
7987
if (currentDocumentPage == null || !currentDocumentPage.hasNext()) {
8088
closeCurrentDocumentPage();
8189
List<String> uris = getNextBatchOfUris();
@@ -89,6 +97,7 @@ public boolean next() {
8997
}
9098
this.currentDocumentPage = readPage(uris);
9199
}
100+
92101
return currentDocumentPage.hasNext();
93102
}
94103

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package com.marklogic.spark.reader.document;
2+
3+
import com.marklogic.spark.AbstractIntegrationTest;
4+
import com.marklogic.spark.Options;
5+
import org.apache.spark.sql.Dataset;
6+
import org.apache.spark.sql.Row;
7+
import org.junit.jupiter.api.Test;
8+
9+
import static org.junit.jupiter.api.Assertions.assertEquals;
10+
import static org.junit.jupiter.api.Assertions.assertTrue;
11+
12+
class PushDownLimitTest extends AbstractIntegrationTest {
13+
14+
@Test
15+
void two() {
16+
long count = readAuthors().limit(2).count();
17+
assertTrue(count <= 6, "With a limit of 2, each reader should read at most 2 docs; they can't do " +
18+
"any fewer than that because each one has no idea how many documents any other reader will get. " +
19+
"Unexpected count: " + count);
20+
}
21+
22+
@Test
23+
void zero() {
24+
long count = readAuthors().limit(0).count();
25+
assertEquals(0, count);
26+
}
27+
28+
@Test
29+
void limitIsMoreThanTotal() {
30+
long count = readAuthors().limit(20).count();
31+
assertEquals(15, count, "A limit greater than then number of matching documents has no impact on the results.");
32+
}
33+
34+
private Dataset<Row> readAuthors() {
35+
return newSparkSession().read()
36+
.format(CONNECTOR_IDENTIFIER)
37+
.option(Options.CLIENT_URI, makeClientUri())
38+
.option(Options.READ_DOCUMENTS_COLLECTIONS, "author")
39+
// Using a single partition to increase the chance that a reader will hit the limit.
40+
.option(Options.READ_DOCUMENTS_PARTITIONS_PER_FOREST, 1)
41+
.load();
42+
}
43+
44+
}

0 commit comments

Comments
 (0)