diff --git a/src/python_bindings/bindings.cpp b/src/python_bindings/bindings.cpp index 7d5be1e8e..d951510e7 100644 --- a/src/python_bindings/bindings.cpp +++ b/src/python_bindings/bindings.cpp @@ -234,7 +234,9 @@ class PyVecSimIndex { class PyHNSWLibIndex : public PyVecSimIndex { private: - std::shared_mutex indexGuard; // to protect parallel operations on the index. + std::shared_ptr + indexGuard; // to protect parallel operations on the index. Make sure to release the GIL + // while locking the mutex. template // size_t/double for KNN/range queries. using QueryFunc = std::function; @@ -256,9 +258,10 @@ class PyHNSWLibIndex : public PyVecSimIndex { if (ind >= n_queries) { break; } - indexGuard.lock_shared(); - results[ind] = queryFunc((const char *)items.data(ind), param, query_params); - indexGuard.unlock_shared(); + { + std::shared_lock lock(*indexGuard); + results[ind] = queryFunc((const char *)items.data(ind), param, query_params); + } } }; std::thread thread_objs[n_threads]; @@ -279,12 +282,14 @@ class PyHNSWLibIndex : public PyVecSimIndex { VecSimParams params = {.algo = VecSimAlgo_HNSWLIB, .algoParams = {.hnswParams = HNSWParams{hnsw_params}}}; this->index = std::shared_ptr(VecSimIndex_New(¶ms), VecSimIndex_Free); + this->indexGuard = std::make_shared(); } // @params is required only in V1. explicit PyHNSWLibIndex(const std::string &location) { this->index = std::shared_ptr(HNSWFactory::NewIndex(location), VecSimIndex_Free); + this->indexGuard = std::make_shared(); } void setDefaultEf(size_t ef) { @@ -401,15 +406,16 @@ class PyHNSWLibIndex : public PyVecSimIndex { break; } if (ind % block_size != 0) { - indexGuard.lock_shared(); + // Read lock for normal operations + indexGuard->lock_shared(); exclusive = false; } else { - // Lock exclusively if we are performing resizing due to a new block. - indexGuard.lock(); + // Exclusive lock for block resizing operations + indexGuard->lock(); } barrier.unlock(); this->addVectorInternal((const char *)data.data(ind), labels.at(ind)); - exclusive ? indexGuard.unlock() : indexGuard.unlock_shared(); + exclusive ? indexGuard->unlock() : indexGuard->unlock_shared(); } }; std::thread thread_objs[n_threads]; @@ -457,12 +463,15 @@ class PyHNSWLibIndex : public PyVecSimIndex { } PyBatchIterator createBatchIterator(const py::object &input, VecSimQueryParams *query_params) override { + py::array query(input); - auto del = [&](VecSimBatchIterator *pyBatchIter) { + py::gil_scoped_release py_gil; + // Passing indexGuardPtr by value, so that the refCount of the mutex + auto del = [indexGuardPtr = this->indexGuard](VecSimBatchIterator *pyBatchIter) { VecSimBatchIterator_Free(pyBatchIter); - this->indexGuard.unlock_shared(); + indexGuardPtr->unlock_shared(); }; - indexGuard.lock_shared(); + indexGuard->lock_shared(); auto py_batch_ptr = std::shared_ptr( VecSimBatchIterator_New(index.get(), (const char *)query.data(0), query_params), del); return PyBatchIterator(index, py_batch_ptr);