Skip to content

Commit 07879a0

Browse files
committed
Add optimal partitioning
1 parent 0322dae commit 07879a0

11 files changed

+449
-40
lines changed

README.md

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,25 +92,48 @@ package.
9292
sort.Sort(lshensemble.BySize(domainRecords))
9393
```
9494

95-
Now you can use `BootstrapLshEnsemble`
96-
(or `BootstrapLshEnsemblePlus` for better accuracy at higher memory cost\*)
97-
to create an LSH Ensemble index. You need to
95+
Now you can use `BootstrapLshEnsembleOptimal`/`BootstrapLshEnsembleEquiDepth`
96+
(or `BootstrapLshEnsemblePlusOptimal`/`BootstrapLshEnsemblePlusEquiDepth`)
97+
for better accuracy at higher memory cost\*)
98+
to create an LSH Ensemble index.
99+
`BootstrapLshEnsembleOptimal` uses dynamic programming to create partitions that
100+
are optimal in the sense that the total number of false positives generated from
101+
all the partitions are minimized. This method can be
102+
a bit slower due to the dynamic programming overhead, however, it creates
103+
optimized partitions for any kind of data distribution.
104+
`BootstrapLshEnsembleEquiDepth` uses simple equi-depth -- same number of domains
105+
in every partition. This is method is described in the original
106+
[paper](http://www.vldb.org/pvldb/vol9/p1185-zhu.pdf) as suitable for power-law
107+
distributed domain sizes, which is common in real-world domains.
108+
You need to
98109
specify the number of partitions to use and some other parameters.
99110
The LSH parameter K (number of hash functions per band) is dynamically tuned at query-time,
100111
but the maximum value should be specified here.
101112

102-
\* See [explanation](#maxk-explanation) for the difference.
113+
\* See [explanation](#maxk-explanation) for the reason for the "Plus" version.
103114

104115
```go
105-
// set the number of partitions
116+
// Set the number of partitions
106117
numPart := 8
107118

108-
// set the maximum value for the MinHash LSH parameter K
119+
// Set the maximum value for the MinHash LSH parameter K
109120
// (number of hash functions per band).
110121
maxK := 4
111122

112-
// create index, you can also use BootstrapLshEnsemblePlus for better accuracy
113-
index, err := lshensemble.BootstrapLshEnsemble(numPart, numHash, maxK, len(domainRecords), lshensemble.Recs2Chan(domainRecords))
123+
// Create index using equi-depth partitioning
124+
// You can also use BootstrapLshEnsemblePlusEquiDepth for better accuracy
125+
index_eqd, err := lshensemble.BootstrapLshEnsembleEquiDepth(numPart, numHash, maxK,
126+
len(domainRecords), lshensemble.Recs2Chan(domainRecords))
127+
if err != nil {
128+
panic(err)
129+
}
130+
131+
// Create index using optimal partitioning
132+
// You can also use BootstrapLshEnsemblePlusOptimal for better accuracy
133+
index_opt, err := lshensemble.BootstrapLshEnsembleOptimal(numPart, numHash, maxK,
134+
func () <-chan *lshensemble.DomainRecord {
135+
return lshensemble.Recs2Chan(domainRecords);
136+
})
114137
if err != nil {
115138
panic(err)
116139
}
@@ -119,7 +142,7 @@ if err != nil {
119142
For better memory efficiency when the number of domains is large,
120143
it's wiser to use Golang channels and goroutines
121144
to pipeline the generation of the signatures, and then use disk-based sorting to sort the domain records.
122-
This is why `BootstrapLshEnsemble` accepts a channel of `*DomainRecord` as input.
145+
This is why `BootstrapLshEnsembleEquiDepth` accepts a channel of `*DomainRecord` as input.
123146
For a small number of domains, you simply use `Recs2Chan` to convert the sorted slice of `*DomainRecord`
124147
into a `chan *DomainRecord`.
125148
To help serializing the domain records to disk, you can use `SerializeSignature`
@@ -223,9 +246,11 @@ Essentially, we have less freedom in varying `L`, as
223246
In this library for LSH Ensemble, we provide both implmentations
224247
(LSH Forest and "vanilla" MinHash LSH ).
225248
Specifically,
226-
* `BootstrapLshEnsemble` builds the index using the LSH Forest implementation,
249+
* `BootstrapLshEnsembleEquiDepth` and `BootstrapLshEnsembleOptimal`
250+
build the index using the LSH Forest implementation,
227251
which use less memory but with a more restricted parameter space for optimization.
228-
* `BootstrapLshEnsemblePlus` builds the index using the "vanilla" MinHash LSH
252+
* `BootstrapLshEnsemblePlusEquiDepth` and `BootstrapLshEnsemblePlusOptimal`
253+
build the index using the "vanilla" MinHash LSH
229254
implementation (one LSH for every `K`), which uses more memory (bounded by `MaxK`)
230255
but with no restriction on `L`.
231256

accuracy_benchmark_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"time"
1111
)
1212

