Skip to content

Commit 3c7e312

Browse files
fix(search_family): Support boolean and nullable types in indexes (#4314)
* fix(search_family): Support boolean and nullable types in indexes fixes #4107, #4129 Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * refactor: address comments Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> --------- Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
1 parent 01f24da commit 3c7e312

File tree

9 files changed

+191
-76
lines changed

9 files changed

+191
-76
lines changed

src/core/search/base.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ struct DocumentAccessor {
8080

8181
/* Return nullopt if the specified field is not a list of doubles */
8282
virtual std::optional<NumsList> GetNumbers(std::string_view active_field) const = 0;
83+
84+
/* Same as GetStrings, but also supports boolean values */
85+
virtual std::optional<StringList> GetTags(std::string_view active_field) const = 0;
8386
};
8487

8588
// Base class for type-specific indices.

src/core/search/indices.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ typename BaseStringIndex<C>::Container* BaseStringIndex<C>::GetOrCreate(string_v
143143

144144
template <typename C>
145145
bool BaseStringIndex<C>::Add(DocId id, const DocumentAccessor& doc, string_view field) {
146-
auto strings_list = doc.GetStrings(field);
146+
auto strings_list = GetStrings(doc, field);
147147
if (!strings_list) {
148148
return false;
149149
}
@@ -159,7 +159,7 @@ bool BaseStringIndex<C>::Add(DocId id, const DocumentAccessor& doc, string_view
159159

160160
template <typename C>
161161
void BaseStringIndex<C>::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
162-
auto strings_list = doc.GetStrings(field).value();
162+
auto strings_list = GetStrings(doc, field).value();
163163

164164
absl::flat_hash_set<std::string> tokens;
165165
for (string_view str : strings_list)
@@ -188,10 +188,20 @@ template <typename C> vector<string> BaseStringIndex<C>::GetTerms() const {
188188
template struct BaseStringIndex<CompressedSortedSet>;
189189
template struct BaseStringIndex<SortedVector>;
190190

191+
std::optional<DocumentAccessor::StringList> TextIndex::GetStrings(const DocumentAccessor& doc,
192+
std::string_view field) const {
193+
return doc.GetStrings(field);
194+
}
195+
191196
absl::flat_hash_set<std::string> TextIndex::Tokenize(std::string_view value) const {
192197
return TokenizeWords(value, *stopwords_);
193198
}
194199

200+
std::optional<DocumentAccessor::StringList> TagIndex::GetStrings(const DocumentAccessor& doc,
201+
std::string_view field) const {
202+
return doc.GetTags(field);
203+
}
204+
195205
absl::flat_hash_set<std::string> TagIndex::Tokenize(std::string_view value) const {
196206
return NormalizeTags(value, case_sensitive_, separator_);
197207
}

src/core/search/indices.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
4747
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
4848
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
4949

50-
// Used by Add & Remove to tokenize text value
51-
virtual absl::flat_hash_set<std::string> Tokenize(std::string_view value) const = 0;
52-
5350
// Pointer is valid as long as index is not mutated. Nullptr if not found
5451
const Container* Matching(std::string_view str) const;
5552

@@ -60,6 +57,15 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
6057
std::vector<std::string> GetTerms() const;
6158

6259
protected:
60+
using StringList = DocumentAccessor::StringList;
61+
62+
// Used by Add & Remove to get strings from document
63+
virtual std::optional<StringList> GetStrings(const DocumentAccessor& doc,
64+
std::string_view field) const = 0;
65+
66+
// Used by Add & Remove to tokenize text value
67+
virtual absl::flat_hash_set<std::string> Tokenize(std::string_view value) const = 0;
68+
6369
Container* GetOrCreate(std::string_view word);
6470

6571
bool case_sensitive_ = false;
@@ -75,6 +81,9 @@ struct TextIndex : public BaseStringIndex<CompressedSortedSet> {
7581
: BaseStringIndex(mr, false), stopwords_{stopwords} {
7682
}
7783

84+
protected:
85+
std::optional<StringList> GetStrings(const DocumentAccessor& doc,
86+
std::string_view field) const override;
7887
absl::flat_hash_set<std::string> Tokenize(std::string_view value) const override;
7988

8089
private:
@@ -88,6 +97,9 @@ struct TagIndex : public BaseStringIndex<SortedVector> {
8897
: BaseStringIndex(mr, params.case_sensitive), separator_{params.separator} {
8998
}
9099

100+
protected:
101+
std::optional<StringList> GetStrings(const DocumentAccessor& doc,
102+
std::string_view field) const override;
91103
absl::flat_hash_set<std::string> Tokenize(std::string_view value) const override;
92104

93105
private:

src/core/search/search_test.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ struct MockedDocument : public DocumentAccessor {
5252
return StringList{string_view{it->second}};
5353
}
5454

55+
std::optional<StringList> GetTags(string_view field) const override {
56+
return GetStrings(field);
57+
}
58+
5559
std::optional<VectorInfo> GetVector(string_view field) const override {
5660
auto strings_list = GetStrings(field);
5761
if (!strings_list)

src/core/search/sort_indices.cc

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,33 @@
1010
#include <absl/strings/str_split.h>
1111

1212
#include <algorithm>
13+
#include <optional>
1314
#include <type_traits>
15+
#include <variant>
1416

1517
namespace dfly::search {
1618

1719
using namespace std;
1820

1921
namespace {} // namespace
2022

23+
template <typename T> bool SimpleValueSortIndex<T>::ParsedSortValue::HasValue() const {
24+
return !std::holds_alternative<std::monostate>(value);
25+
}
26+
27+
template <typename T> bool SimpleValueSortIndex<T>::ParsedSortValue::IsNullValue() const {
28+
return std::holds_alternative<std::nullopt_t>(value);
29+
}
30+
2131
template <typename T>
2232
SimpleValueSortIndex<T>::SimpleValueSortIndex(PMR_NS::memory_resource* mr) : values_{mr} {
2333
}
2434

2535
template <typename T> SortableValue SimpleValueSortIndex<T>::Lookup(DocId doc) const {
36+
if (null_values_.contains(doc)) {
37+
return std::monostate{};
38+
}
39+
2640
DCHECK_LT(doc, values_.size());
2741
if constexpr (is_same_v<T, PMR_NS::string>) {
2842
return std::string(values_[doc]);
@@ -48,21 +62,30 @@ std::vector<ResultScore> SimpleValueSortIndex<T>::Sort(std::vector<DocId>* ids,
4862
template <typename T>
4963
bool SimpleValueSortIndex<T>::Add(DocId id, const DocumentAccessor& doc, std::string_view field) {
5064
auto field_value = Get(doc, field);
51-
if (!field_value) {
65+
if (!field_value.HasValue()) {
5266
return false;
5367
}
5468

55-
DCHECK_LE(id, values_.size()); // Doc ids grow at most by one
69+
if (field_value.IsNullValue()) {
70+
null_values_.insert(id);
71+
return true;
72+
}
73+
5674
if (id >= values_.size())
5775
values_.resize(id + 1);
5876

59-
values_[id] = field_value.value();
77+
values_[id] = std::move(std::get<T>(field_value.value));
6078
return true;
6179
}
6280

6381
template <typename T>
6482
void SimpleValueSortIndex<T>::Remove(DocId id, const DocumentAccessor& doc,
6583
std::string_view field) {
84+
if (auto it = null_values_.find(id); it != null_values_.end()) {
85+
null_values_.erase(it);
86+
return;
87+
}
88+
6689
DCHECK_LT(id, values_.size());
6790
values_[id] = T{};
6891
}
@@ -74,22 +97,28 @@ template <typename T> PMR_NS::memory_resource* SimpleValueSortIndex<T>::GetMemRe
7497
template struct SimpleValueSortIndex<double>;
7598
template struct SimpleValueSortIndex<PMR_NS::string>;
7699

77-
std::optional<double> NumericSortIndex::Get(const DocumentAccessor& doc, std::string_view field) {
100+
SimpleValueSortIndex<double>::ParsedSortValue NumericSortIndex::Get(const DocumentAccessor& doc,
101+
std::string_view field) {
78102
auto numbers_list = doc.GetNumbers(field);
79103
if (!numbers_list) {
80-
return std::nullopt;
104+
return {};
105+
}
106+
if (numbers_list->empty()) {
107+
return ParsedSortValue{std::nullopt};
81108
}
82-
return !numbers_list->empty() ? numbers_list->front() : 0.0;
109+
return ParsedSortValue{numbers_list->front()};
83110
}
84111

85-
std::optional<PMR_NS::string> StringSortIndex::Get(const DocumentAccessor& doc,
86-
std::string_view field) {
87-
auto strings_list = doc.GetStrings(field);
112+
SimpleValueSortIndex<PMR_NS::string>::ParsedSortValue StringSortIndex::Get(
113+
const DocumentAccessor& doc, std::string_view field) {
114+
auto strings_list = doc.GetTags(field);
88115
if (!strings_list) {
89-
return std::nullopt;
116+
return {};
117+
}
118+
if (strings_list->empty()) {
119+
return ParsedSortValue{std::nullopt};
90120
}
91-
return !strings_list->empty() ? PMR_NS::string{strings_list->front(), GetMemRes()}
92-
: PMR_NS::string{GetMemRes()};
121+
return ParsedSortValue{PMR_NS::string{strings_list->front(), GetMemRes()}};
93122
}
94123

95124
} // namespace dfly::search

src/core/search/sort_indices.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,19 @@
1818

1919
namespace dfly::search {
2020

21-
template <typename T> struct SimpleValueSortIndex : BaseSortIndex {
21+
template <typename T> struct SimpleValueSortIndex : public BaseSortIndex {
22+
protected:
23+
struct ParsedSortValue {
24+
bool HasValue() const;
25+
bool IsNullValue() const;
26+
27+
// std::monostate - no value was found.
28+
// std::nullopt - found value is null.
29+
// T - found value.
30+
std::variant<std::monostate, std::nullopt_t, T> value;
31+
};
32+
33+
public:
2234
SimpleValueSortIndex(PMR_NS::memory_resource* mr);
2335

2436
SortableValue Lookup(DocId doc) const override;
@@ -28,25 +40,26 @@ template <typename T> struct SimpleValueSortIndex : BaseSortIndex {
2840
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
2941

3042
protected:
31-
virtual std::optional<T> Get(const DocumentAccessor& doc, std::string_view field_value) = 0;
43+
virtual ParsedSortValue Get(const DocumentAccessor& doc, std::string_view field_value) = 0;
3244

3345
PMR_NS::memory_resource* GetMemRes() const;
3446

3547
private:
3648
PMR_NS::vector<T> values_;
49+
absl::flat_hash_set<DocId> null_values_;
3750
};
3851

3952
struct NumericSortIndex : public SimpleValueSortIndex<double> {
4053
NumericSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {};
4154

42-
std::optional<double> Get(const DocumentAccessor& doc, std::string_view field) override;
55+
ParsedSortValue Get(const DocumentAccessor& doc, std::string_view field) override;
4356
};
4457

4558
// TODO: Map tags to integers for fast sort
4659
struct StringSortIndex : public SimpleValueSortIndex<PMR_NS::string> {
4760
StringSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {};
4861

49-
std::optional<PMR_NS::string> Get(const DocumentAccessor& doc, std::string_view field) override;
62+
ParsedSortValue Get(const DocumentAccessor& doc, std::string_view field) override;
5063
};
5164

5265
} // namespace dfly::search

0 commit comments

Comments
 (0)