diff --git a/.vscode/settings.json b/.vscode/settings.json index 604bfa2c..282c11dc 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -122,7 +122,8 @@ "xtree": "cpp", "xutility": "cpp", "execution": "cpp", - "text_encoding": "cpp" + "text_encoding": "cpp", + "__functional_03": "cpp" }, "cSpell.words": [ "allclose", @@ -151,6 +152,7 @@ "FAISS", "fbin", "furo", + "geospatial", "googleanalytics", "groundtruth", "hashable", @@ -170,6 +172,7 @@ "longlong", "memmap", "MSVC", + "Multimodal", "Napi", "ndarray", "NDCG", @@ -208,7 +211,10 @@ "usecases", "Vardanian", "vectorize", - "Xunit" + "Vincenty", + "Wasmer", + "Xunit", + "Yuga" ], "autoDocstring.docstringFormat": "sphinx", "java.configuration.updateBuildConfiguration": "interactive", @@ -225,5 +231,11 @@ "editor.formatOnSave": true, "editor.defaultFormatter": "golang.go" }, + "editor.tabSize": 4, + "editor.insertSpaces": true, + "prettier.singleQuote": true, + "prettier.tabWidth": 4, + "prettier.useTabs": false + "dotnet.defaultSolution": "csharp/Cloud.Unum.USearch.sln" } \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5870e29f..6caf9570 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -226,10 +226,10 @@ nvm install 20 Testing: ```sh -npm install -g typescript -npm install -npm run build-js -npm test +npm install -g typescript # Install TypeScript globally +npm install # Compile `javascript/lib.cpp` +npm run build-js # Generate JS from TS +npm test # Run the test suite ``` To compile for AWS Lambda you'd need to recompile the binding. diff --git a/README.md b/README.md index c7d9c2cc..b33b1b6a 100644 --- a/README.md +++ b/README.md @@ -161,7 +161,7 @@ This can result in __20x cost reduction__ on AWS and other public clouds. index.save("index.usearch") loaded_copy = index.load("index.usearch") -view = Index.restore("index.usearch", view=True) +view = Index.restore("index.usearch", view=True, ...) other_view = Index(ndim=..., metric=...) other_view.view("index.usearch") @@ -528,7 +528,11 @@ index = Index(ndim=ndim, metric=CompiledMetric( - [x] ClickHouse: [C++](https://github.com/ClickHouse/ClickHouse/pull/53447), [docs](https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/annindexes#usearch). - [x] DuckDB: [post](https://duckdb.org/2024/05/03/vector-similarity-search-vss.html). +- [x] ScyllaDB: [Rust](https://github.com/scylladb/vector-store), [presentation](https://www.slideshare.net/slideshow/vector-search-with-scylladb-by-szymon-wasik/276571548). +- [x] TiDB & TiFlash: [C++](https://github.com/pingcap/tiflash), [announcement](https://www.pingcap.com/article/introduce-vector-search-indexes-in-tidb/). +- [x] YugaByte: [C++](https://github.com/yugabyte/yugabyte-db/blob/366b9f5e3c4df3a1a17d553db41d6dc50146f488/src/yb/vector_index/usearch_wrapper.cc). - [x] Google: [UniSim](https://github.com/google/unisim), [RetSim](https://arxiv.org/abs/2311.17264) paper. +- [x] MemGraph: [C++](https://github.com/memgraph/memgraph/blob/784dd8520f65050d033aea8b29446e84e487d091/src/storage/v2/indices/vector_index.cpp), [announcement](https://memgraph.com/blog/simplify-data-retrieval-memgraph-vector-search). - [x] LanternDB: [C++](https://github.com/lanterndata/lantern), [Rust](https://github.com/lanterndata/lantern_extras), [docs](https://lantern.dev/blog/hnsw-index-creation). - [x] LangChain: [Python](https://github.com/langchain-ai/langchain/releases/tag/v0.0.257) and [JavaScript](https://github.com/hwchase17/langchainjs/releases/tag/0.0.125). - [x] Microsoft Semantic Kernel: [Python](https://github.com/microsoft/semantic-kernel/releases/tag/python-0.3.9.dev) and C#. diff --git a/cpp/test.cpp b/cpp/test.cpp index bcafbc8c..f11abb4e 100644 --- a/cpp/test.cpp +++ b/cpp/test.cpp @@ -877,7 +877,7 @@ void test_absurd(std::size_t dimensions, std::size_t connectivity, std::size_t e template void test_exact_search(std::size_t dataset_count, std::size_t queries_count, std::size_t wanted_count) { std::size_t dimensions = 32; - metric_punned_t metric(dimensions, metric_kind_t::cos_k); + metric_punned_t metric(dimensions, metric_kind_t::cos_k, scalar_kind()); std::random_device rd; std::mt19937 gen(rd()); @@ -886,9 +886,9 @@ void test_exact_search(std::size_t dataset_count, std::size_t queries_count, std std::generate(dataset.begin(), dataset.end(), [&] { return static_cast(dis(gen)); }); exact_search_t search; - auto results = search( // - (byte_t const*)dataset.data(), dataset_count, dimensions * sizeof(float), // - (byte_t const*)dataset.data(), queries_count, dimensions * sizeof(float), // + auto results = search( // + (byte_t const*)dataset.data(), dataset_count, dimensions * sizeof(scalar_at), // + (byte_t const*)dataset.data(), queries_count, dimensions * sizeof(scalar_at), // wanted_count, metric); for (std::size_t i = 0; i < results.size(); ++i) @@ -1098,6 +1098,51 @@ template void test_replacing_update() { expect_eq(final_search[2].member.key, 44); } +/** + * Tests the filtered search functionality of the index. + */ +void test_filtered_search() { + constexpr std::size_t dataset_count = 2048; + constexpr std::size_t dimensions = 32; + metric_punned_t metric(dimensions, metric_kind_t::cos_k); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(0.0, 1.0); + using vector_of_vectors_t = std::vector>; + + vector_of_vectors_t vector_of_vectors(dataset_count); + for (auto& vector : vector_of_vectors) { + vector.resize(dimensions); + std::generate(vector.begin(), vector.end(), [&] { return dis(gen); }); + } + + index_dense_t index = index_dense_t::make(metric); + index.reserve(dataset_count); + for (std::size_t idx = 0; idx < dataset_count; ++idx) + index.add(idx, vector_of_vectors[idx].data()); + expect_eq(index.size(), dataset_count); + + { + auto predicate = [](index_dense_t::key_t key) { return key != 0; }; + auto results = index.filtered_search(vector_of_vectors[0].data(), 10, predicate); + expect_eq(10, results.size()); // ! Should not contain 0 + for (std::size_t i = 0; i != results.size(); ++i) + expect(0 != results[i].member.key); + } + { + auto predicate = [](index_dense_t::key_t) { return false; }; + auto results = index.filtered_search(vector_of_vectors[0].data(), 10, predicate); + expect_eq(0, results.size()); // ! Should not contain 0 + } + { + auto predicate = [](index_dense_t::key_t key) { return key == 10; }; + auto results = index.filtered_search(vector_of_vectors[0].data(), 10, predicate); + expect_eq(1, results.size()); // ! Should not contain 0 + expect_eq(10, results[0].member.key); + } +} + int main(int, char**) { test_uint40(); test_cosine(10, 10); @@ -1174,5 +1219,6 @@ int main(int, char**) { test_sets(set_size, 20, 30); test_strings(); + test_filtered_search(); return 0; } diff --git a/include/usearch/index.hpp b/include/usearch/index.hpp index 3922ae23..acea6385 100644 --- a/include/usearch/index.hpp +++ b/include/usearch/index.hpp @@ -2183,6 +2183,7 @@ class index_gt { */ struct usearch_align_m context_t { top_candidates_t top_candidates{}; + top_candidates_t top_for_refine{}; next_candidates_t next_candidates{}; visits_hash_set_t visits{}; std::default_random_engine level_generator{}; @@ -2498,6 +2499,13 @@ class index_gt { if (nodes_) std::memcpy(new_nodes.data(), nodes_.data(), sizeof(node_t) * size()); + // Pre-reserve the capacity for `top_for_refine`, which always contains at most one more + // element than the connectivity factors. + std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); + for (std::size_t i = 0; i != new_contexts.size(); ++i) + if (!new_contexts[i].top_for_refine.reserve(connectivity_max + 1)) + return false; + limits_ = limits; nodes_capacity_ = limits.members; nodes_ = std::move(new_nodes); @@ -3179,17 +3187,11 @@ class index_gt { std::size_t memory_usage_per_node(level_t level) const noexcept { return node_bytes_(level); } - double inverse_log_connectivity() const { - return pre_.inverse_log_connectivity; - } + double inverse_log_connectivity() const { return pre_.inverse_log_connectivity; } - std::size_t neighbors_base_bytes() const { - return pre_.neighbors_base_bytes; - } + std::size_t neighbors_base_bytes() const { return pre_.neighbors_base_bytes; } - std::size_t neighbors_bytes() const { - return pre_.neighbors_bytes; - } + std::size_t neighbors_bytes() const { return pre_.neighbors_bytes; } #if defined(USEARCH_USE_PRAGMA_REGION) #pragma endregion @@ -3790,7 +3792,7 @@ class index_gt { metric_at&& metric, compressed_slot_t new_slot, candidates_view_t new_neighbors, value_at&& value, level_t level, context_t& context) usearch_noexcept_m { - top_candidates_t& top = context.top_candidates; + top_candidates_t& top_for_refine = context.top_for_refine; std::size_t const connectivity_max = level ? config_.connectivity : config_.connectivity_base; // Reverse links from the neighbors: @@ -3817,19 +3819,16 @@ class index_gt { continue; } - // To fit a new connection we need to drop an existing one. - top.clear(); - usearch_assert_m((top.capacity() >= (close_header.size() + 1)), - "The memory must have been reserved in `add`"); - top.insert_reserved({context.measure(value, citerator_at(close_slot), metric), new_slot}); + top_for_refine.clear(); + top_for_refine.insert_reserved({context.measure(value, citerator_at(close_slot), metric), new_slot}); for (compressed_slot_t successor_slot : close_header) - top.insert_reserved( + top_for_refine.insert_reserved( {context.measure(citerator_at(close_slot), citerator_at(successor_slot), metric), successor_slot}); // Export the results: close_header.clear(); - candidates_view_t top_view = - refine_(metric, connectivity_max, top, context, context.computed_distances_in_reverse_refines); + candidates_view_t top_view = refine_(metric, connectivity_max, top_for_refine, context, + context.computed_distances_in_reverse_refines); usearch_assert_m(top_view.size(), "This would lead to isolated nodes"); for (std::size_t idx = 0; idx != top_view.size(); idx++) close_header.push_back(top_view[idx].slot); @@ -4178,9 +4177,10 @@ class index_gt { // This can substantially grow our priority queue: next.insert({-successor_dist, successor_slot}); if (is_dummy() || - predicate(member_cref_t{node_at_(successor_slot).ckey(), successor_slot})) + predicate(member_cref_t{node_at_(successor_slot).ckey(), successor_slot})) { top.insert({successor_dist, successor_slot}, top_limit); - radius = top.top().distance; + radius = top.top().distance; + } } } } diff --git a/javascript/README.md b/javascript/README.md index cb153a33..7b692b8f 100644 --- a/javascript/README.md +++ b/javascript/README.md @@ -78,6 +78,14 @@ const batchResults = index.search(vectors, 2); const firstMatch = batchResults.get(0); ``` +Multi-threading is supported for batch operations: + +```js +const threads_count = 0; // Zero for auto-detection or pass an unsigned integer +index.add(keys, vectors, threads_count); +const batchResults = index.search(vectors, 2, threads_count); +``` + ## Index Introspection Inspect and interact with the index: diff --git a/javascript/lib.cpp b/javascript/lib.cpp index b259f0b0..0c42efb0 100644 --- a/javascript/lib.cpp +++ b/javascript/lib.cpp @@ -9,7 +9,8 @@ * @see NodeJS docs: https://nodejs.org/api/addons.html#hello-world * */ -#include // `std::bad_alloc` +#include // `std::bad_alloc` +#include // `std::thread::hardware_concurrency()` #define NAPI_CPP_EXCEPTIONS #include @@ -20,7 +21,10 @@ using namespace unum::usearch; using namespace unum; +using index_error_t = usearch::error_t; using add_result_t = typename index_dense_t::add_result_t; +using search_result_t = typename index_dense_t::search_result_t; +using state_result_t = typename index_dense_t::state_result_t; class CompiledIndex : public Napi::ObjectWrap { public: @@ -44,6 +48,7 @@ class CompiledIndex : public Napi::ObjectWrap { Napi::Value Count(Napi::CallbackInfo const& ctx); std::unique_ptr native_; + std::mutex mtx; }; Napi::Object CompiledIndex::Init(Napi::Env env, Napi::Object exports) { @@ -77,7 +82,6 @@ std::size_t napi_argument_to_size(Napi::Value v) { } CompiledIndex::CompiledIndex(Napi::CallbackInfo const& ctx) : Napi::ObjectWrap(ctx) { - // Directly assign the parameters without checks std::size_t dimensions = napi_argument_to_size(ctx[0]); metric_kind_t metric_kind = metric_from_name(ctx[1].As().Utf8Value().c_str()); @@ -95,7 +99,12 @@ CompiledIndex::CompiledIndex(Napi::CallbackInfo const& ctx) : Napi::ObjectWrap(); - std::size_t tasks = keys.ElementLength(); + Napi::TypedArray vectors = ctx[1].As(); + + // Optional arguments + std::size_t threads = napi_argument_to_size(ctx[2]); + if (threads == 0) + threads = std::thread::hardware_concurrency(); - // Ensure there is enough capacity + // Ensure there is enough capacity and memory + std::size_t tasks = keys.ElementLength(); if (native_->size() + tasks >= native_->capacity()) - native_->reserve(ceil2(native_->size() + tasks)); + if (!native_->try_reserve({ceil2(native_->size() + tasks), threads})) { + Napi::TypeError::New(env, "Failed to reserve memory").ThrowAsJavaScriptException(); + return; + } - // Create an instance of the executor with the default number of threads + // Run insertions concurrently auto run_parallel = [&](auto vectors) { - executor_stl_t executor; + // Errors can be set only from the main thread, so before spawning workers + // we need temporary space to keep the message + index_error_t first_error{}; + std::atomic failed{false}; + executor_default_t executor{threads}; executor.fixed(tasks, [&](std::size_t /*thread_idx*/, std::size_t task_idx) { - add_result_t result = native_->add(static_cast(keys[task_idx]), - vectors + task_idx * native_->dimensions()); - if (!result) - Napi::Error::New(ctx.Env(), result.error.release()).ThrowAsJavaScriptException(); + if (failed.load()) + return; + auto key = static_cast(keys[task_idx]); + auto vector = vectors + task_idx * native_->dimensions(); + add_result_t result = native_->add(key, vector); + if (!result) { + if (!failed.exchange(true)) { + first_error = std::move(result.error); + } else { + result.error.release(); + } + } }); + if (failed) + Napi::TypeError::New(env, first_error.release()).ThrowAsJavaScriptException(); }; - Napi::TypedArray vectors = ctx[1].As(); - if (vectors.TypedArrayType() == napi_float32_array) { - run_parallel(vectors.As().Data()); - } else if (vectors.TypedArrayType() == napi_float64_array) { - run_parallel(vectors.As().Data()); - } else if (vectors.TypedArrayType() == napi_int8_array) { - run_parallel(vectors.As().Data()); - } else { - Napi::TypeError::New(ctx.Env(), - "Unsupported TypedArray. Supported types are Float32Array, Float64Array, and Int8Array.") - .ThrowAsJavaScriptException(); + // Dispatch the parallel tasks based on the `TypedArray` type + try { + if (vectors.TypedArrayType() == napi_float32_array) { + run_parallel(vectors.As().Data()); + } else if (vectors.TypedArrayType() == napi_float64_array) { + run_parallel(vectors.As().Data()); + } else if (vectors.TypedArrayType() == napi_int8_array) { + run_parallel(vectors.As().Data()); + } else { + Napi::TypeError::New( + env, "Unsupported TypedArray. Supported types are Float32Array, Float64Array, and Int8Array.") + .ThrowAsJavaScriptException(); + } + } catch (...) { + Napi::TypeError::New(env, "Insertion failed").ThrowAsJavaScriptException(); } } Napi::Value CompiledIndex::Search(Napi::CallbackInfo const& ctx) { Napi::Env env = ctx.Env(); + + // Check the number of arguments + if (ctx.Length() != 3) { + Napi::TypeError::New(env, "`Search` expects 3 arguments: queries, k[, threads]").ThrowAsJavaScriptException(); + return env.Null(); + } + + // Extract mandatory arguments Napi::TypedArray queries = ctx[0].As(); - std::size_t tasks = queries.ElementLength() / native_->dimensions(); std::size_t wanted = napi_argument_to_size(ctx[1]); + std::size_t threads = napi_argument_to_size(ctx[2]); + if (threads == 0) + threads = std::thread::hardware_concurrency(); + // Run queries concurrently + std::size_t tasks = queries.ElementLength() / native_->dimensions(); auto run_parallel = [&](auto vectors) -> Napi::Value { Napi::Array result_js = Napi::Array::New(env, 3); Napi::BigUint64Array matches_js = Napi::BigUint64Array::New(env, tasks * wanted); @@ -204,29 +260,30 @@ Napi::Value CompiledIndex::Search(Napi::CallbackInfo const& ctx) { auto distances_data = distances_js.Data(); auto counts_data = counts_js.Data(); - try { - bool failed = false; - executor_stl_t executor; - executor.fixed(tasks, [&](std::size_t /*thread_idx*/, std::size_t task_idx) { - auto result = native_->search(vectors + task_idx * native_->dimensions(), wanted); - if (!result) { - failed = true; - Napi::TypeError::New(env, result.error.release()).ThrowAsJavaScriptException(); + // Errors can be set only from the main thread, so before spawning workers + // we need temporary space to keep the message + index_error_t first_error{}; + std::atomic failed{false}; + executor_default_t executor{threads}; + executor.fixed(tasks, [&](std::size_t /*thread_idx*/, std::size_t task_idx) { + if (failed.load()) + return; + auto vector = vectors + task_idx * native_->dimensions(); + search_result_t result = native_->search(vector, wanted); + if (!result) { + if (!failed.exchange(true)) { + first_error = std::move(result.error); } else { - counts_data[task_idx] = result.dump_to(matches_data + task_idx * native_->dimensions(), - distances_data + task_idx * native_->dimensions()); + result.error.release(); } - }); - - if (failed) - return env.Null(); - - } catch (std::bad_alloc const&) { - Napi::TypeError::New(env, "Out of memory").ThrowAsJavaScriptException(); - return env.Null(); - - } catch (...) { - Napi::TypeError::New(env, "Search failed").ThrowAsJavaScriptException(); + } else { + auto matches = matches_data + task_idx * wanted; + auto distances = distances_data + task_idx * wanted; + counts_data[task_idx] = result.dump_to(matches, distances); + } + }); + if (failed) { + Napi::TypeError::New(env, first_error.release()).ThrowAsJavaScriptException(); return env.Null(); } @@ -236,16 +293,22 @@ Napi::Value CompiledIndex::Search(Napi::CallbackInfo const& ctx) { return result_js; }; - if (queries.TypedArrayType() == napi_float32_array) { - return run_parallel(queries.As().Data()); - } else if (queries.TypedArrayType() == napi_float64_array) { - return run_parallel(queries.As().Data()); - } else if (queries.TypedArrayType() == napi_int8_array) { - return run_parallel(queries.As().Data()); - } else { - Napi::TypeError::New(env, - "Unsupported TypedArray. Supported types are Float32Array, Float64Array, and Int8Array.") - .ThrowAsJavaScriptException(); + // Dispatch the parallel tasks based on the `TypedArray` type + try { + if (queries.TypedArrayType() == napi_float32_array) { + return run_parallel(queries.As().Data()); + } else if (queries.TypedArrayType() == napi_float64_array) { + return run_parallel(queries.As().Data()); + } else if (queries.TypedArrayType() == napi_int8_array) { + return run_parallel(queries.As().Data()); + } else { + Napi::TypeError::New( + env, "Unsupported TypedArray. Supported types are Float32Array, Float64Array, and Int8Array.") + .ThrowAsJavaScriptException(); + return env.Null(); + } + } catch (...) { + Napi::TypeError::New(env, "Search failed").ThrowAsJavaScriptException(); return env.Null(); } } @@ -287,21 +350,36 @@ Napi::Value CompiledIndex::Count(Napi::CallbackInfo const& ctx) { Napi::Value exactSearch(Napi::CallbackInfo const& ctx) { Napi::Env env = ctx.Env(); + // Check the number of arguments + if (ctx.Length() != 6) { + Napi::TypeError::New(env, + "`exactSearch` expects 6 arguments: dataset, queries, dimensions, k, metric[, threads].") + .ThrowAsJavaScriptException(); + return env.Null(); + } + // Extracting parameters directly without additional type checks. Napi::TypedArray dataset = ctx[0].As(); Napi::ArrayBuffer datasetBuffer = dataset.ArrayBuffer(); Napi::TypedArray queries = ctx[1].As(); Napi::ArrayBuffer queriesBuffer = queries.ArrayBuffer(); - std::uint64_t dimensions = napi_argument_to_size(ctx[2]); - std::uint64_t wanted = napi_argument_to_size(ctx[3]); + std::size_t dimensions = napi_argument_to_size(ctx[2]); + std::size_t wanted = napi_argument_to_size(ctx[3]); metric_kind_t metric_kind = metric_from_name(ctx[4].As().Utf8Value().c_str()); + std::size_t threads = napi_argument_to_size(ctx[5]); + if (threads == 0) + threads = std::thread::hardware_concurrency(); + // Check the types used scalar_kind_t quantization; std::size_t bytes_per_scalar; switch (queries.TypedArrayType()) { case napi_float64_array: quantization = scalar_kind_t::f64_k, bytes_per_scalar = 8; break; + case napi_float32_array: quantization = scalar_kind_t::f32_k, bytes_per_scalar = 4; break; case napi_int8_array: quantization = scalar_kind_t::i8_k, bytes_per_scalar = 1; break; - default: quantization = scalar_kind_t::f32_k, bytes_per_scalar = 4; break; + default: + Napi::TypeError::New(env, "Unsupported TypedArray for queries.").ThrowAsJavaScriptException(); + return env.Null(); } metric_punned_t metric(dimensions, metric_kind, quantization); @@ -310,7 +388,7 @@ Napi::Value exactSearch(Napi::CallbackInfo const& ctx) { return env.Null(); } - executor_default_t executor; + executor_default_t executor(threads); exact_search_t search; // Performing the exact search. diff --git a/javascript/usearch.test.js b/javascript/usearch.test.js index 5210a2bd..5304bf34 100644 --- a/javascript/usearch.test.js +++ b/javascript/usearch.test.js @@ -1,4 +1,35 @@ -const test = require('node:test'); +const nodeTest = require('node:test'); +const realTest = nodeTest.test; // the original function + +function loggedTest(name, options, fn) { + // The API has two call signatures: + // test(name, fn) + // test(name, options, fn) + if (typeof options === 'function') { + fn = options; + options = undefined; + } + + // Wrap the body so we can log before / after + const wrapped = async (t) => { + console.log('▶', name); + try { + await fn(t); // run the user’s test + console.log('✓', name); + } catch (err) { + console.log('✖', name); + throw err; // re-throw so the runner records the failure + } + }; + + // Delegate back to the real test() with the same options + return options ? realTest(name, options, wrapped) : realTest(name, wrapped); +} + +// Replace both the export and the global the runner puts on each module +global.test = loggedTest; +module.exports = loggedTest; // for completeness if this file is `require`d + const assert = require('node:assert'); const fs = require('node:fs'); const os = require('node:os'); @@ -14,7 +45,6 @@ function assertAlmostEqual(actual, expected, tolerance = 1e-6) { ); } - test('Single-entry operations', async (t) => { await t.test('index info', () => { const index = new usearch.Index(2, 'l2sq'); @@ -31,12 +61,24 @@ test('Single-entry operations', async (t) => { index.add(16n, new Float32Array([10, 25])); assert.equal(index.size(), 2, 'size after adding elements should be 2'); - assert.equal(index.contains(15), true, 'entry must be present after insertion'); + assert.equal( + index.contains(15), + true, + 'entry must be present after insertion' + ); const results = index.search(new Float32Array([13, 14]), 2); - assert.deepEqual(results.keys, new BigUint64Array([15n, 16n]), 'keys should be 15 and 16'); - assert.deepEqual(results.distances, new Float32Array([45, 130]), 'distances should be 45 and 130'); + assert.deepEqual( + results.keys, + new BigUint64Array([15n, 16n]), + 'keys should be 15 and 16' + ); + assert.deepEqual( + results.distances, + new Float32Array([45, 130]), + 'distances should be 45 and 130' + ); }); await t.test('remove', () => { @@ -49,12 +91,24 @@ test('Single-entry operations', async (t) => { assert.equal(index.remove(15n), 1); - assert.equal(index.size(), 3, 'size after remoing elements should be 3'); - assert.equal(index.contains(15), false, 'entry must be absent after insertion'); + assert.equal( + index.size(), + 3, + 'size after removing elements should be 3' + ); + assert.equal( + index.contains(15), + false, + 'entry must be absent after insertion' + ); const results = index.search(new Float32Array([13, 14]), 2); - assert.deepEqual(results.keys, new BigUint64Array([16n, 25n]), 'keys should not include 15'); + assert.deepEqual( + results.keys, + new BigUint64Array([16n, 25n]), + 'keys should not include 15' + ); }); }); @@ -63,15 +117,30 @@ test('Batch operations', async (t) => { const indexBatch = new usearch.Index(2, 'l2sq'); const keys = [15n, 16n]; - const vectors = [new Float32Array([10, 20]), new Float32Array([10, 25])]; + const vectors = [ + new Float32Array([10, 20]), + new Float32Array([10, 25]), + ]; indexBatch.add(keys, vectors); - assert.equal(indexBatch.size(), 2, 'size after adding batch should be 2'); + assert.equal( + indexBatch.size(), + 2, + 'size after adding batch should be 2' + ); const results = indexBatch.search(new Float32Array([13, 14]), 2); - assert.deepEqual(results.keys, new BigUint64Array([15n, 16n]), 'keys should be 15 and 16'); - assert.deepEqual(results.distances, new Float32Array([45, 130]), 'distances should be 45 and 130'); + assert.deepEqual( + results.keys, + new BigUint64Array([15n, 16n]), + 'keys should be 15 and 16' + ); + assert.deepEqual( + results.distances, + new Float32Array([45, 130]), + 'distances should be 45 and 130' + ); }); await t.test('remove', () => { @@ -82,22 +151,30 @@ test('Batch operations', async (t) => { new Float32Array([10, 20]), new Float32Array([10, 25]), new Float32Array([20, 40]), - new Float32Array([20, 45]) + new Float32Array([20, 45]), ]; indexBatch.add(keys, vectors); - assert.deepEqual(indexBatch.remove([15n, 25n]), [1, 1]) - assert.equal(indexBatch.size(), 2, 'size after removing batch should be 2'); + assert.deepEqual(indexBatch.remove([15n, 25n]), [1, 1]); + assert.equal( + indexBatch.size(), + 2, + 'size after removing batch should be 2' + ); const results = indexBatch.search(new Float32Array([13, 14]), 2); - assert.deepEqual(results.keys, new BigUint64Array([16n, 26n]), 'keys should not include 15 and 25'); + assert.deepEqual( + results.keys, + new BigUint64Array([16n, 26n]), + 'keys should not include 15 and 25' + ); }); }); -test("Expected results", () => { +test('Expected results', () => { const index = new usearch.Index({ - metric: "l2sq", + metric: 'l2sq', connectivity: 16, dimensions: 3, }); @@ -137,48 +214,50 @@ test('Operations with invalid values', () => { const keys = [NaN, 16n]; const vectors = [new Float32Array([10, 30]), new Float32Array([1, 5])]; - assert.throws( - () => indexBatch.add(keys, vectors), - { - name: 'Error', - message: 'All keys must be positive integers or bigints.' - } - ); + // All keys must be positive integers or bigints + assert.throws(() => indexBatch.add(keys, vectors)); - assert.throws( - () => indexBatch.search(NaN, 2), - { - name: 'Error', - message: 'Vectors must be a TypedArray or an array of arrays.' - } - ); + // Vectors must be a TypedArray or an array of arrays + assert.throws(() => indexBatch.search(NaN, 2)); }); test('Invalid operations', async (t) => { await t.test('Add the same keys', () => { const index = new usearch.Index({ - metric: "l2sq", + metric: 'l2sq', connectivity: 16, dimensions: 3, }); index.add(42n, new Float32Array([0.2, 0.6, 0.4])); - assert.throws( - () => index.add(42n, new Float32Array([0.2, 0.6, 0.4])), - { - name: 'Error', - message: 'Duplicate keys not allowed in high-level wrappers' - } - ); + assert.throws(() => index.add(42n, new Float32Array([0.2, 0.6, 0.4]))); }); -}); + await t.test('Batch add containing the same key', () => { + const index = new usearch.Index({ + metric: 'l2sq', + connectivity: 16, + dimensions: 3, + }); + index.add(42n, new Float32Array([0.2, 0.6, 0.4])); + assert.throws(() => { + index.add( + [41n, 42n, 43n], + [ + [0.1, 0.6, 0.4], + [0.2, 0.6, 0.4], + [0.3, 0.6, 0.4], + ] + ); + }); + }); +}); test('Serialization', async (t) => { - const indexPath = path.join(os.tmpdir(), 'usearch.test.index') + const indexPath = path.join(os.tmpdir(), 'usearch.test.index'); t.beforeEach(() => { const index = new usearch.Index({ - metric: "l2sq", + metric: 'l2sq', connectivity: 16, dimensions: 3, }); @@ -192,7 +271,7 @@ test('Serialization', async (t) => { await t.test('load', () => { const index = new usearch.Index({ - metric: "l2sq", + metric: 'l2sq', connectivity: 16, dimensions: 3, }); @@ -207,49 +286,55 @@ test('Serialization', async (t) => { // todo: Skip as the test fails only on windows. // The following error in afterEach(). // `error: "EBUSY: resource busy or locked, unlink` - await t.test('view: Read data', {skip: process.platform === 'win32'}, () => { - const index = new usearch.Index({ - metric: "l2sq", - connectivity: 16, - dimensions: 3, - }); - index.view(indexPath); - const results = index.search(new Float32Array([0.2, 0.6, 0.4]), 10); - - assert.equal(index.size(), 1); - assert.deepEqual(results.keys, new BigUint64Array([42n])); - assertAlmostEqual(results.distances[0], new Float32Array([0])); - }); + await t.test( + 'view: Read data', + { skip: process.platform === 'win32' }, + () => { + const index = new usearch.Index({ + metric: 'l2sq', + connectivity: 16, + dimensions: 3, + }); + index.view(indexPath); + const results = index.search(new Float32Array([0.2, 0.6, 0.4]), 10); + + assert.equal(index.size(), 1); + assert.deepEqual(results.keys, new BigUint64Array([42n])); + assertAlmostEqual(results.distances[0], new Float32Array([0])); + } + ); - await t.test('view: Invalid operations: add', {skip: process.platform === 'win32'}, () => { - const index = new usearch.Index({ - metric: "l2sq", - connectivity: 16, - dimensions: 3, - }); - index.view(indexPath); - assert.throws( - () => index.add(43n, new Float32Array([0.2, 0.6, 0.4])), - { - name: 'Error', - message: "Can't add to an immutable index" - } - ); - }); + await t.test( + 'view: Invalid operations: add', + { skip: process.platform === 'win32' }, + () => { + const index = new usearch.Index({ + metric: 'l2sq', + connectivity: 16, + dimensions: 3, + }); + index.view(indexPath); + + // Can't add to an immutable index + assert.throws(() => + index.add(43n, new Float32Array([0.2, 0.6, 0.4])) + ); + } + ); - await t.test('view: Invalid operations: remove', {skip: process.platform === 'win32'}, () => { - const index = new usearch.Index({ - metric: "l2sq", - connectivity: 16, - dimensions: 3, - }); - index.view(indexPath); - assert.throws( - () => index.remove(42n), - { - name: 'Error', - message: "Can't remove from an immutable index" - } - ); - }); + await t.test( + 'view: Invalid operations: remove', + { skip: process.platform === 'win32' }, + () => { + const index = new usearch.Index({ + metric: 'l2sq', + connectivity: 16, + dimensions: 3, + }); + index.view(indexPath); + + // Can't remove from an immutable index + assert.throws(() => index.remove(42n)); + } + ); }); diff --git a/javascript/usearch.ts b/javascript/usearch.ts index a91eb783..0cf94225 100644 --- a/javascript/usearch.ts +++ b/javascript/usearch.ts @@ -16,8 +16,8 @@ type CompiledSearchResult = [ ]; interface CompiledIndex { - add(keys: BigUint64Array, vectors: Vector): void; - search(vectors: VectorOrMatrix, k: number): CompiledSearchResult; + add(keys: BigUint64Array, vectors: Vector, threads: number): void; + search(vectors: VectorOrMatrix, k: number, threads: number): CompiledSearchResult; contains(keys: BigUint64Array): boolean[]; count(keys: BigUint64Array): number | number[]; remove(keys: BigUint64Array): number[]; @@ -37,7 +37,8 @@ interface Compiled { queries: VectorOrMatrix, dimensions: number, count: number, - metric: MetricKind + metric: MetricKind, + threads: number ): CompiledSearchResult; } @@ -326,10 +327,12 @@ export class Index { * If a single key is provided, it is associated with all provided vectors. * @param {Float32Array|Float64Array|Int8Array} vectors - Input matrix representing vectors, * matrix of size n * d, where n is the number of vectors, and d is their dimensionality. + * @param {number} [threads=0] - Optional, default is 0. Number of threads to use for indexing. + * If set to 0, the number of threads is determined automatically. * @throws Will throw an error if the length of keys doesn't match the number of vectors * or if it's not a single key. */ - add(keys: bigint | bigint[] | BigUint64Array, vectors: Vector) { + add(keys: bigint | bigint[] | BigUint64Array, vectors: Vector, threads: number = 0): void { let normalizedKeys = normalizeKeys(keys); let normalizedVectors = normalizeVectors( vectors, @@ -351,8 +354,14 @@ export class Index { ); } + if ((!Number.isNaN(threads) && typeof threads !== "number") || threads < 0) { + throw new Error( + "`threads` must be a non-negative integer representing the number of threads to use for searching." + ); + } + // Call the compiled method - this.#compiledIndex.add(normalizedKeys, normalizedVectors); + this.#compiledIndex.add(normalizedKeys, normalizedVectors, threads); } /** @@ -372,16 +381,22 @@ export class Index { * * @param {Float32Array|Float64Array|Int8Array|Array>} vectors - Input matrix representing query vectors, can be a TypedArray or an array of TypedArray. * @param {number} k - The number of nearest neighbors to search for each query vector. + * @param {number} [threads=0] - Optional, default is 0. Number of threads to use for searching. If set to 0, the number of threads is determined automatically. * @return {Matches|BatchMatches} - Search results for one or more queries, containing keys, distances, and counts of the matches found. * @throws Will throw an error if `k` is not a positive integer or if the size of the vectors is not a multiple of dimensions. * @throws Will throw an error if `vectors` is not a valid input type (TypedArray or an array of TypedArray) or if its flattened size is not a multiple of dimensions. */ - search(vectors: VectorOrMatrix, k: number): Matches | BatchMatches { + search(vectors: VectorOrMatrix, k: number, threads: number = 0): Matches | BatchMatches { if ((!Number.isNaN(k) && typeof k !== "number") || k <= 0) { throw new Error( "`k` must be a positive integer representing the number of nearest neighbors to search for." ); } + if ((!Number.isNaN(threads) && typeof threads !== "number") || threads < 0) { + throw new Error( + "`threads` must be a non-negative integer representing the number of threads to use for searching." + ); + } const normalizedVectors = normalizeVectors( vectors, @@ -389,7 +404,7 @@ export class Index { ); // Call the compiled method and create Matches or BatchMatches object with the result - const result = this.#compiledIndex.search(normalizedVectors, k); + const result = this.#compiledIndex.search(normalizedVectors, k, threads); const countInQueries = normalizedVectors.length / Number(this.#compiledIndex.dimensions()); const batchMatches = new BatchMatches(...result, k); @@ -533,6 +548,7 @@ type NumberArrayConstructor = * @param {number} dimensions - The dimensionality of the vectors in both the dataset and the queries. It defines the number of elements in each vector. * @param {number} count - The number of nearest neighbors to return for each query. If the dataset contains fewer vectors than the specified count, the result will contain only the available vectors. * @param {MetricKind} metric - The distance metric to be used for the search. + * @param {number} [threads=0] - Optional, default is 0. The number of threads to use for the search. If set to 0, the number of threads is determined automatically. * @return {Matches|BatchMatches} - Returns a `Matches` or `BatchMatches` object containing the results of the search. * @throws Will throw an error if `dimensions` and `count` are not positive integers. * @throws Will throw an error if `metric` is not a valid MetricKind. @@ -559,7 +575,8 @@ function exactSearch( queries: VectorOrMatrix, dimensions: number, count: number, - metric: MetricKind + metric: MetricKind, + threads: number = 0 ): Matches | BatchMatches { // Validate and normalize the dimensions and count dimensions = Number(dimensions); @@ -567,6 +584,11 @@ function exactSearch( if (count <= 0 || dimensions <= 0) { throw new Error("Dimensions and count must be positive integers."); } + if ((!Number.isNaN(threads) && typeof threads !== "number") || threads < 0) { + throw new Error( + "`threads` must be a non-negative integer representing the number of threads to use for searching." + ); + } // Validate metric if (!Object.values(MetricKind).includes(metric)) { @@ -599,7 +621,8 @@ function exactSearch( queries, dimensions, count, - metric + metric, + threads ); // Create and return a Matches or BatchMatches object with the result diff --git a/package-lock.json b/package-lock.json index 92c06521..ab286033 100644 --- a/package-lock.json +++ b/package-lock.json @@ -6,7 +6,7 @@ "packages": { "": { "name": "usearch", - "version": "2.15.1", + "version": "2.17.7", "hasInstallScript": true, "license": "Apache 2.0", "dependencies": { @@ -309,10 +309,11 @@ "dev": true }, "node_modules/cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", "dev": true, + "license": "MIT", "dependencies": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", @@ -577,11 +578,19 @@ "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", "dev": true }, - "node_modules/ip": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/ip/-/ip-2.0.0.tgz", - "integrity": "sha512-WKa+XuLG1A1R0UWhl2+1XQSi+fZWMsYKffMZTTYsiZaUD8k2yDAj5atimTUD2TZkyCkNEeYE5NhFZmupOGtjYQ==", - "dev": true + "node_modules/ip-address": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/ip-address/-/ip-address-9.0.5.tgz", + "integrity": "sha512-zHtQzGojZXTwZTHQqra+ETKd4Sn3vgi7uBmlPoXVWZqYvuKmtI0l/VZTjqGmJY9x88GGOaZ9+G9ES8hC4T4X8g==", + "dev": true, + "license": "MIT", + "dependencies": { + "jsbn": "1.1.0", + "sprintf-js": "^1.1.3" + }, + "engines": { + "node": ">= 12" + } }, "node_modules/is-fullwidth-code-point": { "version": "3.0.0", @@ -625,6 +634,13 @@ "@pkgjs/parseargs": "^0.11.0" } }, + "node_modules/jsbn": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/jsbn/-/jsbn-1.1.0.tgz", + "integrity": "sha512-4bYVV3aAMtDTTu4+xsDYa6sy9GyJ69/amsu9sYF2zqjiEoZA5xJi3BrfX3uY+/IekIu7MwdObdbDWpoZdBv3/A==", + "dev": true, + "license": "MIT" + }, "node_modules/lru-cache": { "version": "6.0.0", "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", @@ -1149,16 +1165,17 @@ } }, "node_modules/socks": { - "version": "2.7.1", - "resolved": "https://registry.npmjs.org/socks/-/socks-2.7.1.tgz", - "integrity": "sha512-7maUZy1N7uo6+WVEX6psASxtNlKaNVMlGQKkG/63nEDdLOWNbiUMoLK7X4uYoLhQstau72mLgfEWcXcwsaHbYQ==", + "version": "2.8.4", + "resolved": "https://registry.npmjs.org/socks/-/socks-2.8.4.tgz", + "integrity": "sha512-D3YaD0aRxR3mEcqnidIs7ReYJFVzWdd6fXJYUM8ixcQcJRGTka/b3saV0KflYhyVJXKhb947GndU35SxYNResQ==", "dev": true, + "license": "MIT", "dependencies": { - "ip": "^2.0.0", + "ip-address": "^9.0.5", "smart-buffer": "^4.2.0" }, "engines": { - "node": ">= 10.13.0", + "node": ">= 10.0.0", "npm": ">= 3.0.0" } }, @@ -1176,6 +1193,13 @@ "node": ">= 14" } }, + "node_modules/sprintf-js": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.1.3.tgz", + "integrity": "sha512-Oo+0REFV59/rz3gfJNKQiBlwfHaSESl1pcGyABQsnnIfWOFt6JNj5gCog2U6MLZ//IGYD+nA8nI+mTShREReaA==", + "dev": true, + "license": "BSD-3-Clause" + }, "node_modules/ssri": { "version": "10.0.5", "resolved": "https://registry.npmjs.org/ssri/-/ssri-10.0.5.tgz", @@ -1294,10 +1318,11 @@ } }, "node_modules/tar": { - "version": "6.2.0", - "resolved": "https://registry.npmjs.org/tar/-/tar-6.2.0.tgz", - "integrity": "sha512-/Wo7DcT0u5HUV486xg675HtjNd3BXZ6xDbzsCUZPt5iw8bTQ63bP0Raut3mvro9u+CUyq7YQd8Cx55fsZXxqLQ==", + "version": "6.2.1", + "resolved": "https://registry.npmjs.org/tar/-/tar-6.2.1.tgz", + "integrity": "sha512-DZ4yORTwrbTj/7MZYq2w+/ZFdI6OZ/f9SFHR+71gIVUZhOQPHzVCLpvRnPgyaMpfWxxk/4ONva3GQSyNIKRv6A==", "dev": true, + "license": "ISC", "dependencies": { "chownr": "^2.0.0", "fs-minipass": "^2.0.0", @@ -1311,10 +1336,11 @@ } }, "node_modules/tar-fs": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-2.1.1.tgz", - "integrity": "sha512-V0r2Y9scmbDRLCNex/+hYzvp/zyYjvFbHPNgVTKfQvVrb6guiE/fxP+XblDNR011utopbkex2nM4dHNV6GDsng==", + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-2.1.2.tgz", + "integrity": "sha512-EsaAXwxmx8UB7FRKqeozqEPop69DXcmYwTQwXvyAPF352HJsPdkVhvTaDPYqfNgruveJIJy3TA2l+2zj8LJIJA==", "dev": true, + "license": "MIT", "dependencies": { "chownr": "^1.1.1", "mkdirp-classic": "^0.5.2", diff --git a/python/usearch/index.py b/python/usearch/index.py index b45e1ca2..4e884025 100644 --- a/python/usearch/index.py +++ b/python/usearch/index.py @@ -604,7 +604,7 @@ def metadata(path_or_buffer: PathOrBuffer) -> Optional[dict]: raise e @staticmethod - def restore(path_or_buffer: PathOrBuffer, view: bool = False) -> Optional[Index]: + def restore(path_or_buffer: PathOrBuffer, view: bool = False, **kwargs) -> Optional[Index]: meta = Index.metadata(path_or_buffer) if not meta: return None @@ -613,6 +613,7 @@ def restore(path_or_buffer: PathOrBuffer, view: bool = False) -> Optional[Index] ndim=meta["dimensions"], dtype=meta["kind_scalar"], metric=meta["kind_metric"], + **kwargs, ) if view: