Skip to content

Commit a506795

Browse files
BagritsevichStepanromange
authored andcommitted
fix(search_family): Process wrong field types in indexes for the FT.SEARCH and FT.AGGREGATE commands (#4070)
* fix(search_family): Process wrong field types in indexes for the FT.SEARCH and FT.AGGREGATE commands fixes #3986 --------- Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
1 parent 2d49a28 commit a506795

16 files changed

+682
-215
lines changed

src/core/search/base.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "core/search/base.h"
66

7+
#include <absl/strings/numbers.h>
8+
79
namespace dfly::search {
810

911
std::string_view QueryParams::operator[](std::string_view name) const {
@@ -37,4 +39,11 @@ WrappedStrPtr::operator std::string_view() const {
3739
return std::string_view{ptr.get(), std::strlen(ptr.get())};
3840
}
3941

42+
std::optional<double> ParseNumericField(std::string_view value) {
43+
double value_as_double;
44+
if (absl::SimpleAtod(value, &value_as_double))
45+
return value_as_double;
46+
return std::nullopt;
47+
}
48+
4049
} // namespace dfly::search

src/core/search/base.h

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,18 @@ using SortableValue = std::variant<std::monostate, double, std::string>;
6868
struct DocumentAccessor {
6969
using VectorInfo = search::OwnedFtVector;
7070
using StringList = absl::InlinedVector<std::string_view, 1>;
71+
using NumsList = absl::InlinedVector<double, 1>;
7172

7273
virtual ~DocumentAccessor() = default;
7374

74-
virtual StringList GetStrings(std::string_view active_field) const = 0;
75-
virtual VectorInfo GetVector(std::string_view active_field) const = 0;
75+
/* Returns nullopt if the specified field is not a list of strings */
76+
virtual std::optional<StringList> GetStrings(std::string_view active_field) const = 0;
77+
78+
/* Returns nullopt if the specified field is not a vector */
79+
virtual std::optional<VectorInfo> GetVector(std::string_view active_field) const = 0;
80+
81+
/* Return nullopt if the specified field is not a list of doubles */
82+
virtual std::optional<NumsList> GetNumbers(std::string_view active_field) const = 0;
7683
};
7784

7885
// Base class for type-specific indices.
@@ -81,8 +88,10 @@ struct DocumentAccessor {
8188
// query functions. All results for all index types should be sorted.
8289
struct BaseIndex {
8390
virtual ~BaseIndex() = default;
84-
virtual void Add(DocId id, DocumentAccessor* doc, std::string_view field) = 0;
85-
virtual void Remove(DocId id, DocumentAccessor* doc, std::string_view field) = 0;
91+
92+
// Returns true if the document was added / indexed
93+
virtual bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
94+
virtual void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
8695
};
8796

8897
// Base class for type-specific sorting indices.
@@ -91,4 +100,20 @@ struct BaseSortIndex : BaseIndex {
91100
virtual std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const = 0;
92101
};
93102

103+
/* Used for converting field values to double. Returns std::nullopt if the conversion fails */
104+
std::optional<double> ParseNumericField(std::string_view value);
105+
106+
/* Temporary method to create an empty std::optional<InlinedVector> in DocumentAccessor::GetString
107+
and DocumentAccessor::GetNumbers methods. The problem is that due to internal implementation
108+
details of absl::InlineVector, we are getting a -Wmaybe-uninitialized compiler warning. To
109+
suppress this false warning, we temporarily disable it around this block of code using GCC
110+
diagnostic directives. */
111+
template <typename InlinedVector> std::optional<InlinedVector> EmptyAccessResult() {
112+
// GCC 13.1 throws spurious warnings around this code.
113+
#pragma GCC diagnostic push
114+
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
115+
return InlinedVector{};
116+
#pragma GCC diagnostic pop
117+
}
118+
94119
} // namespace dfly::search

src/core/search/indices.cc

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,22 @@ absl::flat_hash_set<string> NormalizeTags(string_view taglist, bool case_sensiti
7171
NumericIndex::NumericIndex(PMR_NS::memory_resource* mr) : entries_{mr} {
7272
}
7373

74-
void NumericIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
75-
for (auto str : doc->GetStrings(field)) {
76-
double num;
77-
if (absl::SimpleAtod(str, &num))
78-
entries_.emplace(num, id);
74+
bool NumericIndex::Add(DocId id, const DocumentAccessor& doc, string_view field) {
75+
auto numbers = doc.GetNumbers(field);
76+
if (!numbers) {
77+
return false;
7978
}
79+
80+
for (auto num : numbers.value()) {
81+
entries_.emplace(num, id);
82+
}
83+
return true;
8084
}
8185

82-
void NumericIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
83-
for (auto str : doc->GetStrings(field)) {
84-
double num;
85-
if (absl::SimpleAtod(str, &num))
86-
entries_.erase({num, id});
86+
void NumericIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
87+
auto numbers = doc.GetNumbers(field).value();
88+
for (auto num : numbers) {
89+
entries_.erase({num, id});
8790
}
8891
}
8992

@@ -139,19 +142,27 @@ typename BaseStringIndex<C>::Container* BaseStringIndex<C>::GetOrCreate(string_v
139142
}
140143

141144
template <typename C>
142-
void BaseStringIndex<C>::Add(DocId id, DocumentAccessor* doc, string_view field) {
145+
bool BaseStringIndex<C>::Add(DocId id, const DocumentAccessor& doc, string_view field) {
146+
auto strings_list = doc.GetStrings(field);
147+
if (!strings_list) {
148+
return false;
149+
}
150+
143151
absl::flat_hash_set<std::string> tokens;
144-
for (string_view str : doc->GetStrings(field))
152+
for (string_view str : strings_list.value())
145153
tokens.merge(Tokenize(str));
146154

147155
for (string_view token : tokens)
148156
GetOrCreate(token)->Insert(id);
157+
return true;
149158
}
150159

151160
template <typename C>
152-
void BaseStringIndex<C>::Remove(DocId id, DocumentAccessor* doc, string_view field) {
161+
void BaseStringIndex<C>::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
162+
auto strings_list = doc.GetStrings(field).value();
163+
153164
absl::flat_hash_set<std::string> tokens;
154-
for (string_view str : doc->GetStrings(field))
165+
for (string_view str : strings_list)
155166
tokens.merge(Tokenize(str));
156167

157168
for (const auto& token : tokens) {
@@ -192,26 +203,39 @@ std::pair<size_t /*dim*/, VectorSimilarity> BaseVectorIndex::Info() const {
192203
return {dim_, sim_};
193204
}
194205

206+
bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) {
207+
auto vector = doc.GetVector(field);
208+
if (!vector)
209+
return false;
210+
211+
auto& [ptr, size] = vector.value();
212+
if (ptr && size != dim_) {
213+
return false;
214+
}
215+
216+
AddVector(id, ptr);
217+
return true;
218+
}
219+
195220
FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params,
196221
PMR_NS::memory_resource* mr)
197222
: BaseVectorIndex{params.dim, params.sim}, entries_{mr} {
198223
DCHECK(!params.use_hnsw);
199224
entries_.reserve(params.capacity * params.dim);
200225
}
201226

202-
void FlatVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
227+
void FlatVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
203228
DCHECK_LE(id * dim_, entries_.size());
204229
if (id * dim_ == entries_.size())
205230
entries_.resize((id + 1) * dim_);
206231

207232
// TODO: Let get vector write to buf itself
208-
auto [ptr, size] = doc->GetVector(field);
209-
210-
if (size == dim_)
211-
memcpy(&entries_[id * dim_], ptr.get(), dim_ * sizeof(float));
233+
if (vector) {
234+
memcpy(&entries_[id * dim_], vector.get(), dim_ * sizeof(float));
235+
}
212236
}
213237

214-
void FlatVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
238+
void FlatVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
215239
// noop
216240
}
217241

@@ -229,7 +253,7 @@ struct HnswlibAdapter {
229253
100 /* seed*/} {
230254
}
231255

232-
void Add(float* data, DocId id) {
256+
void Add(const float* data, DocId id) {
233257
if (world_.cur_element_count + 1 >= world_.max_elements_)
234258
world_.resizeIndex(world_.cur_element_count * 2);
235259
world_.addPoint(data, id);
@@ -298,10 +322,10 @@ HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS
298322
HnswVectorIndex::~HnswVectorIndex() {
299323
}
300324

301-
void HnswVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
302-
auto [ptr, size] = doc->GetVector(field);
303-
if (size == dim_)
304-
adapter_->Add(ptr.get(), id);
325+
void HnswVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
326+
if (vector) {
327+
adapter_->Add(vector.get(), id);
328+
}
305329
}
306330

307331
std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k,
@@ -314,7 +338,7 @@ std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t
314338
return adapter_->Knn(target, k, ef, allowed);
315339
}
316340

317-
void HnswVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
341+
void HnswVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
318342
adapter_->Remove(id);
319343
}
320344

src/core/search/indices.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ namespace dfly::search {
2828
struct NumericIndex : public BaseIndex {
2929
explicit NumericIndex(PMR_NS::memory_resource* mr);
3030

31-
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
32-
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
31+
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
32+
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
3333

3434
std::vector<DocId> Range(double l, double r) const;
3535

@@ -44,16 +44,16 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
4444

4545
BaseStringIndex(PMR_NS::memory_resource* mr, bool case_sensitive);
4646

47-
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
48-
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
47+
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
48+
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
4949

5050
// Used by Add & Remove to tokenize text value
5151
virtual absl::flat_hash_set<std::string> Tokenize(std::string_view value) const = 0;
5252

5353
// Pointer is valid as long as index is not mutated. Nullptr if not found
5454
const Container* Matching(std::string_view str) const;
5555

56-
// Iterate over all Machting on prefix.
56+
// Iterate over all Matching on prefix.
5757
void MatchingPrefix(std::string_view prefix, absl::FunctionRef<void(const Container*)> cb) const;
5858

5959
// Returns all the terms that appear as keys in the reverse index.
@@ -97,9 +97,14 @@ struct TagIndex : public BaseStringIndex<SortedVector> {
9797
struct BaseVectorIndex : public BaseIndex {
9898
std::pair<size_t /*dim*/, VectorSimilarity> Info() const;
9999

100+
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override final;
101+
100102
protected:
101103
BaseVectorIndex(size_t dim, VectorSimilarity sim);
102104

105+
using VectorPtr = decltype(std::declval<OwnedFtVector>().first);
106+
virtual void AddVector(DocId id, const VectorPtr& vector) = 0;
107+
103108
size_t dim_;
104109
VectorSimilarity sim_;
105110
};
@@ -109,11 +114,13 @@ struct BaseVectorIndex : public BaseIndex {
109114
struct FlatVectorIndex : public BaseVectorIndex {
110115
FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
111116

112-
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
113-
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
117+
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
114118

115119
const float* Get(DocId doc) const;
116120

121+
protected:
122+
void AddVector(DocId id, const VectorPtr& vector) override;
123+
117124
private:
118125
PMR_NS::vector<float> entries_;
119126
};
@@ -124,13 +131,15 @@ struct HnswVectorIndex : public BaseVectorIndex {
124131
HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
125132
~HnswVectorIndex();
126133

127-
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
128-
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
134+
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
129135

130136
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef) const;
131137
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
132138
const std::vector<DocId>& allowed) const;
133139

140+
protected:
141+
void AddVector(DocId id, const VectorPtr& vector) override;
142+
134143
private:
135144
std::unique_ptr<HnswlibAdapter> adapter_;
136145
};

src/core/search/search.cc

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -571,23 +571,48 @@ void FieldIndices::CreateSortIndices(PMR_NS::memory_resource* mr) {
571571
}
572572
}
573573

574-
void FieldIndices::Add(DocId doc, DocumentAccessor* access) {
575-
for (auto& [field, index] : indices_)
576-
index->Add(doc, access, field);
577-
for (auto& [field, sort_index] : sort_indices_)
578-
sort_index->Add(doc, access, field);
574+
bool FieldIndices::Add(DocId doc, const DocumentAccessor& access) {
575+
bool was_added = true;
576+
577+
std::vector<std::pair<std::string_view, BaseIndex*>> successfully_added_indices;
578+
successfully_added_indices.reserve(indices_.size() + sort_indices_.size());
579+
580+
auto try_add = [&](const auto& indices_container) {
581+
for (auto& [field, index] : indices_container) {
582+
if (index->Add(doc, access, field)) {
583+
successfully_added_indices.emplace_back(field, index.get());
584+
} else {
585+
was_added = false;
586+
break;
587+
}
588+
}
589+
};
590+
591+
try_add(indices_);
592+
593+
if (was_added) {
594+
try_add(sort_indices_);
595+
}
596+
597+
if (!was_added) {
598+
for (auto& [field, index] : successfully_added_indices) {
599+
index->Remove(doc, access, field);
600+
}
601+
return false;
602+
}
579603

580604
all_ids_.insert(upper_bound(all_ids_.begin(), all_ids_.end(), doc), doc);
605+
return true;
581606
}
582607

583-
void FieldIndices::Remove(DocId doc, DocumentAccessor* access) {
608+
void FieldIndices::Remove(DocId doc, const DocumentAccessor& access) {
584609
for (auto& [field, index] : indices_)
585610
index->Remove(doc, access, field);
586611
for (auto& [field, sort_index] : sort_indices_)
587612
sort_index->Remove(doc, access, field);
588613

589614
auto it = lower_bound(all_ids_.begin(), all_ids_.end(), doc);
590-
CHECK(it != all_ids_.end() && *it == doc);
615+
DCHECK(it != all_ids_.end() && *it == doc);
591616
all_ids_.erase(it);
592617
}
593618

src/core/search/search.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ class FieldIndices {
7777
// Create indices based on schema and options. Both must outlive the indices
7878
FieldIndices(const Schema& schema, const IndicesOptions& options, PMR_NS::memory_resource* mr);
7979

80-
void Add(DocId doc, DocumentAccessor* access);
81-
void Remove(DocId doc, DocumentAccessor* access);
80+
// Returns true if document was added
81+
bool Add(DocId doc, const DocumentAccessor& access);
82+
void Remove(DocId doc, const DocumentAccessor& access);
8283

8384
BaseIndex* GetIndex(std::string_view field) const;
8485
BaseSortIndex* GetSortIndex(std::string_view field) const;

0 commit comments

Comments
 (0)