Skip to content

Commit d12c620

Browse files
committed
Optimize merge
1 parent d5b5c1b commit d12c620

File tree

2 files changed

+264
-61
lines changed

2 files changed

+264
-61
lines changed

cpp/test.cpp

+67-26
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,7 @@ template <typename key_at, typename slot_at> void test_replacing_update() {
10991099
/**
11001100
* @brief Tests merging.
11011101
*/
1102-
void test_merge() {
1102+
void test_merge(std::size_t base_n) {
11031103
using index_t = index_gt<>;
11041104
using distance_t = typename index_t::distance_t;
11051105
using key_t = typename index_t::key_t;
@@ -1143,24 +1143,49 @@ void test_merge() {
11431143
expect(result);
11441144
};
11451145

1146+
std::size_t n_nodes1 = base_n;
1147+
std::size_t n_nodes2 = base_n * 2;
1148+
1149+
// Prepare expected index
1150+
auto expected_index = create_index();
1151+
metric_t expected_metric;
1152+
expect(expected_index.reserve(n_nodes1 + n_nodes2));
1153+
11461154
// Prepare index 1
11471155
auto index1 = create_index();
11481156
metric_t metric1;
1149-
expect(index1.reserve(3));
1150-
add(index1, 11, 1.1f, metric1);
1151-
add(index1, 12, 2.1f, metric1);
1152-
add(index1, 13, 3.1f, metric1);
1153-
expect_eq(index1.size(), 3);
1157+
expect(index1.reserve(n_nodes1));
1158+
{
1159+
// Use static seed for easy to reproduce
1160+
std::default_random_engine engine(n_nodes1);
1161+
std::uniform_real_distribution<float> distribution(-1.0, 1.0);
1162+
for (std::size_t i = 0; i < n_nodes1; ++i) {
1163+
std::size_t key = 10000 + i;
1164+
value_t value = distribution(engine);
1165+
add(index1, key, value, metric1);
1166+
add(expected_index, key, value, expected_metric);
1167+
}
1168+
}
1169+
expect_eq(index1.size(), n_nodes1);
1170+
expect_eq(expected_index.size(), n_nodes1);
11541171

11551172
// Prepare index 2
11561173
auto index2 = create_index();
11571174
metric_t metric2;
1158-
expect(index2.reserve(4));
1159-
add(index2, 21, -1.1f, metric2);
1160-
add(index2, 22, -2.1f, metric2);
1161-
add(index2, 23, -3.1f, metric2);
1162-
add(index2, 24, -4.1f, metric2);
1163-
expect_eq(index2.size(), 4);
1175+
expect(index2.reserve(n_nodes2));
1176+
{
1177+
// Use static seed for easy to reproduce
1178+
std::default_random_engine engine(n_nodes2);
1179+
std::uniform_real_distribution<float> distribution(-1.0, 1.0);
1180+
for (std::size_t i = 0; i < n_nodes2; ++i) {
1181+
std::size_t key = 20000 + i;
1182+
value_t value = distribution(engine);
1183+
add(index2, key, value, metric2);
1184+
add(expected_index, key, value, expected_metric);
1185+
}
1186+
}
1187+
expect_eq(index2.size(), n_nodes2);
1188+
expect_eq(expected_index.size(), n_nodes1 + n_nodes2);
11641189

11651190
// Merge indexes
11661191
char const* merge_file_path = "merge.usearch";
@@ -1172,30 +1197,45 @@ void test_merge() {
11721197
auto merge_on_success = [&](member_ref_t member, value_t const& value) {
11731198
merged_metric.values[member.slot] = value;
11741199
};
1200+
1201+
// Merge index1
11751202
auto get_value1 = [&](member_cref_t member) -> value_t& { return metric1.values[member.slot]; };
11761203
expect(merged_index.merge(index1, get_value1, merged_metric, {}, merge_on_success));
1204+
expect_eq(merged_index.size(), n_nodes1);
1205+
// Assert after we merge index1
1206+
auto search = merged_index.search(0.75f, 3, merged_metric);
1207+
auto expected_search = index1.search(0.75f, 3, expected_metric);
1208+
expect_eq(search.size(), 3);
1209+
expect(search[0].distance <= expected_search[0].distance);
1210+
expect(search[1].distance <= expected_search[1].distance);
1211+
expect(search[2].distance <= expected_search[2].distance);
1212+
auto loaded_index = create_index();
1213+
loaded_index.view(merge_file_path);
1214+
search = merged_index.search(0.75f, 3, merged_metric);
1215+
1216+
// Merge index2
11771217
auto get_value2 = [&](member_cref_t member) -> value_t& { return metric2.values[member.slot]; };
11781218
expect(merged_index.merge(index2, get_value2, merged_metric, {}, merge_on_success));
1179-
1180-
// Assert
1181-
expect_eq(merged_index.size(), 7);
1182-
auto search = merged_index.search(0.75f, 3, merged_metric);
1219+
// Assert after we merge index1 and index2
1220+
expect_eq(merged_index.size(), n_nodes1 + n_nodes2);
1221+
search = merged_index.search(0.75f, 3, merged_metric);
1222+
expected_search = expected_index.search(0.75f, 3, expected_metric);
11831223
expect_eq(search.size(), 3);
1184-
expect_eq(static_cast<key_t>(search[0].member.key), 11);
1185-
expect_eq(static_cast<key_t>(search[1].member.key), 12);
1186-
expect_eq(static_cast<key_t>(search[2].member.key), 21);
1224+
expect(search[0].distance <= expected_search[0].distance);
1225+
expect(search[1].distance <= expected_search[1].distance);
1226+
expect(search[2].distance <= expected_search[2].distance);
11871227

1188-
// Re-load merged indexes
1228+
// Re-load the merged index
11891229
merged_index.reset();
11901230
merged_index.load(merge_file_path);
11911231

1192-
// Assert
1193-
expect_eq(merged_index.size(), 7);
1232+
// Assert after we reload the merged index
1233+
expect_eq(merged_index.size(), n_nodes1 + n_nodes2);
11941234
search = merged_index.search(0.75f, 3, merged_metric);
11951235
expect_eq(search.size(), 3);
1196-
expect_eq(static_cast<key_t>(search[0].member.key), 11);
1197-
expect_eq(static_cast<key_t>(search[1].member.key), 12);
1198-
expect_eq(static_cast<key_t>(search[2].member.key), 21);
1236+
expect(search[0].distance <= expected_search[0].distance);
1237+
expect(search[1].distance <= expected_search[1].distance);
1238+
expect(search[2].distance <= expected_search[2].distance);
11991239
}
12001240

12011241
int main(int, char**) {
@@ -1267,7 +1307,8 @@ int main(int, char**) {
12671307

12681308
// Test merge
12691309
std::printf("Testing merge\n");
1270-
test_merge();
1310+
test_merge(10); // Use only the 0-level layer
1311+
test_merge(1000); // Use multiple layers
12711312

12721313
return 0;
12731314
}

0 commit comments

Comments
 (0)