Skip to content

Commit 1fa9a47

Browse files
refactor(search_family): Add Aggregator class (#4290)
* refactor(search_family): Add Aggregator class Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * fix(aggregator_test): Fix tests failing Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * refactor: address comments Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * refactor: Restore the previous comment Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * refactor: address comments 2 Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * refactor: address comments 3 Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * fix(aggregator): Simplify comparator for the case when one of the values is not present Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> --------- Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
1 parent 8d66c25 commit 1fa9a47

File tree

5 files changed

+149
-104
lines changed

5 files changed

+149
-104
lines changed

src/server/search/aggregator.cc

Lines changed: 95 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -10,62 +10,99 @@ namespace dfly::aggregate {
1010

1111
namespace {
1212

13-
struct GroupStep {
14-
PipelineResult operator()(PipelineResult result) {
15-
// Separate items into groups
16-
absl::flat_hash_map<absl::FixedArray<Value>, std::vector<DocValues>> groups;
17-
for (auto& value : result.values) {
18-
groups[Extract(value)].push_back(std::move(value));
19-
}
13+
using ValuesList = absl::FixedArray<Value>;
2014

21-
// Restore DocValues and apply reducers
22-
std::vector<DocValues> out;
23-
while (!groups.empty()) {
24-
auto node = groups.extract(groups.begin());
25-
DocValues doc = Unpack(std::move(node.key()));
26-
for (auto& reducer : reducers_) {
27-
doc[reducer.result_field] = reducer.func({reducer.source_field, node.mapped()});
28-
}
29-
out.push_back(std::move(doc));
30-
}
15+
ValuesList ExtractFieldsValues(const DocValues& dv, absl::Span<const std::string> fields) {
16+
ValuesList out(fields.size());
17+
for (size_t i = 0; i < fields.size(); i++) {
18+
auto it = dv.find(fields[i]);
19+
out[i] = (it != dv.end()) ? it->second : Value{};
20+
}
21+
return out;
22+
}
3123

32-
absl::flat_hash_set<std::string> fields_to_print;
33-
fields_to_print.reserve(fields_.size() + reducers_.size());
24+
DocValues PackFields(ValuesList values, absl::Span<const std::string> fields) {
25+
DCHECK_EQ(values.size(), fields.size());
26+
DocValues out;
27+
for (size_t i = 0; i < fields.size(); i++)
28+
out[fields[i]] = std::move(values[i]);
29+
return out;
30+
}
3431

35-
for (auto& field : fields_) {
36-
fields_to_print.insert(std::move(field));
37-
}
38-
for (auto& reducer : reducers_) {
39-
fields_to_print.insert(std::move(reducer.result_field));
40-
}
32+
const Value kEmptyValue = Value{};
33+
34+
} // namespace
4135

42-
return {std::move(out), std::move(fields_to_print)};
36+
void Aggregator::DoGroup(absl::Span<const std::string> fields, absl::Span<const Reducer> reducers) {
37+
// Separate items into groups
38+
absl::flat_hash_map<ValuesList, std::vector<DocValues>> groups;
39+
for (auto& value : result.values) {
40+
groups[ExtractFieldsValues(value, fields)].push_back(std::move(value));
4341
}
4442

45-
absl::FixedArray<Value> Extract(const DocValues& dv) {
46-
absl::FixedArray<Value> out(fields_.size());
47-
for (size_t i = 0; i < fields_.size(); i++) {
48-
auto it = dv.find(fields_[i]);
49-
out[i] = (it != dv.end()) ? it->second : Value{};
43+
// Restore DocValues and apply reducers
44+
auto& values = result.values;
45+
values.clear();
46+
values.reserve(groups.size());
47+
while (!groups.empty()) {
48+
auto node = groups.extract(groups.begin());
49+
DocValues doc = PackFields(std::move(node.key()), fields);
50+
for (auto& reducer : reducers) {
51+
doc[reducer.result_field] = reducer.func({reducer.source_field, node.mapped()});
5052
}
51-
return out;
53+
values.push_back(std::move(doc));
5254
}
5355

54-
DocValues Unpack(absl::FixedArray<Value>&& values) {
55-
DCHECK_EQ(values.size(), fields_.size());
56-
DocValues out;
57-
for (size_t i = 0; i < fields_.size(); i++)
58-
out[fields_[i]] = std::move(values[i]);
59-
return out;
56+
auto& fields_to_print = result.fields_to_print;
57+
fields_to_print.clear();
58+
fields_to_print.reserve(fields.size() + reducers.size());
59+
60+
for (auto& field : fields) {
61+
fields_to_print.insert(field);
6062
}
63+
for (auto& reducer : reducers) {
64+
fields_to_print.insert(reducer.result_field);
65+
}
66+
}
6167

62-
std::vector<std::string> fields_;
63-
std::vector<Reducer> reducers_;
64-
};
68+
void Aggregator::DoSort(std::string_view field, bool descending) {
69+
/*
70+
Comparator for sorting DocValues by field.
71+
If some of the fields is not present in the DocValues, comparator returns:
72+
1. l_it == l.end() && r_it != r.end()
73+
asc -> false
74+
desc -> false
75+
2. l_it != l.end() && r_it == r.end()
76+
asc -> true
77+
desc -> true
78+
3. l_it == l.end() && r_it == r.end()
79+
asc -> false
80+
desc -> false
81+
*/
82+
auto comparator = [&](const DocValues& l, const DocValues& r) {
83+
auto l_it = l.find(field);
84+
auto r_it = r.find(field);
85+
86+
// If some of the values is not present
87+
if (l_it == l.end() || r_it == r.end()) {
88+
return l_it != l.end();
89+
}
6590

66-
const Value kEmptyValue = Value{};
91+
auto& lv = l_it->second;
92+
auto& rv = r_it->second;
93+
return !descending ? lv < rv : lv > rv;
94+
};
6795

68-
} // namespace
96+
std::sort(result.values.begin(), result.values.end(), std::move(comparator));
97+
98+
result.fields_to_print.insert(field);
99+
}
100+
101+
void Aggregator::DoLimit(size_t offset, size_t num) {
102+
auto& values = result.values;
103+
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
104+
values.resize(std::min(num, values.size()));
105+
}
69106

70107
const Value& ValueIterator::operator*() const {
71108
auto it = values_.front().find(field_);
@@ -109,48 +146,30 @@ Reducer::Func FindReducerFunc(ReducerFunc name) {
109146
return nullptr;
110147
}
111148

112-
PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
113-
std::vector<Reducer> reducers) {
114-
return GroupStep{std::vector<std::string>(fields.begin(), fields.end()), std::move(reducers)};
149+
AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reducer> reducers) {
150+
return [fields = std::move(fields), reducers = std::move(reducers)](Aggregator* aggregator) {
151+
aggregator->DoGroup(fields, reducers);
152+
};
115153
}
116154

117-
PipelineStep MakeSortStep(std::string_view field, bool descending) {
118-
return [field = std::string(field), descending](PipelineResult result) -> PipelineResult {
119-
auto& values = result.values;
120-
121-
std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
122-
auto it1 = l.find(field);
123-
auto it2 = r.find(field);
124-
return it1 == l.end() || (it2 != r.end() && it1->second < it2->second);
125-
});
126-
127-
if (descending) {
128-
std::reverse(values.begin(), values.end());
129-
}
130-
131-
result.fields_to_print.insert(field);
132-
return result;
155+
AggregationStep MakeSortStep(std::string field, bool descending) {
156+
return [field = std::move(field), descending](Aggregator* aggregator) {
157+
aggregator->DoSort(field, descending);
133158
};
134159
}
135160

136-
PipelineStep MakeLimitStep(size_t offset, size_t num) {
137-
return [offset, num](PipelineResult result) {
138-
auto& values = result.values;
139-
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
140-
values.resize(std::min(num, values.size()));
141-
return result;
142-
};
161+
AggregationStep MakeLimitStep(size_t offset, size_t num) {
162+
return [=](Aggregator* aggregator) { aggregator->DoLimit(offset, num); };
143163
}
144164

145-
PipelineResult Process(std::vector<DocValues> values,
146-
absl::Span<const std::string_view> fields_to_print,
147-
absl::Span<const PipelineStep> steps) {
148-
PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
165+
AggregationResult Process(std::vector<DocValues> values,
166+
absl::Span<const std::string_view> fields_to_print,
167+
absl::Span<const AggregationStep> steps) {
168+
Aggregator aggregator{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
149169
for (auto& step : steps) {
150-
PipelineResult step_result = step(std::move(result));
151-
result = std::move(step_result);
170+
step(&aggregator);
152171
}
153-
return result;
172+
return aggregator.result;
154173
}
155174

156175
} // namespace dfly::aggregate

src/server/search/aggregator.h

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,31 @@
1717

1818
namespace dfly::aggregate {
1919

20+
struct Reducer;
21+
2022
using Value = ::dfly::search::SortableValue;
21-
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline
2223

23-
struct PipelineResult {
24+
// DocValues sent through the pipeline
25+
// TODO: Replace DocValues with compact linear search map instead of hash map
26+
using DocValues = absl::flat_hash_map<std::string_view, Value>;
27+
28+
struct AggregationResult {
2429
// Values to be passed to the next step
25-
// TODO: Replace DocValues with compact linear search map instead of hash map
2630
std::vector<DocValues> values;
2731

2832
// Fields from values to be printed
29-
absl::flat_hash_set<std::string> fields_to_print;
33+
absl::flat_hash_set<std::string_view> fields_to_print;
34+
};
35+
36+
struct Aggregator {
37+
void DoGroup(absl::Span<const std::string> fields, absl::Span<const Reducer> reducers);
38+
void DoSort(std::string_view field, bool descending = false);
39+
void DoLimit(size_t offset, size_t num);
40+
41+
AggregationResult result;
3042
};
3143

32-
using PipelineStep = std::function<PipelineResult(PipelineResult)>; // Group, Sort, etc.
44+
using AggregationStep = std::function<void(Aggregator*)>; // Group, Sort, etc.
3345

3446
// Iterator over Span<DocValues> that yields doc[field] or monostate if not present.
3547
// Extra clumsy for STL compatibility!
@@ -79,18 +91,17 @@ enum class ReducerFunc { COUNT, COUNT_DISTINCT, SUM, AVG, MAX, MIN };
7991
Reducer::Func FindReducerFunc(ReducerFunc name);
8092

8193
// Make `GROUPBY [fields...]` with REDUCE step
82-
PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
83-
std::vector<Reducer> reducers);
94+
AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reducer> reducers);
8495

8596
// Make `SORTBY field [DESC]` step
86-
PipelineStep MakeSortStep(std::string_view field, bool descending = false);
97+
AggregationStep MakeSortStep(std::string field, bool descending = false);
8798

8899
// Make `LIMIT offset num` step
89-
PipelineStep MakeLimitStep(size_t offset, size_t num);
100+
AggregationStep MakeLimitStep(size_t offset, size_t num);
90101

91102
// Process values with given steps
92-
PipelineResult Process(std::vector<DocValues> values,
93-
absl::Span<const std::string_view> fields_to_print,
94-
absl::Span<const PipelineStep> steps);
103+
AggregationResult Process(std::vector<DocValues> values,
104+
absl::Span<const std::string_view> fields_to_print,
105+
absl::Span<const AggregationStep> steps);
95106

96107
} // namespace dfly::aggregate

src/server/search/aggregator_test.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ namespace dfly::aggregate {
1010

1111
using namespace std::string_literals;
1212

13+
using StepsList = std::vector<AggregationStep>;
14+
1315
TEST(AggregatorTest, Sort) {
1416
std::vector<DocValues> values = {
1517
DocValues{{"a", 1.0}},
1618
DocValues{{"a", 0.5}},
1719
DocValues{{"a", 1.5}},
1820
};
19-
PipelineStep steps[] = {MakeSortStep("a", false)};
21+
StepsList steps = {MakeSortStep("a", false)};
2022

2123
auto result = Process(values, {"a"}, steps);
2224

@@ -32,7 +34,8 @@ TEST(AggregatorTest, Limit) {
3234
DocValues{{"i", 3.0}},
3335
DocValues{{"i", 4.0}},
3436
};
35-
PipelineStep steps[] = {MakeLimitStep(1, 2)};
37+
38+
StepsList steps = {MakeLimitStep(1, 2)};
3639

3740
auto result = Process(values, {"i"}, steps);
3841

@@ -49,8 +52,8 @@ TEST(AggregatorTest, SimpleGroup) {
4952
DocValues{{"i", 4.0}, {"tag", "even"}},
5053
};
5154

52-
std::string_view fields[] = {"tag"};
53-
PipelineStep steps[] = {MakeGroupStep(fields, {})};
55+
std::vector<std::string> fields = {"tag"};
56+
StepsList steps = {MakeGroupStep(std::move(fields), {})};
5457

5558
auto result = Process(values, {"i", "tag"}, steps);
5659
EXPECT_EQ(result.values.size(), 2);
@@ -72,13 +75,14 @@ TEST(AggregatorTest, GroupWithReduce) {
7275
});
7376
}
7477

75-
std::string_view fields[] = {"tag"};
78+
std::vector<std::string> fields = {"tag"};
7679
std::vector<Reducer> reducers = {
7780
Reducer{"", "count", FindReducerFunc(ReducerFunc::COUNT)},
7881
Reducer{"i", "sum-i", FindReducerFunc(ReducerFunc::SUM)},
7982
Reducer{"half-i", "distinct-hi", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)},
8083
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
81-
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};
84+
85+
StepsList steps = {MakeGroupStep(std::move(fields), std::move(reducers))};
8286

8387
auto result = Process(values, {"i", "half-i", "tag"}, steps);
8488
EXPECT_EQ(result.values.size(), 2);

src/server/search/doc_index.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ struct AggregateParams {
168168
search::QueryParams params;
169169

170170
std::optional<SearchFieldsList> load_fields;
171-
std::vector<aggregate::PipelineStep> steps;
171+
std::vector<aggregate::AggregationStep> steps;
172172
};
173173

174174
// Stores basic info about a document index.

0 commit comments

Comments
 (0)