13-
func benchmark_accuracy(groundTruthFilename, queryResultFilename, outputFilename string) {
13+
func benchmarkAccuracy(groundTruthFilename, queryResultFilename, outputFilename string) {
1414
groundTruths := readQueryResultFile(groundTruthFilename)
1515
queryResults := readQueryResultFile(queryResultFilename)
1616
precisions := make([]float64, 0)
@@ -67,7 +67,7 @@ func recallPrecision(result, groundTruth queryResult) (recall, precision float64
6767
overlap := 0
6868
for id := range test {
6969
if _, found := truth[id]; found {
70-
overlap += 1
70+
overlap++
7171
}
7272
}
7373
recall = float64(overlap) / float64(len(truth))
@@ -82,6 +82,7 @@ func readQueryResultFile(queryResultFile string) []queryResult {
8282
panic(err)
8383
}
8484
scanner := bufio.NewScanner(file)
85+
scanner.Buffer(nil, 4096*1024*1024)
8586
for scanner.Scan() {
8687
raw := strings.Split(scanner.Text(), "\t")
8788
key := raw[0]

bootstrap.go

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,85 @@ package lshensemble
33
import "errors"
44

55
var (
6-
ErrDomainSizeOrder = errors.New("Domain records must be sorted in ascending order of size")
6+
errDomainSizeOrder = errors.New("Domain records must be sorted in ascending order of size")
77
)
88

9-
func bootstrap(index *LshEnsemble, totalNumDomains int, sortedDomains <-chan *DomainRecord) error {
9+
func bootstrapOptimalPartitions(domains <-chan *DomainRecord, numPart int) []Partition {
10+
sizes, counts := computeSizeDistribution(domains)
11+
partitions := optimalPartitions(sizes, counts, numPart)
12+
return partitions
13+
}
14+
15+
func bootstrapOptimal(index *LshEnsemble, sortedDomains <-chan *DomainRecord) error {
16+
var currPart int
17+
var currSize int
18+
for rec := range sortedDomains {
19+
if currSize > rec.Size {
20+
return errDomainSizeOrder
21+
}
22+
currSize = rec.Size
23+
if currSize > index.Partitions[currPart].Upper {
24+
currPart++
25+
}
26+
if currPart >= len(index.Partitions) ||
27+
!(index.Partitions[currPart].Lower <= currSize &&
28+
currSize <= index.Partitions[currPart].Upper) {
29+
return errors.New("Domain records does not match the existing partitions")
30+
}
31+
index.Add(rec.Key, rec.Signature, currPart)
32+
}
33+
index.Index()
34+
return nil
35+
}
36+
37+
// BootstrapLshEnsembleOptimal builds an index from domains using optimal
38+
// partitioning.
39+
// The returned index consists of MinHash LSH implemented using LshForest.
40+
// numPart is the number of partitions to create.
41+
// numHash is the number of hash functions in MinHash.
42+
// maxK is the maximum value for the MinHash parameter K - the number of hash
43+
// functions per "band".
44+
// sortedDomainFactory is factory function that returns a DomainRecord channel
45+
// emitting domains in sorted order by their sizes.
46+
func BootstrapLshEnsembleOptimal(numPart, numHash, maxK int,
47+
sortedDomainFactory func() <-chan *DomainRecord) (*LshEnsemble, error) {
48+
partitions := bootstrapOptimalPartitions(sortedDomainFactory(), numPart)
49+
index := NewLshEnsemble(partitions, numHash, maxK)
50+
err := bootstrapOptimal(index, sortedDomainFactory())
51+
if err != nil {
52+
return nil, err
53+
}
54+
return index, nil
55+
}
56+
57+
// BootstrapLshEnsemblePlusOptimal builds an index from domains using optimal
58+
// partitioning.
59+
// The returned index consists of MinHash LSH implemented using LshForestArray.
60+
// numPart is the number of partitions to create.
61+
// numHash is the number of hash functions in MinHash.
62+
// maxK is the maximum value for the MinHash parameter K - the number of hash
63+
// functions per "band".
64+
// sortedDomainFactory is factory function that returns a DomainRecord channel
65+
// emitting domains in sorted order by their sizes.
66+
func BootstrapLshEnsemblePlusOptimal(numPart, numHash, maxK int,
67+
sortedDomainFactory func() <-chan *DomainRecord) (*LshEnsemble, error) {
68+
partitions := bootstrapOptimalPartitions(sortedDomainFactory(), numPart)
69+
index := NewLshEnsemblePlus(partitions, numHash, maxK)
70+
err := bootstrapOptimal(index, sortedDomainFactory())
71+
if err != nil {
72+
return nil, err
73+
}
74+
return index, nil
75+
}
76+
77+
func bootstrapEquiDepth(index *LshEnsemble, totalNumDomains int, sortedDomains <-chan *DomainRecord) error {
1078
numPart := len(index.Partitions)
1179
depth := totalNumDomains / numPart
1280
var currDepth, currPart int
1381
var currSize int
1482
for rec := range sortedDomains {
1583
if currSize > rec.Size {
16-
return ErrDomainSizeOrder
84+
return errDomainSizeOrder
1785
}
1886
currSize = rec.Size
1987
index.Add(rec.Key, rec.Signature, currPart)
@@ -29,30 +97,36 @@ func bootstrap(index *LshEnsemble, totalNumDomains int, sortedDomains <-chan *Do
2997
return nil
3098
}
3199

32-
// BoostrapLshEnsemble builds an index from a channel of domains.
100+
// BootstrapLshEnsembleEquiDepth builds an index from a channel of domains
101+
// using equi-depth partitions -- partitions have approximately the same
102+
// number of domains.
33103
// The returned index consists of MinHash LSH implemented using LshForest.
34104
// numPart is the number of partitions to create.
35105
// numHash is the number of hash functions in MinHash.
36106
// maxK is the maximum value for the MinHash parameter K - the number of hash functions per "band".
37107
// sortedDomains is a DomainRecord channel emitting domains in sorted order by their sizes.
38-
func BootstrapLshEnsemble(numPart, numHash, maxK, totalNumDomains int, sortedDomains <-chan *DomainRecord) (*LshEnsemble, error) {
108+
func BootstrapLshEnsembleEquiDepth(numPart, numHash, maxK, totalNumDomains int,
109+
sortedDomains <-chan *DomainRecord) (*LshEnsemble, error) {
39110
index := NewLshEnsemble(make([]Partition, numPart), numHash, maxK)
40-
err := bootstrap(index, totalNumDomains, sortedDomains)
111+
err := bootstrapEquiDepth(index, totalNumDomains, sortedDomains)
41112
if err != nil {
42113
return nil, err
43114
}
44115
return index, nil
45116
}
46117

47-
// BoostrapLshEnsemblePlus builds an index from a channel of domains.
118+
// BootstrapLshEnsemblePlusEquiDepth builds an index from a channel of domains
119+
// using equi-depth partitions -- partitions have approximately the same
120+
// number of domains.
48121
// The returned index consists of MinHash LSH implemented using LshForestArray.
49122
// numPart is the number of partitions to create.
50123
// numHash is the number of hash functions in MinHash.
51124
// maxK is the maximum value for the MinHash parameter K - the number of hash functions per "band".
52125
// sortedDomains is a DomainRecord channel emitting domains in sorted order by their sizes.
53-
func BootstrapLshEnsemblePlus(numPart, numHash, maxK, totalNumDomains int, sortedDomains <-chan *DomainRecord) (*LshEnsemble, error) {
126+
func BootstrapLshEnsemblePlusEquiDepth(numPart, numHash, maxK,
127+
totalNumDomains int, sortedDomains <-chan *DomainRecord) (*LshEnsemble, error) {
54128
index := NewLshEnsemblePlus(make([]Partition, numPart), numHash, maxK)
55-
err := bootstrap(index, totalNumDomains, sortedDomains)
129+
err := bootstrapEquiDepth(index, totalNumDomains, sortedDomains)
56130
if err != nil {
57131
return nil, err
58132
}

cod_benchmark_test.go

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@ import (
1313
"time"
1414
)
1515

16-
const benchmarkSeed = 42
17-
const fracQuery = 0.01
18-
const minDomainSize = 10
16+
const (
17+
benchmarkSeed = 42
18+
fracQuery = 0.01
19+
minDomainSize = 10
20+
)
21+
22+
var (
23+
thresholds = []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}
24+
)
1925

2026
// Running this function requires a `_cod_domains` directory
2127
// in the current directory.
@@ -49,17 +55,19 @@ func Benchmark_CanadianOpenData(b *testing.B) {
4955
}
5056

5157
// Run benchmark
52-
log.Printf("Canadian Open Data benchmark threshold = %.2f", 0.5)
53-
benchmark_cod(rawDomains, queries, 0.5)
58+
for _, threshold := range thresholds {
59+
log.Printf("Canadian Open Data benchmark threshold = %.2f", threshold)
60+
benchmarkCOD(rawDomains, queries, threshold)
61+
}
5462
}
5563

56-
func benchmark_cod(rawDomains, queries []rawDomain, threshold float64) {
57-
linearscan_output := fmt.Sprintf("_cod_linearscan_threshold=%.2f", threshold)
58-
lshensemble_output := fmt.Sprintf("_cod_lshensemble_threshold=%.2f", threshold)
59-
accuracy_output := fmt.Sprintf("_cod_accuracy_threhsold=%.2f", threshold)
60-
benchmark_linearscan(rawDomains, queries, threshold, linearscan_output)
61-
benchmark_lshensemble(rawDomains, queries, threshold, lshensemble_output)
62-
benchmark_accuracy(linearscan_output, lshensemble_output, accuracy_output)
64+
func benchmarkCOD(rawDomains, queries []rawDomain, threshold float64) {
65+
linearscanOutput := fmt.Sprintf("_cod_linearscan_threshold=%.2f", threshold)
66+
lshensembleOutput := fmt.Sprintf("_cod_lshensemble_threshold=%.2f", threshold)
67+
accuracyOutput := fmt.Sprintf("_cod_accuracy_threshold=%.2f", threshold)
68+
benchmarkLinearscan(rawDomains, queries, threshold, linearscanOutput)
69+
benchmarkLshEnsemble(rawDomains, queries, threshold, lshensembleOutput)
70+
benchmarkAccuracy(linearscanOutput, lshensembleOutput, accuracyOutput)
6371
}
6472

6573
type rawDomain struct {

linearscan_benchmark_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import (
55
"time"
66
)
77

8-
func benchmark_linearscan(rawDomains []rawDomain, queries []rawDomain, threshold float64, outputFilename string) {
8+
func benchmarkLinearscan(rawDomains []rawDomain, queries []rawDomain,
9+
threshold float64, outputFilename string) {
910
log.Printf("Start Linear Scan with %d queries", len(queries))
1011
results := make(chan queryResult)
1112
go func() {

lshensemble.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package lshensemble
22

33
import (
4+
"errors"
45
"fmt"
56
"sync"
67
"time"
@@ -87,6 +88,19 @@ func (e *LshEnsemble) Add(key interface{}, sig []uint64, partInd int) {
8788
e.lshes[partInd].Add(key, sig)
8889
}
8990

91+
// Prepare adds a new domain to the index given its size, and partition will
92+
// be selected automatically. It could be more efficient to use Add().
93+
// The added domain won't be searchable until the Index() function is called.
94+
func (e *LshEnsemble) Prepare(key interface{}, sig []uint64, size int) error {
95+
for i := range e.Partitions {
96+
if size >= e.Partitions[i].Lower && size <= e.Partitions[i].Upper {
97+
e.Add(key, sig, i)
98+
break
99+
}
100+
}
101+
return errors.New("No matching partition found")
102+
}
103+
90104
// Index makes all added domains searchable.
91105
func (e *LshEnsemble) Index() {
92106
for i := range e.lshes {

lshensemble_benchmark_test.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,16 @@ import (
66
"time"
77
)
88

9-
func benchmark_lshensemble(rawDomains []rawDomain, rawQueries []rawDomain, threshold float64, outputFilename string) {
9+
const (
10+
numHash = 256
11+
numPart = 32
12+
maxK = 4
13+
// useOptimalPartitions = true
14+
useOptimalPartitions = false
15+
)
16+
17+
func benchmarkLshEnsemble(rawDomains []rawDomain, rawQueries []rawDomain,
18+
threshold float64, outputFilename string) {
1019
numHash := 256
1120
numPart := 32
1221
maxK := 4
@@ -27,8 +36,14 @@ func benchmark_lshensemble(rawDomains []rawDomain, rawQueries []rawDomain, thres
2736
// Indexing
2837
log.Print("Start building LSH Ensemble index")
2938
sort.Sort(BySize(domainRecords))
30-
index, _ := BootstrapLshEnsemblePlus(numPart, numHash, maxK, len(domainRecords),
31-
Recs2Chan(domainRecords))
39+
var index *LshEnsemble
40+
if useOptimalPartitions {
41+
index, _ = BootstrapLshEnsemblePlusOptimal(numPart, numHash, maxK,
42+
func() <-chan *DomainRecord { return Recs2Chan(domainRecords) })
43+
} else {
44+
index, _ = BootstrapLshEnsemblePlusEquiDepth(numPart, numHash, maxK,
45+
len(domainRecords), Recs2Chan(domainRecords))
46+
}
3247
log.Print("Finished building LSH Ensemble index")
3348
// Querying
3449
log.Printf("Start querying LSH Ensemble index with %d queries", len(queries))

0 commit comments

Comments
 (0)