Skip to content

Commit 5bdb785

Browse files
authored
Merge pull request #369 from marklogic/feature/batch-size-tweak
Tweaked calculation of buckets per partition
2 parents 48c5d6e + 8b42d4b commit 5bdb785

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

marklogic-spark-connector/src/main/java/com/marklogic/spark/reader/optic/PlanAnalyzer.java

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,32 @@ PlanAnalysis analyzePlan(AbstractWriteHandle userPlan, long userPartitionCount,
3434
return new PlanAnalysis((ObjectNode) viewInfo.get("modifiedPlan"), partitions);
3535
}
3636

37-
private List<PlanAnalysis.Partition> calculatePartitions(long rowCount, long userPartitionCount, long userBatchSize) {
37+
static List<PlanAnalysis.Partition> calculatePartitions(long rowCount, long userPartitionCount, long userBatchSize) {
3838
final long batchSize = userBatchSize > 0 ? userBatchSize : Long.parseLong("-1");
39-
long bucketCount = (rowCount / userPartitionCount) / batchSize;
40-
if (bucketCount < 1) {
41-
bucketCount = 1;
42-
}
39+
40+
long bucketsPerPartition = calculateBucketsPerPartition(rowCount, userPartitionCount, batchSize);
4341
long partitionSize = Long.divideUnsigned(-1, userPartitionCount);
4442
long nextLowerBound = 0;
4543

4644
List<PlanAnalysis.Partition> partitions = new ArrayList<>();
4745
for (int i = 1; i <= userPartitionCount; i++) {
4846
long upperBound = (i == userPartitionCount) ? -1 : nextLowerBound + partitionSize;
49-
partitions.add(new PlanAnalysis.Partition(i, nextLowerBound, upperBound, bucketCount, partitionSize));
47+
partitions.add(new PlanAnalysis.Partition(i, nextLowerBound, upperBound, bucketsPerPartition, partitionSize));
5048
nextLowerBound = nextLowerBound + partitionSize + 1;
5149
}
5250
return partitions;
5351
}
52+
53+
/**
54+
* The number of buckets per partition is always the same, as the random distribution of row IDs means we don't know
55+
* how rows will be distributed across buckets.
56+
*/
57+
private static long calculateBucketsPerPartition(long rowCount, long userPartitionCount, long batchSize) {
58+
double rawBucketsPerPartition = ((double) rowCount / userPartitionCount) / batchSize;
59+
// ceil is used here to ensure that given the batch size, a bucket typically will not have more rows in it
60+
// than the batch size. That's not guaranteed, as row IDs could have a distribution such that many rows are in
61+
// one particular bucket.
62+
long bucketsPerPartition = (long) Math.ceil(rawBucketsPerPartition);
63+
return bucketsPerPartition < 1 ? 1 : bucketsPerPartition;
64+
}
5465
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright © 2024 MarkLogic Corporation. All Rights Reserved.
3+
*/
4+
package com.marklogic.spark.reader.optic;
5+
6+
import org.junit.jupiter.params.ParameterizedTest;
7+
import org.junit.jupiter.params.provider.CsvSource;
8+
9+
import java.util.List;
10+
11+
import static org.junit.jupiter.api.Assertions.assertEquals;
12+
13+
class CalculatePartitionsTest {
14+
15+
@ParameterizedTest
16+
@CsvSource({
17+
"1,0,1,1",
18+
"2,0,2,2",
19+
"1,5000,1,2",
20+
"1,5001,1,2",
21+
"1,6666,1,2",
22+
"1,6667,1,2",
23+
"1,9999,1,2",
24+
"1,10000,1,1",
25+
"1,10001,1,1",
26+
"3,3000,3,6"
27+
})
28+
void test(long userPartitionCount, long batchSize, int expectedPartitionCount, int expectedBucketCount) {
29+
long rowCount = 10000;
30+
List<PlanAnalysis.Partition> partitions = PlanAnalyzer.calculatePartitions(rowCount, userPartitionCount, batchSize);
31+
int bucketCount = 0;
32+
for (PlanAnalysis.Partition partition : partitions) {
33+
bucketCount += partition.getBuckets().size();
34+
}
35+
36+
assertEquals(expectedPartitionCount, partitions.size(), "Unexpected number of partitions");
37+
assertEquals(expectedBucketCount, bucketCount, "Unexpected number of buckets");
38+
}
39+
}

0 commit comments

Comments
 (0)