@@ -1096,6 +1096,108 @@ template <typename key_at, typename slot_at> void test_replacing_update() {
1096
1096
expect_eq (final_search[2 ].member .key , 44 );
1097
1097
}
1098
1098
1099
+ /* *
1100
+ * @brief Tests merging.
1101
+ */
1102
+ void test_merge () {
1103
+ using index_t = index_gt<>;
1104
+ using distance_t = typename index_t ::distance_t ;
1105
+ using key_t = typename index_t ::key_t ;
1106
+ using compressed_slot_t = typename index_t ::compressed_slot_t ;
1107
+ using member_ref_t = typename index_t ::member_ref_t ;
1108
+ using member_cref_t = typename index_t ::member_cref_t ;
1109
+ using member_citerator_t = typename index_t ::member_citerator_t ;
1110
+ using add_result_t = typename index_t ::add_result_t ;
1111
+
1112
+ using value_t = float ;
1113
+
1114
+ auto create_index = []() {
1115
+ auto index_result = index_t::make ();
1116
+ expect (index_result);
1117
+ return std::move (index_result.index );
1118
+ };
1119
+
1120
+ struct metric_t {
1121
+ std::unordered_map<compressed_slot_t , value_t > values;
1122
+
1123
+ metric_t () : values() {}
1124
+ distance_t compute (value_t const & a, value_t const & b) {
1125
+ if (b > a) {
1126
+ return b - a;
1127
+ } else {
1128
+ return a - b;
1129
+ }
1130
+ }
1131
+ distance_t operator ()(value_t const & a, member_cref_t const & b) { return compute (a, values.at (get_slot (b))); }
1132
+ distance_t operator ()(value_t const & a, member_citerator_t const & b) {
1133
+ return compute (a, values.at (get_slot (b)));
1134
+ }
1135
+ distance_t operator ()(member_citerator_t const & a, member_citerator_t const & b) {
1136
+ return compute (values.at (get_slot (a)), values.at (get_slot (b)));
1137
+ }
1138
+ };
1139
+
1140
+ auto add = [](index_t & index , key_t const key, value_t const value, metric_t & metric) {
1141
+ auto on_success = [&](member_ref_t member) { metric.values [member.slot ] = value; };
1142
+ add_result_t result = index .add (key, value, metric, {}, on_success);
1143
+ expect (result);
1144
+ };
1145
+
1146
+ // Prepare index 1
1147
+ auto index1 = create_index ();
1148
+ metric_t metric1;
1149
+ expect (index1.reserve (3 ));
1150
+ add (index1, 11 , 1 .1f , metric1);
1151
+ add (index1, 12 , 2 .1f , metric1);
1152
+ add (index1, 13 , 3 .1f , metric1);
1153
+ expect_eq (index1.size (), 3 );
1154
+
1155
+ // Prepare index 2
1156
+ auto index2 = create_index ();
1157
+ metric_t metric2;
1158
+ expect (index2.reserve (4 ));
1159
+ add (index2, 21 , -1 .1f , metric2);
1160
+ add (index2, 22 , -2 .1f , metric2);
1161
+ add (index2, 23 , -3 .1f , metric2);
1162
+ add (index2, 24 , -4 .1f , metric2);
1163
+ expect_eq (index2.size (), 4 );
1164
+
1165
+ // Merge indexes
1166
+ char const * merge_file_path = " merge.usearch" ;
1167
+ auto merged_index = create_index ();
1168
+ expect (merged_index.save (merge_file_path));
1169
+ memory_mapped_file_t file{merge_file_path, true };
1170
+ expect (merged_index.load (std::move (file)));
1171
+ metric_t merged_metric;
1172
+ auto merge_on_success = [&](member_ref_t member, value_t const & value) {
1173
+ merged_metric.values [member.slot ] = value;
1174
+ };
1175
+ auto get_value1 = [&](member_cref_t member) -> value_t & { return metric1.values [member.slot ]; };
1176
+ expect (merged_index.merge (index1, get_value1, merged_metric, {}, merge_on_success));
1177
+ auto get_value2 = [&](member_cref_t member) -> value_t & { return metric2.values [member.slot ]; };
1178
+ expect (merged_index.merge (index2, get_value2, merged_metric, {}, merge_on_success));
1179
+
1180
+ // Assert
1181
+ expect_eq (merged_index.size (), 7 );
1182
+ auto search = merged_index.search (0 .75f , 3 , merged_metric);
1183
+ expect_eq (search.size (), 3 );
1184
+ expect_eq (static_cast <key_t >(search[0 ].member .key ), 11 );
1185
+ expect_eq (static_cast <key_t >(search[1 ].member .key ), 12 );
1186
+ expect_eq (static_cast <key_t >(search[2 ].member .key ), 21 );
1187
+
1188
+ // Re-load merged indexes
1189
+ merged_index.reset ();
1190
+ merged_index.load (merge_file_path);
1191
+
1192
+ // Assert
1193
+ expect_eq (merged_index.size (), 7 );
1194
+ search = merged_index.search (0 .75f , 3 , merged_metric);
1195
+ expect_eq (search.size (), 3 );
1196
+ expect_eq (static_cast <key_t >(search[0 ].member .key ), 11 );
1197
+ expect_eq (static_cast <key_t >(search[1 ].member .key ), 12 );
1198
+ expect_eq (static_cast <key_t >(search[2 ].member .key ), 21 );
1199
+ }
1200
+
1099
1201
int main (int , char **) {
1100
1202
test_uint40 ();
1101
1203
test_cosine<float , std::int64_t , uint40_t >(10 , 10 );
@@ -1163,5 +1265,9 @@ int main(int, char**) {
1163
1265
test_sets<std::int64_t , slot32_t >(set_size, 20 , 30 );
1164
1266
test_strings<std::int64_t , slot32_t >();
1165
1267
1268
+ // Test merge
1269
+ std::printf (" Testing merge\n " );
1270
+ test_merge ();
1271
+
1166
1272
return 0 ;
1167
1273
}
0 commit comments