Skip to content

feat(core): Add new index PartitionedHNSW for vectors #9469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 290 additions & 0 deletions posting/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"github.com/hypermodeinc/dgraph/v25/schema"
"github.com/hypermodeinc/dgraph/v25/tok"
"github.com/hypermodeinc/dgraph/v25/tok/hnsw"
tokIndex "github.com/hypermodeinc/dgraph/v25/tok/index"

"github.com/hypermodeinc/dgraph/v25/types"
"github.com/hypermodeinc/dgraph/v25/x"
)
Expand Down Expand Up @@ -1361,6 +1363,291 @@
return prefixes, nil
}

const numCentroids = 1000

Check failure on line 1366 in posting/index.go

View workflow job for this annotation

GitHub Actions / Trunk Check

golangci-lint2(unused)

[new] const numCentroids is unused

func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSpec, rb *IndexRebuild) error {
pk := x.ParsedKey{Attr: rb.Attr}

indexer, err := factorySpecs[0].CreateIndex(pk.Attr)
if err != nil {
return err
}

numVectorsToCheck := 100
lenFreq := make(map[int]int, numVectorsToCheck)
maxFreq := 0
dimension := 0
MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{

Check failure on line 1380 in posting/index.go

View workflow job for this annotation

GitHub Actions / Trunk Check

golangci-lint2(errcheck)

[new] Error return value of `MemLayerInstance.IterateDisk` is not checked
Prefix: pk.DataPrefix(),
ReadTs: rb.StartTs,
AllVersions: false,
Reverse: false,
CheckInclusion: func(uid uint64) error {
return nil
},
Function: func(l *List, pk x.ParsedKey) error {
val, err := l.Value(rb.StartTs)
if err != nil {
return err
}
inVec := types.BytesAsFloatArray(val.Value.([]byte))
lenFreq[len(inVec)] += 1
if lenFreq[len(inVec)] > maxFreq {
maxFreq = lenFreq[len(inVec)]
dimension = len(inVec)
}
numVectorsToCheck -= 1
if numVectorsToCheck <= 0 {
return ErrStopIteration
}
return nil
},
StartKey: x.DataKey(rb.Attr, 0),
})

fmt.Println("Selecting vector dimension to be:", dimension)

if indexer.NumSeedVectors() > 0 {
count := 0
MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{

Check failure on line 1412 in posting/index.go

View workflow job for this annotation

GitHub Actions / Trunk Check

golangci-lint2(errcheck)

[new] Error return value of `MemLayerInstance.IterateDisk` is not checked
Prefix: pk.DataPrefix(),
ReadTs: rb.StartTs,
AllVersions: false,
Reverse: false,
CheckInclusion: func(uid uint64) error {
return nil
},
Function: func(l *List, pk x.ParsedKey) error {
val, err := l.Value(rb.StartTs)
if err != nil {
return err
}
inVec := types.BytesAsFloatArray(val.Value.([]byte))
if len(inVec) != dimension {
return nil
}
count += 1
indexer.AddSeedVector(inVec)
if count == indexer.NumSeedVectors() {
return ErrStopIteration
}
return nil
},
StartKey: x.DataKey(rb.Attr, 0),
})
}

txns := make([]*Txn, indexer.NumThreads())
for i := range txns {
txns[i] = NewTxn(rb.StartTs)
}
caches := make([]tokIndex.CacheType, indexer.NumThreads())
for i := range caches {
caches[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs)
}

for pass_idx := range indexer.NumBuildPasses() {
fmt.Println("Building pass", pass_idx)

indexer.StartBuild(caches)

builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
edges := []*pb.DirectedEdge{}
val, err := pl.Value(rb.StartTs)
if err != nil {
return []*pb.DirectedEdge{}, err
}

inVec := types.BytesAsFloatArray(val.Value.([]byte))
if len(inVec) != dimension {
return []*pb.DirectedEdge{}, nil
}
indexer.BuildInsert(ctx, uid, inVec)

Check failure on line 1466 in posting/index.go

View workflow job for this annotation

GitHub Actions / Trunk Check

golangci-lint2(errcheck)

[new] Error return value of `indexer.BuildInsert` is not checked
return edges, nil
}

