From 46eaef47db8db7fe237c4a5f2bf670cd8f988cb1 Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Fri, 13 Jun 2025 03:21:49 +0530 Subject: [PATCH 01/20] perf(vector): Improve hnsw by sharding vectors --- posting/index.go | 201 ++++++++++++++++++++++++++++++++- tok/hnsw/persistent_factory.go | 21 ++-- tok/hnsw/persistent_hnsw.go | 41 +++++++ tok/index/index.go | 7 +- tok/index_factory.go | 14 ++- worker/task.go | 40 ++++--- 6 files changed, 293 insertions(+), 31 deletions(-) diff --git a/posting/index.go b/posting/index.go index 87b1bcfd015..abb98ff4233 100644 --- a/posting/index.go +++ b/posting/index.go @@ -15,6 +15,7 @@ import ( "math" "os" "strings" + "sync" "sync/atomic" "time" "unsafe" @@ -33,8 +34,11 @@ import ( "github.com/hypermodeinc/dgraph/v25/schema" "github.com/hypermodeinc/dgraph/v25/tok" "github.com/hypermodeinc/dgraph/v25/tok/hnsw" + "github.com/hypermodeinc/dgraph/v25/tok/index" "github.com/hypermodeinc/dgraph/v25/types" "github.com/hypermodeinc/dgraph/v25/x" + + "github.com/viterin/vek/vek32" ) var emptyCountParams countParams @@ -162,7 +166,7 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo) // retrieve vector from inUuid save as inVec inVec := types.BytesAsFloatArray(data[0].Value.([]byte)) tc := hnsw.NewTxnCache(NewViTxn(txn), txn.StartTs) - indexer, err := info.factorySpecs[0].CreateIndex(attr) + indexer, err := info.factorySpecs[0].CreateIndex(attr, 0) if err != nil { return []*pb.DirectedEdge{}, err } @@ -1361,6 +1365,198 @@ func (rb *indexRebuildInfo) prefixesForTokIndexes() ([][]byte, error) { return prefixes, nil } +type vectorCentroids struct { + dimension int + numCenters int + + centroids [][]float32 + counts []int64 + weights [][]float32 + mutexs []*sync.Mutex +} + +func (vc *vectorCentroids) findCentroid(input []float32) int { + minIdx := 0 + minDist := math.MaxFloat32 + for i, centroid := range vc.centroids { + dist := vek32.Distance(centroid, input) + if float64(dist) < minDist { + minDist = float64(dist) + minIdx = i + } + } + return minIdx +} + +func (vc *vectorCentroids) addVector(vec []float32) { + idx := vc.findCentroid(vec) + vc.mutexs[idx].Lock() + defer vc.mutexs[idx].Unlock() + for i := 0; i < vc.dimension; i++ { + vc.weights[idx][i] += vec[i] + } + vc.counts[idx]++ +} + +func (vc *vectorCentroids) updateCentroids() { + for i := 0; i < vc.numCenters; i++ { + for j := 0; j < vc.dimension; j++ { + vc.centroids[i][j] = vc.weights[i][j] / float32(vc.counts[i]) + vc.weights[i][j] = 0 + } + fmt.Printf("%d, ", vc.counts[i]) + vc.counts[i] = 0 + } + fmt.Println() +} + +func (vc *vectorCentroids) randomInit() { + vc.dimension = len(vc.centroids[0]) + vc.numCenters = len(vc.centroids) + vc.centroids = make([][]float32, vc.numCenters) + vc.counts = make([]int64, vc.numCenters) + vc.weights = make([][]float32, vc.numCenters) + vc.mutexs = make([]*sync.Mutex, vc.numCenters) + for i := 0; i < vc.numCenters; i++ { + vc.weights[i] = make([]float32, vc.dimension) + vc.counts[i] = 0 + vc.mutexs[i] = &sync.Mutex{} + } +} + +func (vc *vectorCentroids) addSeedCentroid(vec []float32) { + vc.centroids = append(vc.centroids, vec) +} + +const numCentroids = 1000 + +func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSpec, rb *IndexRebuild) error { + pk := x.ParsedKey{Attr: rb.Attr} + vc := &vectorCentroids{} + + MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{ + Prefix: pk.DataPrefix(), + ReadTs: rb.StartTs, + AllVersions: false, + Reverse: false, + CheckInclusion: func(uid uint64) error { + return nil + }, + Function: func(l *List, pk x.ParsedKey) error { + val, err := l.Value(rb.StartTs) + if err != nil { + return err + } + inVec := types.BytesAsFloatArray(val.Value.([]byte)) + vc.addSeedCentroid(inVec) + if len(vc.centroids) == numCentroids { + return ErrStopIteration + } + return nil + }, + StartKey: x.DataKey(rb.Attr, 0), + }) + + vc.randomInit() + + fmt.Println("Clustering Vectors") + for range 5 { + builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} + builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { + edges := []*pb.DirectedEdge{} + val, err := pl.Value(txn.StartTs) + if err != nil { + return []*pb.DirectedEdge{}, err + } + + inVec := types.BytesAsFloatArray(val.Value.([]byte)) + vc.addVector(inVec) + return edges, nil + } + + err := builder.RunWithoutTemp(ctx) + if err != nil { + return err + } + + vc.updateCentroids() + } + + tcs := make([]*hnsw.TxnCache, vc.numCenters) + txns := make([]*Txn, vc.numCenters) + indexers := make([]index.VectorIndex[float32], vc.numCenters) + for i := 0; i < vc.numCenters; i++ { + txns[i] = NewTxn(rb.StartTs) + tcs[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs) + indexers_i, err := factorySpecs[0].CreateIndex(pk.Attr, i) + if err != nil { + return err + } + vc.mutexs[i] = &sync.Mutex{} + indexers[i] = indexers_i + } + + var edgesCreated atomic.Int64 + + numPasses := vc.numCenters / 100 + for pass_idx := range numPasses { + builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} + builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { + val, err := pl.Value(txn.StartTs) + if err != nil { + return []*pb.DirectedEdge{}, err + } + + inVec := types.BytesAsFloatArray(val.Value.([]byte)) + idx := vc.findCentroid(inVec) + if idx%numPasses != pass_idx { + return []*pb.DirectedEdge{}, nil + } + vc.mutexs[idx].Lock() + defer vc.mutexs[idx].Unlock() + _, err = indexers[idx].Insert(ctx, tcs[idx], uid, inVec) + if err != nil { + return []*pb.DirectedEdge{}, err + } + + edgesCreated.Add(int64(1)) + return nil, nil + } + + err := builder.RunWithoutTemp(ctx) + if err != nil { + return err + } + + for idx := range vc.counts { + if idx%numPasses != pass_idx { + continue + } + txns[idx].Update() + writer := NewTxnWriter(pstore) + + x.ExponentialRetry(int(x.Config.MaxRetries), + 20*time.Millisecond, func() error { + err := txns[idx].CommitToDisk(writer, rb.StartTs) + if err == badger.ErrBannedKey { + glog.Errorf("Error while writing to banned namespace.") + return nil + } + return err + }) + + txns[idx].cache.plists = nil + txns[idx] = nil + tcs[idx] = nil + indexers[idx] = nil + } + + fmt.Printf("Created %d edges in pass %d out of %d\n", edgesCreated.Load(), pass_idx, numPasses) + } + + return nil +} + // rebuildTokIndex rebuilds index for a given attribute. // We commit mutations with startTs and ignore the errors. func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { @@ -1392,6 +1588,9 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { } runForVectors := (len(factorySpecs) != 0) + if runForVectors { + return rebuildVectorIndex(ctx, factorySpecs, rb) + } pk := x.ParsedKey{Attr: rb.Attr} builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} diff --git a/tok/hnsw/persistent_factory.go b/tok/hnsw/persistent_factory.go index ff4c622f218..76c9eaa6f5f 100644 --- a/tok/hnsw/persistent_factory.go +++ b/tok/hnsw/persistent_factory.go @@ -87,25 +87,27 @@ func (hf *persistentIndexFactory[T]) AllowedOptions() opt.AllowedOptions { func (hf *persistentIndexFactory[T]) Create( name string, o opt.Options, - floatBits int) (index.VectorIndex[T], error) { + floatBits int, + split int) (index.VectorIndex[T], error) { hf.mu.Lock() defer hf.mu.Unlock() - return hf.createWithLock(name, o, floatBits) + return hf.createWithLock(name, o, floatBits, split) } func (hf *persistentIndexFactory[T]) createWithLock( name string, o opt.Options, - floatBits int) (index.VectorIndex[T], error) { - if !hf.isNameAvailableWithLock(name) { + floatBits int, + split int) (index.VectorIndex[T], error) { + if !hf.isNameAvailableWithLock(fmt.Sprintf("%s-%d", name, split)) { err := errors.New("index with name " + name + " already exists") return nil, err } retVal := &persistentHNSW[T]{ pred: name, - vecEntryKey: ConcatStrings(name, VecEntry), - vecKey: ConcatStrings(name, VecKeyword), - vecDead: ConcatStrings(name, VecDead), + vecEntryKey: ConcatStrings(name, VecEntry, fmt.Sprintf("_%d", split)), + vecKey: ConcatStrings(name, VecKeyword, fmt.Sprintf("_%d", split)), + vecDead: ConcatStrings(name, VecDead, fmt.Sprintf("_%d", split)), floatBits: floatBits, nodeAllEdges: map[uint64][][]uint64{}, } @@ -152,7 +154,8 @@ func (hf *persistentIndexFactory[T]) removeWithLock(name string) error { func (hf *persistentIndexFactory[T]) CreateOrReplace( name string, o opt.Options, - floatBits int) (index.VectorIndex[T], error) { + floatBits int, + split int) (index.VectorIndex[T], error) { hf.mu.Lock() defer hf.mu.Unlock() vi, err := hf.findWithLock(name) @@ -165,5 +168,5 @@ func (hf *persistentIndexFactory[T]) CreateOrReplace( return nil, err } } - return hf.createWithLock(name, o, floatBits) + return hf.createWithLock(name, o, floatBits, split) } diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index e13ddddaf89..4b0d3101cc4 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -8,6 +8,7 @@ package hnsw import ( "context" "fmt" + "sort" "strings" "time" @@ -254,6 +255,46 @@ func (ph *persistentHNSW[T]) Search(ctx context.Context, c index.CacheType, quer return r.Neighbors, err } +type resultRow[T c.Float] struct { + uid uint64 + dist T +} + +func (ph *persistentHNSW[T]) MergeResults(ctx context.Context, c index.CacheType, list []uint64, query []T, maxResults int, filter index.SearchFilter[T]) ([]uint64, error) { + var result []resultRow[T] + + for i := range list { + var vec []T + err := ph.getVecFromUid(list[i], c, &vec) + if err != nil { + return nil, err + } + + dist, err := ph.simType.distanceScore(vec, query, ph.floatBits) + if err != nil { + return nil, err + } + result = append(result, resultRow[T]{ + uid: list[i], + dist: dist, + }) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].dist < result[j].dist + }) + + uids := []uint64{} + for i := range maxResults { + if i > len(result) { + break + } + uids = append(uids, result[i].uid) + } + + return uids, nil +} + // SearchWithUid searches the hnsw graph for the nearest neighbors of the query uid // and returns the traversal path and the nearest neighbors func (ph *persistentHNSW[T]) SearchWithUid(_ context.Context, c index.CacheType, queryUid uint64, diff --git a/tok/index/index.go b/tok/index/index.go index e0a62255ce1..e00fb440932 100644 --- a/tok/index/index.go +++ b/tok/index/index.go @@ -39,7 +39,7 @@ type IndexFactory[T c.Float] interface { // same object. // The set of vectors to use in the index process is defined by // source. - Create(name string, o opts.Options, floatBits int) (VectorIndex[T], error) + Create(name string, o opts.Options, floatBits int, split int) (VectorIndex[T], error) // Find is expected to retrieve the VectorIndex corresponding with the // name. If it attempts to find a name that does not exist, the VectorIndex @@ -56,7 +56,7 @@ type IndexFactory[T c.Float] interface { // CreateOrReplace will create a new index -- as defined by the Create // function -- if it does not yet exist, otherwise, it will replace any // index with the given name. - CreateOrReplace(name string, o opts.Options, floatBits int) (VectorIndex[T], error) + CreateOrReplace(name string, o opts.Options, floatBits int, split int) (VectorIndex[T], error) } // SearchFilter defines a predicate function that we will use to determine @@ -93,6 +93,9 @@ type OptionalIndexSupport[T c.Float] interface { type VectorIndex[T c.Float] interface { OptionalIndexSupport[T] + MergeResults(ctx context.Context, c CacheType, list []uint64, query []T, maxResults int, + filter SearchFilter[T]) ([]uint64, error) + // Search will find the uids for a given set of vectors based on the // input query, limiting to the specified maximum number of results. // The filter parameter indicates that we might discard certain parameters diff --git a/tok/index_factory.go b/tok/index_factory.go index abef317b952..d67610bce1d 100644 --- a/tok/index_factory.go +++ b/tok/index_factory.go @@ -45,7 +45,7 @@ func (fcs *FactoryCreateSpec) Name() string { return fcs.factory.Name() + fcs.factory.GetOptions(fcs.opts) } -func (fcs *FactoryCreateSpec) CreateIndex(name string) (index.VectorIndex[float32], error) { +func (fcs *FactoryCreateSpec) CreateIndex(name string, split int) (index.VectorIndex[float32], error) { if fcs == nil || fcs.factory == nil { return nil, errors.Errorf( @@ -61,7 +61,7 @@ func (fcs *FactoryCreateSpec) CreateIndex(name string) (index.VectorIndex[float3 // has the downside of not allowing us to reuse the pre-existing // index. // nil VectorSource at the moment. - return fcs.factory.CreateOrReplace(name, fcs.opts, 32) + return fcs.factory.CreateOrReplace(name, fcs.opts, 32, split) } func createIndexFactory(f index.IndexFactory[float32]) IndexFactory { @@ -79,8 +79,9 @@ func (f *indexFactory) AllowedOptions() opts.AllowedOptions { func (f *indexFactory) Create( name string, o opts.Options, - floatBits int) (index.VectorIndex[float32], error) { - return f.delegate.Create(name, o, floatBits) + floatBits int, + split int) (index.VectorIndex[float32], error) { + return f.delegate.Create(name, o, floatBits, split) } func (f *indexFactory) Find(name string) (index.VectorIndex[float32], error) { return f.delegate.Find(name) @@ -91,8 +92,9 @@ func (f *indexFactory) Remove(name string) error { func (f *indexFactory) CreateOrReplace( name string, o opts.Options, - floatBits int) (index.VectorIndex[float32], error) { - return f.delegate.CreateOrReplace(name, o, floatBits) + floatBits int, + split int) (index.VectorIndex[float32], error) { + return f.delegate.CreateOrReplace(name, o, floatBits, split) } func (f *indexFactory) GetOptions(o opts.Options) string { diff --git a/worker/task.go b/worker/task.go index 92c1d02350f..0ecd371d842 100644 --- a/worker/task.go +++ b/worker/task.go @@ -12,6 +12,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" "github.com/golang/glog" @@ -360,20 +361,33 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er posting.NewViLocalCache(qs.cache), args.q.ReadTs, ) - indexer, err := cspec.CreateIndex(args.q.Attr) - if err != nil { - return err - } - var nnUids []uint64 - if srcFn.vectorInfo != nil { - nnUids, err = indexer.Search(ctx, qc, srcFn.vectorInfo, - int(numNeighbors), index.AcceptAll[float32]) - } else { - nnUids, err = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, - int(numNeighbors), index.AcceptAll[float32]) - } - if err != nil && !strings.Contains(err.Error(), hnsw.EmptyHNSWTreeError+": "+badger.ErrKeyNotFound.Error()) { + var nnUids []uint64 + var wg sync.WaitGroup + wg.Add(1000) + var mutex sync.Mutex + for i := range 1000 { + go func(idx int) { + nnuids := make([]uint64, 0) + indexer, _ := cspec.CreateIndex(args.q.Attr, i) + if srcFn.vectorInfo != nil { + nnuids, _ = indexer.Search(ctx, qc, srcFn.vectorInfo, + int(numNeighbors), index.AcceptAll[float32]) + } else { + nnuids, _ = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, + int(numNeighbors), index.AcceptAll[float32]) + } + mutex.Lock() + nnUids = append(nnUids, nnuids...) + mutex.Unlock() + wg.Done() + }(i) + } + wg.Wait() + indexer, _ := cspec.CreateIndex(args.q.Attr, 0) + nnUids, err = indexer.MergeResults(ctx, qc, nnUids, srcFn.vectorInfo, + int(numNeighbors), index.AcceptAll[float32]) + if err != nil { return err } sort.Slice(nnUids, func(i, j int) bool { return nnUids[i] < nnUids[j] }) From 7e3aa8e493a98bae4da113cd1b2e82d2c5c33674 Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Wed, 25 Jun 2025 02:32:21 +0530 Subject: [PATCH 02/20] added changes --- posting/index.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/posting/index.go b/posting/index.go index abb98ff4233..e8e70bcf24f 100644 --- a/posting/index.go +++ b/posting/index.go @@ -1413,7 +1413,6 @@ func (vc *vectorCentroids) updateCentroids() { func (vc *vectorCentroids) randomInit() { vc.dimension = len(vc.centroids[0]) vc.numCenters = len(vc.centroids) - vc.centroids = make([][]float32, vc.numCenters) vc.counts = make([]int64, vc.numCenters) vc.weights = make([][]float32, vc.numCenters) vc.mutexs = make([]*sync.Mutex, vc.numCenters) @@ -1433,6 +1432,7 @@ const numCentroids = 1000 func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSpec, rb *IndexRebuild) error { pk := x.ParsedKey{Attr: rb.Attr} vc := &vectorCentroids{} + vc.centroids = make([][]float32, 0) MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{ Prefix: pk.DataPrefix(), From d651ad9b577d33099ce47fefe850a8cd8afcfa87 Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Wed, 25 Jun 2025 22:38:40 +0530 Subject: [PATCH 03/20] added changes --- posting/index.go | 307 +++++++++++--------- tok/hnsw/helper.go | 4 + tok/hnsw/persistent_factory.go | 32 +- tok/hnsw/persistent_hnsw.go | 36 +++ tok/index/index.go | 24 +- tok/index_factory.go | 14 +- tok/kmeans/kmeans.go | 135 +++++++++ tok/partitioned_hnsw/partitioned_factory.go | 160 ++++++++++ tok/partitioned_hnsw/partitioned_hnsw.go | 193 ++++++++++++ worker/task.go | 32 +- 10 files changed, 759 insertions(+), 178 deletions(-) create mode 100644 tok/kmeans/kmeans.go create mode 100644 tok/partitioned_hnsw/partitioned_factory.go create mode 100644 tok/partitioned_hnsw/partitioned_hnsw.go diff --git a/posting/index.go b/posting/index.go index e8e70bcf24f..667915ae000 100644 --- a/posting/index.go +++ b/posting/index.go @@ -15,7 +15,6 @@ import ( "math" "os" "strings" - "sync" "sync/atomic" "time" "unsafe" @@ -34,11 +33,10 @@ import ( "github.com/hypermodeinc/dgraph/v25/schema" "github.com/hypermodeinc/dgraph/v25/tok" "github.com/hypermodeinc/dgraph/v25/tok/hnsw" - "github.com/hypermodeinc/dgraph/v25/tok/index" + tokIndex "github.com/hypermodeinc/dgraph/v25/tok/index" + "github.com/hypermodeinc/dgraph/v25/types" "github.com/hypermodeinc/dgraph/v25/x" - - "github.com/viterin/vek/vek32" ) var emptyCountParams countParams @@ -166,7 +164,7 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo) // retrieve vector from inUuid save as inVec inVec := types.BytesAsFloatArray(data[0].Value.([]byte)) tc := hnsw.NewTxnCache(NewViTxn(txn), txn.StartTs) - indexer, err := info.factorySpecs[0].CreateIndex(attr, 0) + indexer, err := info.factorySpecs[0].CreateIndex(attr) if err != nil { return []*pb.DirectedEdge{}, err } @@ -1365,112 +1363,67 @@ func (rb *indexRebuildInfo) prefixesForTokIndexes() ([][]byte, error) { return prefixes, nil } -type vectorCentroids struct { - dimension int - numCenters int +const numCentroids = 1000 - centroids [][]float32 - counts []int64 - weights [][]float32 - mutexs []*sync.Mutex -} +func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSpec, rb *IndexRebuild) error { + pk := x.ParsedKey{Attr: rb.Attr} -func (vc *vectorCentroids) findCentroid(input []float32) int { - minIdx := 0 - minDist := math.MaxFloat32 - for i, centroid := range vc.centroids { - dist := vek32.Distance(centroid, input) - if float64(dist) < minDist { - minDist = float64(dist) - minIdx = i - } + indexer, err := factorySpecs[0].CreateIndex(pk.Attr) + if err != nil { + return err } - return minIdx -} -func (vc *vectorCentroids) addVector(vec []float32) { - idx := vc.findCentroid(vec) - vc.mutexs[idx].Lock() - defer vc.mutexs[idx].Unlock() - for i := 0; i < vc.dimension; i++ { - vc.weights[idx][i] += vec[i] + if indexer.NumSeedVectors() > 0 { + count := 0 + MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{ + Prefix: pk.DataPrefix(), + ReadTs: rb.StartTs, + AllVersions: false, + Reverse: false, + CheckInclusion: func(uid uint64) error { + return nil + }, + Function: func(l *List, pk x.ParsedKey) error { + val, err := l.Value(rb.StartTs) + if err != nil { + return err + } + inVec := types.BytesAsFloatArray(val.Value.([]byte)) + count += 1 + indexer.AddSeedVector(inVec) + if count == indexer.NumSeedVectors() { + return ErrStopIteration + } + return nil + }, + StartKey: x.DataKey(rb.Attr, 0), + }) } - vc.counts[idx]++ -} -func (vc *vectorCentroids) updateCentroids() { - for i := 0; i < vc.numCenters; i++ { - for j := 0; j < vc.dimension; j++ { - vc.centroids[i][j] = vc.weights[i][j] / float32(vc.counts[i]) - vc.weights[i][j] = 0 - } - fmt.Printf("%d, ", vc.counts[i]) - vc.counts[i] = 0 + txns := make([]*Txn, indexer.NumThreads()) + for i := range txns { + txns[i] = NewTxn(rb.StartTs) } - fmt.Println() -} - -func (vc *vectorCentroids) randomInit() { - vc.dimension = len(vc.centroids[0]) - vc.numCenters = len(vc.centroids) - vc.counts = make([]int64, vc.numCenters) - vc.weights = make([][]float32, vc.numCenters) - vc.mutexs = make([]*sync.Mutex, vc.numCenters) - for i := 0; i < vc.numCenters; i++ { - vc.weights[i] = make([]float32, vc.dimension) - vc.counts[i] = 0 - vc.mutexs[i] = &sync.Mutex{} + caches := make([]tokIndex.CacheType, indexer.NumThreads()) + for i := range caches { + caches[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs) } -} -func (vc *vectorCentroids) addSeedCentroid(vec []float32) { - vc.centroids = append(vc.centroids, vec) -} + for pass_idx := range indexer.NumBuildPasses() { + fmt.Println("Building pass", pass_idx) -const numCentroids = 1000 + indexer.StartBuild(caches) -func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSpec, rb *IndexRebuild) error { - pk := x.ParsedKey{Attr: rb.Attr} - vc := &vectorCentroids{} - vc.centroids = make([][]float32, 0) - - MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{ - Prefix: pk.DataPrefix(), - ReadTs: rb.StartTs, - AllVersions: false, - Reverse: false, - CheckInclusion: func(uid uint64) error { - return nil - }, - Function: func(l *List, pk x.ParsedKey) error { - val, err := l.Value(rb.StartTs) - if err != nil { - return err - } - inVec := types.BytesAsFloatArray(val.Value.([]byte)) - vc.addSeedCentroid(inVec) - if len(vc.centroids) == numCentroids { - return ErrStopIteration - } - return nil - }, - StartKey: x.DataKey(rb.Attr, 0), - }) - - vc.randomInit() - - fmt.Println("Clustering Vectors") - for range 5 { builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { edges := []*pb.DirectedEdge{} - val, err := pl.Value(txn.StartTs) + val, err := pl.Value(rb.StartTs) if err != nil { return []*pb.DirectedEdge{}, err } inVec := types.BytesAsFloatArray(val.Value.([]byte)) - vc.addVector(inVec) + indexer.BuildInsert(ctx, uid, inVec) return edges, nil } @@ -1479,48 +1432,25 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp return err } - vc.updateCentroids() + indexer.EndBuild() } - tcs := make([]*hnsw.TxnCache, vc.numCenters) - txns := make([]*Txn, vc.numCenters) - indexers := make([]index.VectorIndex[float32], vc.numCenters) - for i := 0; i < vc.numCenters; i++ { - txns[i] = NewTxn(rb.StartTs) - tcs[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs) - indexers_i, err := factorySpecs[0].CreateIndex(pk.Attr, i) - if err != nil { - return err - } - vc.mutexs[i] = &sync.Mutex{} - indexers[i] = indexers_i - } + for pass_idx := range indexer.NumIndexPasses() { + fmt.Println("Indexing pass", pass_idx) - var edgesCreated atomic.Int64 + indexer.StartBuild(caches) - numPasses := vc.numCenters / 100 - for pass_idx := range numPasses { builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { - val, err := pl.Value(txn.StartTs) + edges := []*pb.DirectedEdge{} + val, err := pl.Value(rb.StartTs) if err != nil { return []*pb.DirectedEdge{}, err } inVec := types.BytesAsFloatArray(val.Value.([]byte)) - idx := vc.findCentroid(inVec) - if idx%numPasses != pass_idx { - return []*pb.DirectedEdge{}, nil - } - vc.mutexs[idx].Lock() - defer vc.mutexs[idx].Unlock() - _, err = indexers[idx].Insert(ctx, tcs[idx], uid, inVec) - if err != nil { - return []*pb.DirectedEdge{}, err - } - - edgesCreated.Add(int64(1)) - return nil, nil + indexer.BuildInsert(ctx, uid, inVec) + return edges, nil } err := builder.RunWithoutTemp(ctx) @@ -1528,10 +1458,7 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp return err } - for idx := range vc.counts { - if idx%numPasses != pass_idx { - continue - } + for _, idx := range indexer.EndBuild() { txns[idx].Update() writer := NewTxnWriter(pstore) @@ -1547,14 +1474,132 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp txns[idx].cache.plists = nil txns[idx] = nil - tcs[idx] = nil - indexers[idx] = nil } - - fmt.Printf("Created %d edges in pass %d out of %d\n", edgesCreated.Load(), pass_idx, numPasses) } return nil + + // MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{ + // Prefix: pk.DataPrefix(), + // ReadTs: rb.StartTs, + // AllVersions: false, + // Reverse: false, + // CheckInclusion: func(uid uint64) error { + // return nil + // }, + // Function: func(l *List, pk x.ParsedKey) error { + // val, err := l.Value(rb.StartTs) + // if err != nil { + // return err + // } + // inVec := types.BytesAsFloatArray(val.Value.([]byte)) + // vc.addSeedCentroid(inVec) + // if len(vc.centroids) == numCentroids { + // return ErrStopIteration + // } + // return nil + // }, + // StartKey: x.DataKey(rb.Attr, 0), + // }) + + // vc.randomInit() + + // fmt.Println("Clustering Vectors") + // for range 5 { + // builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} + // builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { + // edges := []*pb.DirectedEdge{} + // val, err := pl.Value(txn.StartTs) + // if err != nil { + // return []*pb.DirectedEdge{}, err + // } + + // inVec := types.BytesAsFloatArray(val.Value.([]byte)) + // vc.addVector(inVec) + // return edges, nil + // } + + // err := builder.RunWithoutTemp(ctx) + // if err != nil { + // return err + // } + + // vc.updateCentroids() + // } + + // tcs := make([]*hnsw.TxnCache, vc.numCenters) + // txns := make([]*Txn, vc.numCenters) + // indexers := make([]index.VectorIndex[float32], vc.numCenters) + // for i := 0; i < vc.numCenters; i++ { + // txns[i] = NewTxn(rb.StartTs) + // tcs[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs) + // indexers_i, err := factorySpecs[0].CreateIndex(pk.Attr, i) + // if err != nil { + // return err + // } + // vc.mutexs[i] = &sync.Mutex{} + // indexers[i] = indexers_i + // } + + // var edgesCreated atomic.Int64 + + // numPasses := vc.numCenters / 100 + // for pass_idx := range numPasses { + // builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} + // builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { + // val, err := pl.Value(txn.StartTs) + // if err != nil { + // return []*pb.DirectedEdge{}, err + // } + + // inVec := types.BytesAsFloatArray(val.Value.([]byte)) + // idx := vc.findCentroid(inVec) + // if idx%numPasses != pass_idx { + // return []*pb.DirectedEdge{}, nil + // } + // vc.mutexs[idx].Lock() + // defer vc.mutexs[idx].Unlock() + // _, err = indexers[idx].Insert(ctx, tcs[idx], uid, inVec) + // if err != nil { + // return []*pb.DirectedEdge{}, err + // } + + // edgesCreated.Add(int64(1)) + // return nil, nil + // } + + // err := builder.RunWithoutTemp(ctx) + // if err != nil { + // return err + // } + + // for idx := range vc.counts { + // if idx%numPasses != pass_idx { + // continue + // } + // txns[idx].Update() + // writer := NewTxnWriter(pstore) + + // x.ExponentialRetry(int(x.Config.MaxRetries), + // 20*time.Millisecond, func() error { + // err := txns[idx].CommitToDisk(writer, rb.StartTs) + // if err == badger.ErrBannedKey { + // glog.Errorf("Error while writing to banned namespace.") + // return nil + // } + // return err + // }) + + // txns[idx].cache.plists = nil + // txns[idx] = nil + // tcs[idx] = nil + // indexers[idx] = nil + // } + + // fmt.Printf("Created %d edges in pass %d out of %d\n", edgesCreated.Load(), pass_idx, numPasses) + // } + + // return nil } // rebuildTokIndex rebuilds index for a given attribute. diff --git a/tok/hnsw/helper.go b/tok/hnsw/helper.go index 477f5bc9b27..c3b78c7a488 100644 --- a/tok/hnsw/helper.go +++ b/tok/hnsw/helper.go @@ -114,6 +114,10 @@ func euclideanDistanceSq[T c.Float](a, b []T, floatBits int) (T, error) { return applyDistanceFunction(a, b, floatBits, "euclidean distance", vek32.Distance, vek.Distance) } +func EuclideanDistanceSq[T c.Float](a, b []T, floatBits int) (T, error) { + return applyDistanceFunction(a, b, floatBits, "euclidean distance", vek32.Distance, vek.Distance) +} + // Used for distance, since shorter distance is better func insortPersistentHeapAscending[T c.Float]( slice []minPersistentHeapElement[T], diff --git a/tok/hnsw/persistent_factory.go b/tok/hnsw/persistent_factory.go index 76c9eaa6f5f..4bc13b48ea6 100644 --- a/tok/hnsw/persistent_factory.go +++ b/tok/hnsw/persistent_factory.go @@ -78,6 +78,17 @@ func (hf *persistentIndexFactory[T]) AllowedOptions() opt.AllowedOptions { return retVal } +func UpdateIndexSplit[T c.Float](vi index.VectorIndex[T], split int) error { + hnsw, ok := vi.(*persistentHNSW[T]) + if !ok { + return errors.New("index is not a persistent HNSW index") + } + hnsw.vecEntryKey = ConcatStrings(hnsw.pred, fmt.Sprintf("%s_%d", VecEntry, split)) + hnsw.vecKey = ConcatStrings(hnsw.pred, fmt.Sprintf("%s_%d", VecKeyword, split)) + hnsw.vecDead = ConcatStrings(hnsw.pred, fmt.Sprintf("%s_%d", VecDead, split)) + return nil +} + // Create is an implementation of the IndexFactory interface function, invoked by an HNSWIndexFactory // instance. It takes in a string name and a VectorSource implementation, and returns a VectorIndex and error // flag. It creates an HNSW instance using the index name and populates other parts of the HNSW struct such as @@ -87,27 +98,25 @@ func (hf *persistentIndexFactory[T]) AllowedOptions() opt.AllowedOptions { func (hf *persistentIndexFactory[T]) Create( name string, o opt.Options, - floatBits int, - split int) (index.VectorIndex[T], error) { + floatBits int) (index.VectorIndex[T], error) { hf.mu.Lock() defer hf.mu.Unlock() - return hf.createWithLock(name, o, floatBits, split) + return hf.createWithLock(name, o, floatBits) } func (hf *persistentIndexFactory[T]) createWithLock( name string, o opt.Options, - floatBits int, - split int) (index.VectorIndex[T], error) { - if !hf.isNameAvailableWithLock(fmt.Sprintf("%s-%d", name, split)) { + floatBits int) (index.VectorIndex[T], error) { + if !hf.isNameAvailableWithLock(name) { err := errors.New("index with name " + name + " already exists") return nil, err } retVal := &persistentHNSW[T]{ pred: name, - vecEntryKey: ConcatStrings(name, VecEntry, fmt.Sprintf("_%d", split)), - vecKey: ConcatStrings(name, VecKeyword, fmt.Sprintf("_%d", split)), - vecDead: ConcatStrings(name, VecDead, fmt.Sprintf("_%d", split)), + vecEntryKey: ConcatStrings(name, VecEntry), + vecKey: ConcatStrings(name, VecKeyword), + vecDead: ConcatStrings(name, VecDead), floatBits: floatBits, nodeAllEdges: map[uint64][][]uint64{}, } @@ -154,8 +163,7 @@ func (hf *persistentIndexFactory[T]) removeWithLock(name string) error { func (hf *persistentIndexFactory[T]) CreateOrReplace( name string, o opt.Options, - floatBits int, - split int) (index.VectorIndex[T], error) { + floatBits int) (index.VectorIndex[T], error) { hf.mu.Lock() defer hf.mu.Unlock() vi, err := hf.findWithLock(name) @@ -168,5 +176,5 @@ func (hf *persistentIndexFactory[T]) CreateOrReplace( return nil, err } } - return hf.createWithLock(name, o, floatBits, split) + return hf.createWithLock(name, o, floatBits) } diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index 4b0d3101cc4..f1b0e4f3d84 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -33,6 +33,7 @@ type persistentHNSW[T c.Float] struct { // layer for uuid 65443. The result will be a neighboring uuid. nodeAllEdges map[uint64][][]uint64 deadNodes map[uint64]struct{} + cache index.CacheType } func GetPersistantOptions[T c.Float](o opt.Options) string { @@ -112,6 +113,41 @@ func (ph *persistentHNSW[T]) applyOptions(o opt.Options) error { return nil } +func (ph *persistentHNSW[T]) NumBuildPasses() int { + return 0 +} + +func (ph *persistentHNSW[T]) NumIndexPasses() int { + return 1 +} + +func (ph *persistentHNSW[T]) NumSeedVectors() int { + return 0 +} + +func (ph *persistentHNSW[T]) StartBuild(caches []index.CacheType) { + ph.nodeAllEdges = make(map[uint64][][]uint64) + ph.cache = caches[0] +} + +func (ph *persistentHNSW[T]) EndBuild() []int { + ph.nodeAllEdges = nil + ph.cache = nil + return []int{0} +} + +func (ph *persistentHNSW[T]) NumThreads() int { + return 1 +} + +func (ph *persistentHNSW[T]) BuildInsert(ctx context.Context, uid uint64, vec []T) error { + _, err := ph.Insert(ctx, ph.cache, uid, vec) + return err +} + +func (ph *persistentHNSW[T]) AddSeedVector(vec []T) { +} + func (ph *persistentHNSW[T]) emptyFinalResultWithError(e error) ( *index.SearchPathResult, error) { return index.NewSearchPathResult(), e diff --git a/tok/index/index.go b/tok/index/index.go index e00fb440932..503c2b66faa 100644 --- a/tok/index/index.go +++ b/tok/index/index.go @@ -39,7 +39,7 @@ type IndexFactory[T c.Float] interface { // same object. // The set of vectors to use in the index process is defined by // source. - Create(name string, o opts.Options, floatBits int, split int) (VectorIndex[T], error) + Create(name string, o opts.Options, floatBits int) (VectorIndex[T], error) // Find is expected to retrieve the VectorIndex corresponding with the // name. If it attempts to find a name that does not exist, the VectorIndex @@ -56,7 +56,7 @@ type IndexFactory[T c.Float] interface { // CreateOrReplace will create a new index -- as defined by the Create // function -- if it does not yet exist, otherwise, it will replace any // index with the given name. - CreateOrReplace(name string, o opts.Options, floatBits int, split int) (VectorIndex[T], error) + CreateOrReplace(name string, o opts.Options, floatBits int) (VectorIndex[T], error) } // SearchFilter defines a predicate function that we will use to determine @@ -89,6 +89,17 @@ type OptionalIndexSupport[T c.Float] interface { filter SearchFilter[T]) (*SearchPathResult, error) } +type VectorPartitionStrat[T c.Float] interface { + FindIndexForSearch(vec []T) ([]int, error) + FindIndexForInsert(vec []T) (int, error) + NumPasses() int + NumSeedVectors() int + StartBuildPass() + EndBuildPass() + AddSeedVector(vec []T) + AddVector(vec []T) error +} + // A VectorIndex can be used to Search for vectors and add vectors to an index. type VectorIndex[T c.Float] interface { OptionalIndexSupport[T] @@ -119,6 +130,15 @@ type VectorIndex[T c.Float] interface { // Insert will add a vector and uuid into the existing VectorIndex. If // uuid already exists, it should throw an error to not insert duplicate uuids Insert(ctx context.Context, c CacheType, uuid uint64, vec []T) ([]*KeyValue, error) + + BuildInsert(ctx context.Context, uuid uint64, vec []T) error + AddSeedVector(vec []T) + NumBuildPasses() int + NumIndexPasses() int + NumSeedVectors() int + StartBuild(caches []CacheType) + EndBuild() []int + NumThreads() int } // A Txn is an interface representation of a persistent storage transaction, diff --git a/tok/index_factory.go b/tok/index_factory.go index d67610bce1d..abef317b952 100644 --- a/tok/index_factory.go +++ b/tok/index_factory.go @@ -45,7 +45,7 @@ func (fcs *FactoryCreateSpec) Name() string { return fcs.factory.Name() + fcs.factory.GetOptions(fcs.opts) } -func (fcs *FactoryCreateSpec) CreateIndex(name string, split int) (index.VectorIndex[float32], error) { +func (fcs *FactoryCreateSpec) CreateIndex(name string) (index.VectorIndex[float32], error) { if fcs == nil || fcs.factory == nil { return nil, errors.Errorf( @@ -61,7 +61,7 @@ func (fcs *FactoryCreateSpec) CreateIndex(name string, split int) (index.VectorI // has the downside of not allowing us to reuse the pre-existing // index. // nil VectorSource at the moment. - return fcs.factory.CreateOrReplace(name, fcs.opts, 32, split) + return fcs.factory.CreateOrReplace(name, fcs.opts, 32) } func createIndexFactory(f index.IndexFactory[float32]) IndexFactory { @@ -79,9 +79,8 @@ func (f *indexFactory) AllowedOptions() opts.AllowedOptions { func (f *indexFactory) Create( name string, o opts.Options, - floatBits int, - split int) (index.VectorIndex[float32], error) { - return f.delegate.Create(name, o, floatBits, split) + floatBits int) (index.VectorIndex[float32], error) { + return f.delegate.Create(name, o, floatBits) } func (f *indexFactory) Find(name string) (index.VectorIndex[float32], error) { return f.delegate.Find(name) @@ -92,9 +91,8 @@ func (f *indexFactory) Remove(name string) error { func (f *indexFactory) CreateOrReplace( name string, o opts.Options, - floatBits int, - split int) (index.VectorIndex[float32], error) { - return f.delegate.CreateOrReplace(name, o, floatBits, split) + floatBits int) (index.VectorIndex[float32], error) { + return f.delegate.CreateOrReplace(name, o, floatBits) } func (f *indexFactory) GetOptions(o opts.Options) string { diff --git a/tok/kmeans/kmeans.go b/tok/kmeans/kmeans.go new file mode 100644 index 00000000000..00768b1f006 --- /dev/null +++ b/tok/kmeans/kmeans.go @@ -0,0 +1,135 @@ +package kmeans + +import ( + "fmt" + "math" + "sync" + + c "github.com/hypermodeinc/dgraph/v25/tok/constraints" + "github.com/hypermodeinc/dgraph/v25/tok/index" +) + +type Kmeans[T c.Float] struct { + floatBits int + centroids *vectorCentroids[T] +} + +func CreateKMeans[T c.Float](floatBits int, distFunc func(a, b []T, floatBits int) (T, error)) index.VectorPartitionStrat[T] { + return &Kmeans[T]{ + floatBits: floatBits, + centroids: &vectorCentroids[T]{ + distFunc: distFunc, + floatBits: floatBits, + }, + } +} + +func (km *Kmeans[T]) AddSeedVector(vec []T) { + km.centroids.addSeedCentroid(vec) +} + +func (km *Kmeans[T]) AddVector(vec []T) error { + return km.centroids.addVector(vec) +} + +func (km *Kmeans[T]) FindIndexForSearch(vec []T) ([]int, error) { + res := make([]int, len(km.centroids.centroids)) + for i := range res { + res[i] = i + } + return res, nil +} + +func (km *Kmeans[T]) FindIndexForInsert(vec []T) (int, error) { + return km.centroids.findCentroid(vec) +} + +func (km *Kmeans[T]) NumPasses() int { + return 5 +} + +func (km *Kmeans[T]) NumSeedVectors() int { + return 1000 +} + +func (km *Kmeans[T]) StartBuildPass() { + if km.centroids.weights == nil { + km.centroids.randomInit() + } +} + +func (km *Kmeans[T]) EndBuildPass() { + km.centroids.updateCentroids() +} + +type vectorCentroids[T c.Float] struct { + dimension int + numCenters int + + distFunc func(a, b []T, floatBits int) (T, error) + + centroids [][]T + counts []int64 + weights [][]T + mutexs []*sync.Mutex + floatBits int +} + +func (vc *vectorCentroids[T]) findCentroid(input []T) (int, error) { + minIdx := 0 + minDist := math.MaxFloat32 + for i, centroid := range vc.centroids { + dist, err := vc.distFunc(centroid, input, vc.floatBits) + if err != nil { + return 0, err + } + if float64(dist) < minDist { + minDist = float64(dist) + minIdx = i + } + } + return minIdx, nil +} + +func (vc *vectorCentroids[T]) addVector(vec []T) error { + idx, err := vc.findCentroid(vec) + if err != nil { + return err + } + vc.mutexs[idx].Lock() + defer vc.mutexs[idx].Unlock() + for i := 0; i < vc.dimension; i++ { + vc.weights[idx][i] += vec[i] + } + vc.counts[idx]++ + return nil +} + +func (vc *vectorCentroids[T]) updateCentroids() { + for i := 0; i < vc.numCenters; i++ { + for j := 0; j < vc.dimension; j++ { + vc.centroids[i][j] = vc.weights[i][j] / T(vc.counts[i]) + vc.weights[i][j] = 0 + } + fmt.Printf("%d, ", vc.counts[i]) + vc.counts[i] = 0 + } + fmt.Println() +} + +func (vc *vectorCentroids[T]) randomInit() { + vc.dimension = len(vc.centroids[0]) + vc.numCenters = len(vc.centroids) + vc.counts = make([]int64, vc.numCenters) + vc.weights = make([][]T, vc.numCenters) + vc.mutexs = make([]*sync.Mutex, vc.numCenters) + for i := 0; i < vc.numCenters; i++ { + vc.weights[i] = make([]T, vc.dimension) + vc.counts[i] = 0 + vc.mutexs[i] = &sync.Mutex{} + } +} + +func (vc *vectorCentroids[T]) addSeedCentroid(vec []T) { + vc.centroids = append(vc.centroids, vec) +} diff --git a/tok/partitioned_hnsw/partitioned_factory.go b/tok/partitioned_hnsw/partitioned_factory.go new file mode 100644 index 00000000000..28dc608a362 --- /dev/null +++ b/tok/partitioned_hnsw/partitioned_factory.go @@ -0,0 +1,160 @@ +/* + * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package partitioned_hnsw + +import ( + "errors" + "fmt" + "sync" + + c "github.com/hypermodeinc/dgraph/v25/tok/constraints" + "github.com/hypermodeinc/dgraph/v25/tok/hnsw" + "github.com/hypermodeinc/dgraph/v25/tok/index" + opt "github.com/hypermodeinc/dgraph/v25/tok/options" +) + +const ( + NumClustersOpt string = "numClusters" + PartitionStratOpt string = "partitionStratOpt" + PartitionedHNSW string = "partionedHNSW" +) + +type partitionedHNSWIndexFactory[T c.Float] struct { + indexMap map[string]index.VectorIndex[T] + floatBits int + mu sync.RWMutex +} + +// CreateFactory creates an instance of the private struct persistentIndexFactory. +// NOTE: if T and floatBits do not match in # of bits, there will be consequences. +func CreateFactory[T c.Float](floatBits int) index.IndexFactory[T] { + return &partitionedHNSWIndexFactory[T]{ + indexMap: map[string]index.VectorIndex[T]{}, + floatBits: floatBits, + } +} + +// Implements NamedFactory interface for use as a plugin. +func (hf *partitionedHNSWIndexFactory[T]) Name() string { return PartitionedHNSW } + +func (hf *partitionedHNSWIndexFactory[T]) GetOptions(o opt.Options) string { + return hnsw.GetPersistantOptions[T](o) +} + +func (hf *partitionedHNSWIndexFactory[T]) isNameAvailableWithLock(name string) bool { + _, nameUsed := hf.indexMap[name] + return !nameUsed +} + +// hf.AllowedOptions() allows persistentIndexFactory to implement the +// IndexFactory interface (see vector-indexer/index/index.go for details). +// We define here options for exponent, maxLevels, efSearch, efConstruction, +// and metric. +func (hf *partitionedHNSWIndexFactory[T]) AllowedOptions() opt.AllowedOptions { + retVal := opt.NewAllowedOptions() + retVal.AddIntOption(hnsw.ExponentOpt). + AddIntOption(hnsw.MaxLevelsOpt). + AddIntOption(hnsw.EfConstructionOpt). + AddIntOption(hnsw.EfSearchOpt). + AddIntOption(NumClustersOpt). + AddStringOption(PartitionStratOpt) + getSimFunc := func(optValue string) (any, error) { + if optValue != hnsw.Euclidean && optValue != hnsw.Cosine && optValue != hnsw.DotProd { + return nil, errors.New(fmt.Sprintf("Can't create a vector index for %s", optValue)) + } + return hnsw.GetSimType[T](optValue, hf.floatBits), nil + } + + retVal.AddCustomOption(hnsw.MetricOpt, getSimFunc) + return retVal +} + +// Create is an implementation of the IndexFactory interface function, invoked by an HNSWIndexFactory +// instance. It takes in a string name and a VectorSource implementation, and returns a VectorIndex and error +// flag. It creates an HNSW instance using the index name and populates other parts of the HNSW struct such as +// multFactor, maxLevels, efConstruction, maxNeighbors, and efSearch using struct parameters. +// It then populates the HNSW graphs using the InsertChunk function until there are no more items to populate. +// Finally, the function adds the name and hnsw object to the in memory map and returns the object. +func (hf *partitionedHNSWIndexFactory[T]) Create( + name string, + o opt.Options, + floatBits int) (index.VectorIndex[T], error) { + hf.mu.Lock() + defer hf.mu.Unlock() + return hf.createWithLock(name, o, floatBits) +} + +func (hf *partitionedHNSWIndexFactory[T]) createWithLock( + name string, + o opt.Options, + floatBits int) (index.VectorIndex[T], error) { + if !hf.isNameAvailableWithLock(name) { + err := errors.New("index with name " + name + " already exists") + return nil, err + } + retVal := &partitionedHNSW[T]{ + pred: name, + floatBits: floatBits, + clusterMap: map[int]index.VectorIndex[T]{}, + } + err := retVal.applyOptions(o) + if err != nil { + return nil, err + } + hf.indexMap[name] = retVal + return retVal, nil +} + +// Find is an implementation of the IndexFactory interface function, invoked by an persistentIndexFactory +// instance. It returns the VectorIndex corresponding with a string name using the in memory map. +func (hf *partitionedHNSWIndexFactory[T]) Find(name string) (index.VectorIndex[T], error) { + hf.mu.RLock() + defer hf.mu.RUnlock() + return hf.findWithLock(name) +} + +func (hf *partitionedHNSWIndexFactory[T]) findWithLock(name string) (index.VectorIndex[T], error) { + vecInd := hf.indexMap[name] + return vecInd, nil +} + +// Remove is an implementation of the IndexFactory interface function, invoked by an persistentIndexFactory +// instance. It removes the VectorIndex corresponding with a string name using the in memory map. +func (hf *partitionedHNSWIndexFactory[T]) Remove(name string) error { + hf.mu.Lock() + defer hf.mu.Unlock() + return hf.removeWithLock(name) +} + +func (hf *partitionedHNSWIndexFactory[T]) removeWithLock(name string) error { + delete(hf.indexMap, name) + return nil +} + +// CreateOrReplace is an implementation of the IndexFactory interface funciton, +// invoked by an persistentIndexFactory. It checks if a VectorIndex +// correpsonding with name exists. If it does, it removes it, and replaces it +// via the Create function using the passed VectorSource. If the VectorIndex +// does not exist, it creates that VectorIndex corresponding with the name using +// the VectorSource. +func (hf *partitionedHNSWIndexFactory[T]) CreateOrReplace( + name string, + o opt.Options, + floatBits int) (index.VectorIndex[T], error) { + hf.mu.Lock() + defer hf.mu.Unlock() + vi, err := hf.findWithLock(name) + if err != nil { + return nil, err + } + if vi != nil { + err = hf.removeWithLock(name) + if err != nil { + return nil, err + } + } + return hf.createWithLock(name, o, floatBits) +} diff --git a/tok/partitioned_hnsw/partitioned_hnsw.go b/tok/partitioned_hnsw/partitioned_hnsw.go new file mode 100644 index 00000000000..5dcc6dd564f --- /dev/null +++ b/tok/partitioned_hnsw/partitioned_hnsw.go @@ -0,0 +1,193 @@ +// CreateFactory creates an instance of the private struct persistentIndexFactory. +// NOTE: if T and floatBits do not match in # of bits, there will be consequences. + +package partitioned_hnsw + +import ( + "context" + "errors" + "sync" + + c "github.com/hypermodeinc/dgraph/v25/tok/constraints" + hnsw "github.com/hypermodeinc/dgraph/v25/tok/hnsw" + "github.com/hypermodeinc/dgraph/v25/tok/index" + "github.com/hypermodeinc/dgraph/v25/tok/kmeans" + opt "github.com/hypermodeinc/dgraph/v25/tok/options" +) + +type partitionedHNSW[T c.Float] struct { + floatBits int + pred string + + clusterMap map[int]index.VectorIndex[T] + numClusters int + factory index.IndexFactory[T] + partition index.VectorPartitionStrat[T] + + hnswOptions opt.Options + partitionStrat string + + caches []index.CacheType + buildPass int +} + +func (ph *partitionedHNSW[T]) applyOptions(o opt.Options) error { + if o.Specifies(NumClustersOpt) { + ph.numClusters, _, _ = opt.GetOpt(o, NumClustersOpt, 1000) + } + + if o.Specifies(PartitionStratOpt) { + ph.partitionStrat, _, _ = opt.GetOpt(o, PartitionStratOpt, "kmeans") + } + + if ph.partitionStrat != "kmeans" && ph.partitionStrat != "query" { + return errors.New("partition strategy must be kmeans or query") + } + + if ph.partitionStrat == "kmeans" { + ph.partition = kmeans.CreateKMeans(ph.floatBits, hnsw.EuclideanDistanceSq[T]) + } + + ph.buildPass = 0 + ph.hnswOptions = o + ph.factory = hnsw.CreateFactory[T](ph.floatBits) + for i := range ph.numClusters { + vi, err := ph.factory.Create(ph.pred, ph.hnswOptions, ph.floatBits) + if err != nil { + return err + } + err = hnsw.UpdateIndexSplit(vi, i) + if err != nil { + return err + } + ph.clusterMap[i] = vi + } + return nil +} + +func (ph *partitionedHNSW[T]) AddSeedVector(vec []T) { + ph.partition.AddSeedVector(vec) +} + +func (ph *partitionedHNSW[T]) BuildInsert(ctx context.Context, uuid uint64, vec []T) error { + passIdx := ph.buildPass - ph.partition.NumPasses() + if passIdx < 0 { + return ph.partition.AddVector(vec) + } + index, err := ph.partition.FindIndexForInsert(vec) + if err != nil { + return err + } + if index%NUM_PASSES != passIdx { + return nil + } + return ph.clusterMap[index].BuildInsert(ctx, uuid, vec) +} + +const NUM_PASSES = 10 + +func (ph *partitionedHNSW[T]) NumBuildPasses() int { + return ph.partition.NumPasses() +} + +func (ph *partitionedHNSW[T]) NumIndexPasses() int { + return NUM_PASSES +} + +func (ph *partitionedHNSW[T]) NumThreads() int { + return NUM_PASSES +} + +func (ph *partitionedHNSW[T]) NumSeedVectors() int { + return ph.partition.NumSeedVectors() +} + +func (ph *partitionedHNSW[T]) StartBuild(caches []index.CacheType) { + ph.caches = caches + if ph.buildPass <= ph.partition.NumPasses() { + ph.partition.StartBuildPass() + return + } + + for i := range ph.clusterMap { + if i%NUM_PASSES != (ph.buildPass - ph.partition.NumPasses()) { + continue + } + ph.clusterMap[i].StartBuild([]index.CacheType{ph.caches[i]}) + } +} + +func (ph *partitionedHNSW[T]) EndBuild() []int { + res := []int{} + + if ph.buildPass > ph.partition.NumPasses() { + for i := range ph.clusterMap { + if i%NUM_PASSES != (ph.buildPass - ph.partition.NumPasses()) { + continue + } + ph.clusterMap[i].EndBuild() + res = append(res, i) + } + } + + ph.buildPass += 1 + + if len(res) > 0 { + return res + } + + if ph.buildPass <= ph.partition.NumPasses() { + ph.partition.EndBuildPass() + } + return []int{} +} + +func (ph *partitionedHNSW[T]) Insert(ctx context.Context, txn index.CacheType, uid uint64, vec []T) ([]*index.KeyValue, error) { + index, err := ph.partition.FindIndexForInsert(vec) + if err != nil { + return nil, err + } + return ph.clusterMap[index].Insert(ctx, txn, uid, vec) +} + +func (ph *partitionedHNSW[T]) Search(ctx context.Context, txn index.CacheType, query []T, maxResults int, filter index.SearchFilter[T]) ([]uint64, error) { + indexes, err := ph.partition.FindIndexForSearch(query) + if err != nil { + return nil, err + } + res := []uint64{} + mutex := &sync.Mutex{} + var wg sync.WaitGroup + for _, index := range indexes { + wg.Add(1) + go func(i int) { + defer wg.Done() + ids, err := ph.clusterMap[i].Search(ctx, txn, query, maxResults, filter) + if err != nil { + return + } + mutex.Lock() + res = append(res, ids...) + mutex.Unlock() + }(index) + } + wg.Wait() + return ph.clusterMap[0].MergeResults(ctx, txn, res, query, maxResults, filter) +} + +func (ph *partitionedHNSW[T]) SearchWithPath(ctx context.Context, txn index.CacheType, query []T, maxResults int, filter index.SearchFilter[T]) (*index.SearchPathResult, error) { + indexes, err := ph.partition.FindIndexForSearch(query) + if err != nil { + return nil, err + } + return ph.clusterMap[indexes[0]].SearchWithPath(ctx, txn, query, maxResults, filter) +} + +func (ph *partitionedHNSW[T]) SearchWithUid(ctx context.Context, txn index.CacheType, uid uint64, maxResults int, filter index.SearchFilter[T]) ([]uint64, error) { + // #TODO + return ph.clusterMap[0].SearchWithUid(ctx, txn, uid, maxResults, filter) +} + +func (ph *partitionedHNSW[T]) MergeResults(ctx context.Context, txn index.CacheType, list []uint64, query []T, maxResults int, filter index.SearchFilter[T]) ([]uint64, error) { + return ph.clusterMap[0].MergeResults(ctx, txn, list, query, maxResults, filter) +} diff --git a/worker/task.go b/worker/task.go index 0ecd371d842..1b97bc98702 100644 --- a/worker/task.go +++ b/worker/task.go @@ -12,7 +12,6 @@ import ( "sort" "strconv" "strings" - "sync" "time" "github.com/golang/glog" @@ -362,34 +361,17 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er args.q.ReadTs, ) - var nnUids []uint64 - var wg sync.WaitGroup - wg.Add(1000) - var mutex sync.Mutex - for i := range 1000 { - go func(idx int) { - nnuids := make([]uint64, 0) - indexer, _ := cspec.CreateIndex(args.q.Attr, i) - if srcFn.vectorInfo != nil { - nnuids, _ = indexer.Search(ctx, qc, srcFn.vectorInfo, - int(numNeighbors), index.AcceptAll[float32]) - } else { - nnuids, _ = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, - int(numNeighbors), index.AcceptAll[float32]) - } - mutex.Lock() - nnUids = append(nnUids, nnuids...) - mutex.Unlock() - wg.Done() - }(i) - } - wg.Wait() - indexer, _ := cspec.CreateIndex(args.q.Attr, 0) - nnUids, err = indexer.MergeResults(ctx, qc, nnUids, srcFn.vectorInfo, + indexer, err := cspec.CreateIndex(args.q.Attr) + if err != nil { + return err + } + + nnUids, err := indexer.Search(ctx, qc, srcFn.vectorInfo, int(numNeighbors), index.AcceptAll[float32]) if err != nil { return err } + sort.Slice(nnUids, func(i, j int) bool { return nnUids[i] < nnUids[j] }) args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: nnUids}) return nil From bbaa7fbed2b7f241aafd4e0b405daa1d5a93e05a Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Wed, 25 Jun 2025 22:44:58 +0530 Subject: [PATCH 04/20] added changes --- tok/tok.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tok/tok.go b/tok/tok.go index c74c7a9d10b..e20a647c43b 100644 --- a/tok/tok.go +++ b/tok/tok.go @@ -21,6 +21,7 @@ import ( "github.com/hypermodeinc/dgraph/v25/protos/pb" "github.com/hypermodeinc/dgraph/v25/tok/hnsw" opts "github.com/hypermodeinc/dgraph/v25/tok/options" + "github.com/hypermodeinc/dgraph/v25/tok/partitioned_hnsw" "github.com/hypermodeinc/dgraph/v25/types" "github.com/hypermodeinc/dgraph/v25/x" ) @@ -85,6 +86,7 @@ var indexFactories = make(map[string]IndexFactory) func init() { registerTokenizer(BigFloatTokenizer{}) registerIndexFactory(createIndexFactory(hnsw.CreateFactory[float32](32))) + registerIndexFactory(createIndexFactory(partitioned_hnsw.CreateFactory[float32](32))) registerTokenizer(GeoTokenizer{}) registerTokenizer(IntTokenizer{}) registerTokenizer(FloatTokenizer{}) From c8951dc9402bfd897a75cfab17ad72be573a3038 Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Wed, 25 Jun 2025 22:54:58 +0530 Subject: [PATCH 05/20] added changes --- schema/parse.go | 2 +- tok/tok.go | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/schema/parse.go b/schema/parse.go index 776596d292b..1fd4d64128d 100644 --- a/schema/parse.go +++ b/schema/parse.go @@ -306,7 +306,7 @@ func parseTokenOrVectorIndexSpec( tokenizer, has := tok.GetTokenizer(tokenOrFactoryName) if !has { return tokenOrFactoryName, nil, false, - next.Errorf("Invalid tokenizer %s", next.Val) + next.Errorf("Invalid tokenizer 1 %s", next.Val) } tokenizerType, ok := types.TypeForName(tokenizer.Type()) x.AssertTrue(ok) // Type is validated during tokenizer loading. diff --git a/tok/tok.go b/tok/tok.go index e20a647c43b..b7f6d3976c0 100644 --- a/tok/tok.go +++ b/tok/tok.go @@ -7,6 +7,7 @@ package tok import ( "encoding/binary" + "fmt" "math/big" "plugin" "strings" @@ -156,6 +157,7 @@ func GetTokenizer(name string) (Tokenizer, bool) { // GetIndexFactory returns IndexFactory given name. func GetIndexFactory(name string) (IndexFactory, bool) { + fmt.Println("HERE GET INDEX FACTORY", indexFactories) f, found := indexFactories[name] return f, found } From a87231a20709f33a9ee946f7387025f20884be7b Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Wed, 25 Jun 2025 22:55:56 +0530 Subject: [PATCH 06/20] added changes --- tok/tok.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tok/tok.go b/tok/tok.go index b7f6d3976c0..3931a27972d 100644 --- a/tok/tok.go +++ b/tok/tok.go @@ -157,7 +157,7 @@ func GetTokenizer(name string) (Tokenizer, bool) { // GetIndexFactory returns IndexFactory given name. func GetIndexFactory(name string) (IndexFactory, bool) { - fmt.Println("HERE GET INDEX FACTORY", indexFactories) + fmt.Println("HERE GET INDEX FACTORY", indexFactories, name) f, found := indexFactories[name] return f, found } From 874638aca22d4afe236cc533fef14483847f020d Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Wed, 25 Jun 2025 22:57:03 +0530 Subject: [PATCH 07/20] added changes --- tok/partitioned_hnsw/partitioned_factory.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tok/partitioned_hnsw/partitioned_factory.go b/tok/partitioned_hnsw/partitioned_factory.go index 28dc608a362..3d803a3fa6c 100644 --- a/tok/partitioned_hnsw/partitioned_factory.go +++ b/tok/partitioned_hnsw/partitioned_factory.go @@ -19,7 +19,7 @@ import ( const ( NumClustersOpt string = "numClusters" PartitionStratOpt string = "partitionStratOpt" - PartitionedHNSW string = "partionedHNSW" + PartitionedHNSW string = "partionedhnsw" ) type partitionedHNSWIndexFactory[T c.Float] struct { From 09f4093f049385c35e5b0b1912cb972235a3436f Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Wed, 25 Jun 2025 23:00:13 +0530 Subject: [PATCH 08/20] added changes --- tok/partitioned_hnsw/partitioned_hnsw.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tok/partitioned_hnsw/partitioned_hnsw.go b/tok/partitioned_hnsw/partitioned_hnsw.go index 5dcc6dd564f..49a6d296c50 100644 --- a/tok/partitioned_hnsw/partitioned_hnsw.go +++ b/tok/partitioned_hnsw/partitioned_hnsw.go @@ -32,13 +32,8 @@ type partitionedHNSW[T c.Float] struct { } func (ph *partitionedHNSW[T]) applyOptions(o opt.Options) error { - if o.Specifies(NumClustersOpt) { - ph.numClusters, _, _ = opt.GetOpt(o, NumClustersOpt, 1000) - } - - if o.Specifies(PartitionStratOpt) { - ph.partitionStrat, _, _ = opt.GetOpt(o, PartitionStratOpt, "kmeans") - } + ph.numClusters, _, _ = opt.GetOpt(o, NumClustersOpt, 1000) + ph.partitionStrat, _, _ = opt.GetOpt(o, PartitionStratOpt, "kmeans") if ph.partitionStrat != "kmeans" && ph.partitionStrat != "query" { return errors.New("partition strategy must be kmeans or query") From 7b14cac9514a345b0e03315b8abf70a1511c57f1 Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Thu, 26 Jun 2025 06:45:41 +0530 Subject: [PATCH 09/20] added changes --- tok/hnsw/persistent_hnsw.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index f1b0e4f3d84..729a8ee1b03 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -141,7 +141,20 @@ func (ph *persistentHNSW[T]) NumThreads() int { } func (ph *persistentHNSW[T]) BuildInsert(ctx context.Context, uid uint64, vec []T) error { - _, err := ph.Insert(ctx, ph.cache, uid, vec) + newPh := &persistentHNSW[T]{ + maxLevels: ph.maxLevels, + efConstruction: ph.efConstruction, + efSearch: ph.efSearch, + pred: ph.pred, + vecEntryKey: ph.vecEntryKey, + vecKey: ph.vecKey, + vecDead: ph.vecDead, + simType: ph.simType, + floatBits: ph.floatBits, + nodeAllEdges: make(map[uint64][][]uint64), + cache: ph.cache, + } + _, err := newPh.Insert(ctx, ph.cache, uid, vec) return err } From 1cf52d89024c8b4fc6ac95673c26029e2a86a5e1 Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Thu, 26 Jun 2025 07:03:01 +0530 Subject: [PATCH 10/20] added changes --- tok/partitioned_hnsw/partitioned_hnsw.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tok/partitioned_hnsw/partitioned_hnsw.go b/tok/partitioned_hnsw/partitioned_hnsw.go index 49a6d296c50..a42a44e7ff4 100644 --- a/tok/partitioned_hnsw/partitioned_hnsw.go +++ b/tok/partitioned_hnsw/partitioned_hnsw.go @@ -21,7 +21,6 @@ type partitionedHNSW[T c.Float] struct { clusterMap map[int]index.VectorIndex[T] numClusters int - factory index.IndexFactory[T] partition index.VectorPartitionStrat[T] hnswOptions opt.Options @@ -45,9 +44,9 @@ func (ph *partitionedHNSW[T]) applyOptions(o opt.Options) error { ph.buildPass = 0 ph.hnswOptions = o - ph.factory = hnsw.CreateFactory[T](ph.floatBits) for i := range ph.numClusters { - vi, err := ph.factory.Create(ph.pred, ph.hnswOptions, ph.floatBits) + factory := hnsw.CreateFactory[T](ph.floatBits) + vi, err := factory.Create(ph.pred, ph.hnswOptions, ph.floatBits) if err != nil { return err } From b66cfc6d2d3838d9f98646d8615fa5e3be2c6cee Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Thu, 26 Jun 2025 07:05:17 +0530 Subject: [PATCH 11/20] added changes --- tok/partitioned_hnsw/partitioned_hnsw.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tok/partitioned_hnsw/partitioned_hnsw.go b/tok/partitioned_hnsw/partitioned_hnsw.go index a42a44e7ff4..d94e3d976da 100644 --- a/tok/partitioned_hnsw/partitioned_hnsw.go +++ b/tok/partitioned_hnsw/partitioned_hnsw.go @@ -89,7 +89,7 @@ func (ph *partitionedHNSW[T]) NumIndexPasses() int { } func (ph *partitionedHNSW[T]) NumThreads() int { - return NUM_PASSES + return ph.numClusters } func (ph *partitionedHNSW[T]) NumSeedVectors() int { From 8a2dad2544fe02648797612f18df44e84307a35c Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Thu, 26 Jun 2025 07:17:16 +0530 Subject: [PATCH 12/20] added changes --- tok/partitioned_hnsw/partitioned_hnsw.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tok/partitioned_hnsw/partitioned_hnsw.go b/tok/partitioned_hnsw/partitioned_hnsw.go index d94e3d976da..cada5ff5b7b 100644 --- a/tok/partitioned_hnsw/partitioned_hnsw.go +++ b/tok/partitioned_hnsw/partitioned_hnsw.go @@ -6,6 +6,7 @@ package partitioned_hnsw import ( "context" "errors" + "fmt" "sync" c "github.com/hypermodeinc/dgraph/v25/tok/constraints" @@ -158,6 +159,7 @@ func (ph *partitionedHNSW[T]) Search(ctx context.Context, txn index.CacheType, q defer wg.Done() ids, err := ph.clusterMap[i].Search(ctx, txn, query, maxResults, filter) if err != nil { + fmt.Println("Error", err) return } mutex.Lock() From 9c5d1756b781162376ee531c59f53878cf5cf240 Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Thu, 26 Jun 2025 07:19:02 +0530 Subject: [PATCH 13/20] added changes --- tok/partitioned_hnsw/partitioned_hnsw.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tok/partitioned_hnsw/partitioned_hnsw.go b/tok/partitioned_hnsw/partitioned_hnsw.go index cada5ff5b7b..bcb214c76ce 100644 --- a/tok/partitioned_hnsw/partitioned_hnsw.go +++ b/tok/partitioned_hnsw/partitioned_hnsw.go @@ -163,11 +163,13 @@ func (ph *partitionedHNSW[T]) Search(ctx context.Context, txn index.CacheType, q return } mutex.Lock() + fmt.Println("Addign result:", ids) res = append(res, ids...) mutex.Unlock() }(index) } wg.Wait() + fmt.Println("Result:", res) return ph.clusterMap[0].MergeResults(ctx, txn, res, query, maxResults, filter) } From 7b15cac1ee0f92befb16301fd67d799aafd2f637 Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Thu, 26 Jun 2025 07:21:04 +0530 Subject: [PATCH 14/20] added changes --- tok/kmeans/kmeans.go | 2 +- tok/partitioned_hnsw/partitioned_hnsw.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tok/kmeans/kmeans.go b/tok/kmeans/kmeans.go index 00768b1f006..6f6693d9300 100644 --- a/tok/kmeans/kmeans.go +++ b/tok/kmeans/kmeans.go @@ -33,7 +33,7 @@ func (km *Kmeans[T]) AddVector(vec []T) error { } func (km *Kmeans[T]) FindIndexForSearch(vec []T) ([]int, error) { - res := make([]int, len(km.centroids.centroids)) + res := make([]int, km.NumSeedVectors()) for i := range res { res[i] = i } diff --git a/tok/partitioned_hnsw/partitioned_hnsw.go b/tok/partitioned_hnsw/partitioned_hnsw.go index bcb214c76ce..31e2ed7d149 100644 --- a/tok/partitioned_hnsw/partitioned_hnsw.go +++ b/tok/partitioned_hnsw/partitioned_hnsw.go @@ -169,7 +169,7 @@ func (ph *partitionedHNSW[T]) Search(ctx context.Context, txn index.CacheType, q }(index) } wg.Wait() - fmt.Println("Result:", res) + fmt.Println("Result:", res, indexes) return ph.clusterMap[0].MergeResults(ctx, txn, res, query, maxResults, filter) } From 9afdb8c4e0a5129f2d299e7aaa133aeb14fc3837 Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Thu, 26 Jun 2025 07:23:07 +0530 Subject: [PATCH 15/20] added changes --- tok/partitioned_hnsw/partitioned_hnsw.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tok/partitioned_hnsw/partitioned_hnsw.go b/tok/partitioned_hnsw/partitioned_hnsw.go index 31e2ed7d149..d94e3d976da 100644 --- a/tok/partitioned_hnsw/partitioned_hnsw.go +++ b/tok/partitioned_hnsw/partitioned_hnsw.go @@ -6,7 +6,6 @@ package partitioned_hnsw import ( "context" "errors" - "fmt" "sync" c "github.com/hypermodeinc/dgraph/v25/tok/constraints" @@ -159,17 +158,14 @@ func (ph *partitionedHNSW[T]) Search(ctx context.Context, txn index.CacheType, q defer wg.Done() ids, err := ph.clusterMap[i].Search(ctx, txn, query, maxResults, filter) if err != nil { - fmt.Println("Error", err) return } mutex.Lock() - fmt.Println("Addign result:", ids) res = append(res, ids...) mutex.Unlock() }(index) } wg.Wait() - fmt.Println("Result:", res, indexes) return ph.clusterMap[0].MergeResults(ctx, txn, res, query, maxResults, filter) } From b2700d8bbf5755a2c99ccc5f53c3163a6e651488 Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Thu, 26 Jun 2025 07:40:54 +0530 Subject: [PATCH 16/20] added changes --- tok/partitioned_hnsw/partitioned_hnsw.go | 11 ++++++++--- tok/tok.go | 2 -- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tok/partitioned_hnsw/partitioned_hnsw.go b/tok/partitioned_hnsw/partitioned_hnsw.go index d94e3d976da..cc7f22f20fc 100644 --- a/tok/partitioned_hnsw/partitioned_hnsw.go +++ b/tok/partitioned_hnsw/partitioned_hnsw.go @@ -26,8 +26,9 @@ type partitionedHNSW[T c.Float] struct { hnswOptions opt.Options partitionStrat string - caches []index.CacheType - buildPass int + caches []index.CacheType + buildPass int + buildSyncMaps map[int]*sync.Mutex } func (ph *partitionedHNSW[T]) applyOptions(o opt.Options) error { @@ -75,7 +76,10 @@ func (ph *partitionedHNSW[T]) BuildInsert(ctx context.Context, uuid uint64, vec if index%NUM_PASSES != passIdx { return nil } - return ph.clusterMap[index].BuildInsert(ctx, uuid, vec) + ph.buildSyncMaps[index].Lock() + defer ph.buildSyncMaps[index].Unlock() + _, err = ph.clusterMap[index].Insert(ctx, ph.caches[index], uuid, vec) + return err } const NUM_PASSES = 10 @@ -107,6 +111,7 @@ func (ph *partitionedHNSW[T]) StartBuild(caches []index.CacheType) { if i%NUM_PASSES != (ph.buildPass - ph.partition.NumPasses()) { continue } + ph.buildSyncMaps[i] = &sync.Mutex{} ph.clusterMap[i].StartBuild([]index.CacheType{ph.caches[i]}) } } diff --git a/tok/tok.go b/tok/tok.go index 3931a27972d..e20a647c43b 100644 --- a/tok/tok.go +++ b/tok/tok.go @@ -7,7 +7,6 @@ package tok import ( "encoding/binary" - "fmt" "math/big" "plugin" "strings" @@ -157,7 +156,6 @@ func GetTokenizer(name string) (Tokenizer, bool) { // GetIndexFactory returns IndexFactory given name. func GetIndexFactory(name string) (IndexFactory, bool) { - fmt.Println("HERE GET INDEX FACTORY", indexFactories, name) f, found := indexFactories[name] return f, found } From 347d86fe06199b1d3901c3d0cc1b9f4c5cb6d575 Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Thu, 26 Jun 2025 07:42:23 +0530 Subject: [PATCH 17/20] added changes --- tok/partitioned_hnsw/partitioned_factory.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tok/partitioned_hnsw/partitioned_factory.go b/tok/partitioned_hnsw/partitioned_factory.go index 3d803a3fa6c..8c925ca003a 100644 --- a/tok/partitioned_hnsw/partitioned_factory.go +++ b/tok/partitioned_hnsw/partitioned_factory.go @@ -96,9 +96,10 @@ func (hf *partitionedHNSWIndexFactory[T]) createWithLock( return nil, err } retVal := &partitionedHNSW[T]{ - pred: name, - floatBits: floatBits, - clusterMap: map[int]index.VectorIndex[T]{}, + pred: name, + floatBits: floatBits, + clusterMap: map[int]index.VectorIndex[T]{}, + buildSyncMaps: map[int]*sync.Mutex{}, } err := retVal.applyOptions(o) if err != nil { From 9c24bd1cb1fb4bb41d42207b6210ce79818f113e Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Thu, 26 Jun 2025 07:44:48 +0530 Subject: [PATCH 18/20] added changes --- tok/partitioned_hnsw/partitioned_hnsw.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tok/partitioned_hnsw/partitioned_hnsw.go b/tok/partitioned_hnsw/partitioned_hnsw.go index cc7f22f20fc..1897679d2f7 100644 --- a/tok/partitioned_hnsw/partitioned_hnsw.go +++ b/tok/partitioned_hnsw/partitioned_hnsw.go @@ -108,10 +108,10 @@ func (ph *partitionedHNSW[T]) StartBuild(caches []index.CacheType) { } for i := range ph.clusterMap { + ph.buildSyncMaps[i] = &sync.Mutex{} if i%NUM_PASSES != (ph.buildPass - ph.partition.NumPasses()) { continue } - ph.buildSyncMaps[i] = &sync.Mutex{} ph.clusterMap[i].StartBuild([]index.CacheType{ph.caches[i]}) } } From 91b443b0d0646fa6a337925373909f6d48f114cc Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Thu, 26 Jun 2025 07:59:03 +0530 Subject: [PATCH 19/20] added changes --- tok/partitioned_hnsw/partitioned_hnsw.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tok/partitioned_hnsw/partitioned_hnsw.go b/tok/partitioned_hnsw/partitioned_hnsw.go index 1897679d2f7..b36258beaff 100644 --- a/tok/partitioned_hnsw/partitioned_hnsw.go +++ b/tok/partitioned_hnsw/partitioned_hnsw.go @@ -102,7 +102,7 @@ func (ph *partitionedHNSW[T]) NumSeedVectors() int { func (ph *partitionedHNSW[T]) StartBuild(caches []index.CacheType) { ph.caches = caches - if ph.buildPass <= ph.partition.NumPasses() { + if ph.buildPass < ph.partition.NumPasses() { ph.partition.StartBuildPass() return } @@ -119,7 +119,7 @@ func (ph *partitionedHNSW[T]) StartBuild(caches []index.CacheType) { func (ph *partitionedHNSW[T]) EndBuild() []int { res := []int{} - if ph.buildPass > ph.partition.NumPasses() { + if ph.buildPass >= ph.partition.NumPasses() { for i := range ph.clusterMap { if i%NUM_PASSES != (ph.buildPass - ph.partition.NumPasses()) { continue @@ -135,7 +135,7 @@ func (ph *partitionedHNSW[T]) EndBuild() []int { return res } - if ph.buildPass <= ph.partition.NumPasses() { + if ph.buildPass < ph.partition.NumPasses() { ph.partition.EndBuildPass() } return []int{} From 86b2df294b71ed0d16b267f6ceaed4cbd90a81e2 Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Sat, 5 Jul 2025 01:58:30 +0530 Subject: [PATCH 20/20] added filter for vector lengths --- posting/index.go | 46 ++++++++++++++++++++++++++++++++++++++++++++ tok/kmeans/kmeans.go | 9 +++++++++ 2 files changed, 55 insertions(+) diff --git a/posting/index.go b/posting/index.go index 667915ae000..c5c9f30f66f 100644 --- a/posting/index.go +++ b/posting/index.go @@ -1373,6 +1373,40 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp return err } + numVectorsToCheck := 100 + lenFreq := make(map[int]int, numVectorsToCheck) + maxFreq := 0 + dimension := 0 + MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{ + Prefix: pk.DataPrefix(), + ReadTs: rb.StartTs, + AllVersions: false, + Reverse: false, + CheckInclusion: func(uid uint64) error { + return nil + }, + Function: func(l *List, pk x.ParsedKey) error { + val, err := l.Value(rb.StartTs) + if err != nil { + return err + } + inVec := types.BytesAsFloatArray(val.Value.([]byte)) + lenFreq[len(inVec)] += 1 + if lenFreq[len(inVec)] > maxFreq { + maxFreq = lenFreq[len(inVec)] + dimension = len(inVec) + } + numVectorsToCheck -= 1 + if numVectorsToCheck <= 0 { + return ErrStopIteration + } + return nil + }, + StartKey: x.DataKey(rb.Attr, 0), + }) + + fmt.Println("Selecting vector dimension to be:", dimension) + if indexer.NumSeedVectors() > 0 { count := 0 MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{ @@ -1389,6 +1423,9 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp return err } inVec := types.BytesAsFloatArray(val.Value.([]byte)) + if len(inVec) != dimension { + return nil + } count += 1 indexer.AddSeedVector(inVec) if count == indexer.NumSeedVectors() { @@ -1423,6 +1460,9 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp } inVec := types.BytesAsFloatArray(val.Value.([]byte)) + if len(inVec) != dimension { + return []*pb.DirectedEdge{}, nil + } indexer.BuildInsert(ctx, uid, inVec) return edges, nil } @@ -1449,6 +1489,12 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp } inVec := types.BytesAsFloatArray(val.Value.([]byte)) + if len(inVec) != dimension { + if pass_idx == 0 { + glog.Warningf("Skipping vector with invalid dimension uid: %d, dimension: %d", uid, len(inVec)) + } + return []*pb.DirectedEdge{}, nil + } indexer.BuildInsert(ctx, uid, inVec) return edges, nil } diff --git a/tok/kmeans/kmeans.go b/tok/kmeans/kmeans.go index 6f6693d9300..5e4fa01deae 100644 --- a/tok/kmeans/kmeans.go +++ b/tok/kmeans/kmeans.go @@ -7,6 +7,7 @@ import ( c "github.com/hypermodeinc/dgraph/v25/tok/constraints" "github.com/hypermodeinc/dgraph/v25/tok/index" + "github.com/hypermodeinc/dgraph/v25/x" ) type Kmeans[T c.Float] struct { @@ -106,8 +107,13 @@ func (vc *vectorCentroids[T]) addVector(vec []T) error { } func (vc *vectorCentroids[T]) updateCentroids() { + x.AssertTrue(len(vc.centroids) == vc.numCenters) + x.AssertTrue(len(vc.counts) == vc.numCenters) + x.AssertTrue(len(vc.weights) == vc.numCenters) for i := 0; i < vc.numCenters; i++ { for j := 0; j < vc.dimension; j++ { + x.AssertTrue(len(vc.centroids[i]) == vc.dimension) + x.AssertTrue(len(vc.weights[i]) == vc.dimension) vc.centroids[i][j] = vc.weights[i][j] / T(vc.counts[i]) vc.weights[i][j] = 0 } @@ -119,6 +125,9 @@ func (vc *vectorCentroids[T]) updateCentroids() { func (vc *vectorCentroids[T]) randomInit() { vc.dimension = len(vc.centroids[0]) + for i := range vc.centroids { + x.AssertTrue(len(vc.centroids[i]) == vc.dimension) + } vc.numCenters = len(vc.centroids) vc.counts = make([]int64, vc.numCenters) vc.weights = make([][]T, vc.numCenters)