err := builder.RunWithoutTemp(ctx)
if err != nil {
return err
}

indexer.EndBuild()
}

for pass_idx := range indexer.NumIndexPasses() {
fmt.Println("Indexing pass", pass_idx)

indexer.StartBuild(caches)

builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
edges := []*pb.DirectedEdge{}
val, err := pl.Value(rb.StartTs)
if err != nil {
return []*pb.DirectedEdge{}, err
}

inVec := types.BytesAsFloatArray(val.Value.([]byte))
if len(inVec) != dimension {
if pass_idx == 0 {
glog.Warningf("Skipping vector with invalid dimension uid: %d, dimension: %d", uid, len(inVec))
}
return []*pb.DirectedEdge{}, nil
}
indexer.BuildInsert(ctx, uid, inVec)

Check failure on line 1498 in posting/index.go

View workflow job for this annotation

GitHub Actions / Trunk Check

golangci-lint2(errcheck)

[new] Error return value of `indexer.BuildInsert` is not checked
return edges, nil
}

err := builder.RunWithoutTemp(ctx)
if err != nil {
return err
}

for _, idx := range indexer.EndBuild() {
txns[idx].Update()
writer := NewTxnWriter(pstore)

x.ExponentialRetry(int(x.Config.MaxRetries),

Check failure on line 1511 in posting/index.go

View workflow job for this annotation

GitHub Actions / Trunk Check

golangci-lint2(errcheck)

[new] Error return value of `x.ExponentialRetry` is not checked
20*time.Millisecond, func() error {
err := txns[idx].CommitToDisk(writer, rb.StartTs)
if err == badger.ErrBannedKey {
glog.Errorf("Error while writing to banned namespace.")
return nil
}
return err
})

txns[idx].cache.plists = nil
txns[idx] = nil
}
}

return nil

// MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
// Prefix: pk.DataPrefix(),
// ReadTs: rb.StartTs,
// AllVersions: false,
// Reverse: false,
// CheckInclusion: func(uid uint64) error {
// return nil
// },
// Function: func(l *List, pk x.ParsedKey) error {
// val, err := l.Value(rb.StartTs)
// if err != nil {
// return err
// }
// inVec := types.BytesAsFloatArray(val.Value.([]byte))
// vc.addSeedCentroid(inVec)
// if len(vc.centroids) == numCentroids {
// return ErrStopIteration
// }
// return nil
// },
// StartKey: x.DataKey(rb.Attr, 0),
// })

// vc.randomInit()

// fmt.Println("Clustering Vectors")
// for range 5 {
// builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
// builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
// edges := []*pb.DirectedEdge{}
// val, err := pl.Value(txn.StartTs)
// if err != nil {
// return []*pb.DirectedEdge{}, err
// }

// inVec := types.BytesAsFloatArray(val.Value.([]byte))
// vc.addVector(inVec)
// return edges, nil
// }

// err := builder.RunWithoutTemp(ctx)
// if err != nil {
// return err
// }

// vc.updateCentroids()
// }

// tcs := make([]*hnsw.TxnCache, vc.numCenters)
// txns := make([]*Txn, vc.numCenters)
// indexers := make([]index.VectorIndex[float32], vc.numCenters)
// for i := 0; i < vc.numCenters; i++ {
// txns[i] = NewTxn(rb.StartTs)
// tcs[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs)
// indexers_i, err := factorySpecs[0].CreateIndex(pk.Attr, i)
// if err != nil {
// return err
// }
// vc.mutexs[i] = &sync.Mutex{}
// indexers[i] = indexers_i
// }

// var edgesCreated atomic.Int64

// numPasses := vc.numCenters / 100
// for pass_idx := range numPasses {
// builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
// builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
// val, err := pl.Value(txn.StartTs)
// if err != nil {
// return []*pb.DirectedEdge{}, err
// }

// inVec := types.BytesAsFloatArray(val.Value.([]byte))
// idx := vc.findCentroid(inVec)
// if idx%numPasses != pass_idx {
// return []*pb.DirectedEdge{}, nil
// }
// vc.mutexs[idx].Lock()
// defer vc.mutexs[idx].Unlock()
// _, err = indexers[idx].Insert(ctx, tcs[idx], uid, inVec)
// if err != nil {
// return []*pb.DirectedEdge{}, err
// }

// edgesCreated.Add(int64(1))
// return nil, nil
// }

// err := builder.RunWithoutTemp(ctx)
// if err != nil {
// return err
// }

// for idx := range vc.counts {
// if idx%numPasses != pass_idx {
// continue
// }
// txns[idx].Update()
// writer := NewTxnWriter(pstore)

// x.ExponentialRetry(int(x.Config.MaxRetries),
// 20*time.Millisecond, func() error {
// err := txns[idx].CommitToDisk(writer, rb.StartTs)
// if err == badger.ErrBannedKey {
// glog.Errorf("Error while writing to banned namespace.")
// return nil
// }
// return err
// })

// txns[idx].cache.plists = nil
// txns[idx] = nil
// tcs[idx] = nil
// indexers[idx] = nil
// }

// fmt.Printf("Created %d edges in pass %d out of %d\n", edgesCreated.Load(), pass_idx, numPasses)
// }

// return nil
}

// rebuildTokIndex rebuilds index for a given attribute.
// We commit mutations with startTs and ignore the errors.
func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error {
Expand Down Expand Up @@ -1392,6 +1679,9 @@
}

runForVectors := (len(factorySpecs) != 0)
if runForVectors {
return rebuildVectorIndex(ctx, factorySpecs, rb)
}

pk := x.ParsedKey{Attr: rb.Attr}
builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
Expand Down
2 changes: 1 addition & 1 deletion schema/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ func parseTokenOrVectorIndexSpec(
tokenizer, has := tok.GetTokenizer(tokenOrFactoryName)
if !has {
return tokenOrFactoryName, nil, false,
next.Errorf("Invalid tokenizer %s", next.Val)
next.Errorf("Invalid tokenizer 1 %s", next.Val)
}
tokenizerType, ok := types.TypeForName(tokenizer.Type())
x.AssertTrue(ok) // Type is validated during tokenizer loading.
Expand Down
4 changes: 4 additions & 0 deletions tok/hnsw/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ func euclideanDistanceSq[T c.Float](a, b []T, floatBits int) (T, error) {
return applyDistanceFunction(a, b, floatBits, "euclidean distance", vek32.Distance, vek.Distance)
}

func EuclideanDistanceSq[T c.Float](a, b []T, floatBits int) (T, error) {
return applyDistanceFunction(a, b, floatBits, "euclidean distance", vek32.Distance, vek.Distance)
}

// Used for distance, since shorter distance is better
func insortPersistentHeapAscending[T c.Float](
slice []minPersistentHeapElement[T],
Expand Down
11 changes: 11 additions & 0 deletions tok/hnsw/persistent_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ func (hf *persistentIndexFactory[T]) AllowedOptions() opt.AllowedOptions {
return retVal
}

func UpdateIndexSplit[T c.Float](vi index.VectorIndex[T], split int) error {
hnsw, ok := vi.(*persistentHNSW[T])
if !ok {
return errors.New("index is not a persistent HNSW index")
}
hnsw.vecEntryKey = ConcatStrings(hnsw.pred, fmt.Sprintf("%s_%d", VecEntry, split))
hnsw.vecKey = ConcatStrings(hnsw.pred, fmt.Sprintf("%s_%d", VecKeyword, split))
hnsw.vecDead = ConcatStrings(hnsw.pred, fmt.Sprintf("%s_%d", VecDead, split))
return nil
}

// Create is an implementation of the IndexFactory interface function, invoked by an HNSWIndexFactory
// instance. It takes in a string name and a VectorSource implementation, and returns a VectorIndex and error
// flag. It creates an HNSW instance using the index name and populates other parts of the HNSW struct such as
Expand Down
Loading
Loading