Skip to content

Commit aeeb625

Browse files
fix(search_family): Support multiple fields in SORTBY option in the FT.AGGREGATE command. SECOND PR (#4232)
fix(search_family): Support multiple fields in SORTBY option in the FT.AGGREGATE command fixes dragonfly#3631 Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
1 parent 3c7e312 commit aeeb625

File tree

5 files changed

+246
-25
lines changed

5 files changed

+246
-25
lines changed

src/server/search/aggregator.cc

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ void Aggregator::DoGroup(absl::Span<const std::string> fields, absl::Span<const
6565
}
6666
}
6767

68-
void Aggregator::DoSort(std::string_view field, bool descending) {
68+
void Aggregator::DoSort(const SortParams& sort_params) {
6969
/*
70-
Comparator for sorting DocValues by field.
70+
Comparator for sorting DocValues by fields.
7171
If some of the fields is not present in the DocValues, comparator returns:
7272
1. l_it == l.end() && r_it != r.end()
7373
asc -> false
@@ -80,22 +80,41 @@ void Aggregator::DoSort(std::string_view field, bool descending) {
8080
desc -> false
8181
*/
8282
auto comparator = [&](const DocValues& l, const DocValues& r) {
83-
auto l_it = l.find(field);
84-
auto r_it = r.find(field);
85-
86-
// If some of the values is not present
87-
if (l_it == l.end() || r_it == r.end()) {
88-
return l_it != l.end();
83+
for (const auto& [field, order] : sort_params.fields) {
84+
auto l_it = l.find(field);
85+
auto r_it = r.find(field);
86+
87+
// If some of the values is not present
88+
if (l_it == l.end() || r_it == r.end()) {
89+
if (l_it == l.end() && r_it == r.end()) {
90+
continue;
91+
}
92+
return l_it != l.end();
93+
}
94+
95+
const auto& lv = l_it->second;
96+
const auto& rv = r_it->second;
97+
if (lv == rv) {
98+
continue;
99+
}
100+
return order == SortParams::SortOrder::ASC ? lv < rv : lv > rv;
89101
}
90-
91-
auto& lv = l_it->second;
92-
auto& rv = r_it->second;
93-
return !descending ? lv < rv : lv > rv;
102+
return false;
94103
};
95104

96-
std::sort(result.values.begin(), result.values.end(), std::move(comparator));
105+
auto& values = result.values;
106+
if (sort_params.SortAll()) {
107+
std::sort(values.begin(), values.end(), comparator);
108+
} else {
109+
DCHECK_GE(sort_params.max, 0);
110+
const size_t limit = std::min(values.size(), size_t(sort_params.max));
111+
std::partial_sort(values.begin(), values.begin() + limit, values.end(), comparator);
112+
values.resize(limit);
113+
}
97114

98-
result.fields_to_print.insert(field);
115+
for (auto& field : sort_params.fields) {
116+
result.fields_to_print.insert(field.first);
117+
}
99118
}
100119

101120
void Aggregator::DoLimit(size_t offset, size_t num) {
@@ -152,10 +171,8 @@ AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reduc
152171
};
153172
}
154173

155-
AggregationStep MakeSortStep(std::string field, bool descending) {
156-
return [field = std::move(field), descending](Aggregator* aggregator) {
157-
aggregator->DoSort(field, descending);
158-
};
174+
AggregationStep MakeSortStep(SortParams sort_params) {
175+
return [params = std::move(sort_params)](Aggregator* aggregator) { aggregator->DoSort(params); };
159176
}
160177

161178
AggregationStep MakeLimitStep(size_t offset, size_t num) {

src/server/search/aggregator.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,27 @@ struct AggregationResult {
3333
absl::flat_hash_set<std::string_view> fields_to_print;
3434
};
3535

36+
struct SortParams {
37+
enum class SortOrder { ASC, DESC };
38+
39+
constexpr static int64_t kSortAll = -1;
40+
41+
bool SortAll() const {
42+
return max == kSortAll;
43+
}
44+
45+
/* Fields to sort by. If multiple fields are provided, sorting works hierarchically:
46+
- First, the i-th field is compared.
47+
- If the i-th field values are equal, the (i + 1)-th field is compared, and so on. */
48+
absl::InlinedVector<std::pair<std::string, SortOrder>, 2> fields;
49+
/* Max number of elements to include in the sorted result.
50+
If set, only the first [max] elements are fully sorted using partial_sort. */
51+
int64_t max = kSortAll;
52+
};
53+
3654
struct Aggregator {
3755
void DoGroup(absl::Span<const std::string> fields, absl::Span<const Reducer> reducers);
38-
void DoSort(std::string_view field, bool descending = false);
56+
void DoSort(const SortParams& sort_params);
3957
void DoLimit(size_t offset, size_t num);
4058

4159
AggregationResult result;
@@ -94,7 +112,7 @@ Reducer::Func FindReducerFunc(ReducerFunc name);
94112
AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reducer> reducers);
95113

96114
// Make `SORTBY field [DESC]` step
97-
AggregationStep MakeSortStep(std::string field, bool descending = false);
115+
AggregationStep MakeSortStep(SortParams sort_params);
98116

99117
// Make `LIMIT offset num` step
100118
AggregationStep MakeLimitStep(size_t offset, size_t num);

src/server/search/aggregator_test.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ TEST(AggregatorTest, Sort) {
1818
DocValues{{"a", 0.5}},
1919
DocValues{{"a", 1.5}},
2020
};
21-
StepsList steps = {MakeSortStep("a", false)};
21+
22+
SortParams params;
23+
params.fields.emplace_back("a", SortParams::SortOrder::ASC);
24+
StepsList steps = {MakeSortStep(std::move(params))};
2225

2326
auto result = Process(values, {"a"}, steps);
2427

src/server/search/search_family.cc

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,42 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyB
306306
return params;
307307
}
308308

309+
std::optional<aggregate::SortParams> ParseAggregatorSortParams(CmdArgParser* parser) {
310+
using SordOrder = aggregate::SortParams::SortOrder;
311+
312+
size_t strings_num = parser->Next<size_t>();
313+
314+
aggregate::SortParams sort_params;
315+
sort_params.fields.reserve(strings_num / 2);
316+
317+
while (parser->HasNext() && strings_num > 0) {
318+
// TODO: Throw an error if the field has no '@' sign at the beginning
319+
std::string_view parsed_field = ParseFieldWithAtSign(parser);
320+
strings_num--;
321+
322+
SordOrder sord_order = SordOrder::ASC;
323+
if (strings_num > 0) {
324+
auto order = parser->TryMapNext("ASC", SordOrder::ASC, "DESC", SordOrder::DESC);
325+
if (order) {
326+
sord_order = order.value();
327+
strings_num--;
328+
}
329+
}
330+
331+
sort_params.fields.emplace_back(parsed_field, sord_order);
332+
}
333+
334+
if (strings_num) {
335+
return std::nullopt;
336+
}
337+
338+
if (parser->Check("MAX")) {
339+
sort_params.max = parser->Next<size_t>();
340+
}
341+
342+
return sort_params;
343+
}
344+
309345
optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
310346
SinkReplyBuilder* builder) {
311347
AggregateParams params;
@@ -372,11 +408,13 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
372408

373409
// SORTBY nargs
374410
if (parser.Check("SORTBY")) {
375-
parser.ExpectTag("1");
376-
string_view field = parser.Next();
377-
bool desc = bool(parser.Check("DESC"));
411+
auto sort_params = ParseAggregatorSortParams(&parser);
412+
if (!sort_params) {
413+
builder->SendError("bad arguments for SORTBY: specified invalid number of strings");
414+
return nullopt;
415+
}
378416

379-
params.steps.push_back(aggregate::MakeSortStep(std::string{field}, desc));
417+
params.steps.push_back(aggregate::MakeSortStep(std::move(sort_params).value()));
380418
continue;
381419
}
382420

src/server/search/search_family_test.cc

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,4 +1680,149 @@ TEST_F(SearchFamilyTest, AggregateResultFields) {
16801680
IsMap(), IsMap()));
16811681
}
16821682

1683+
TEST_F(SearchFamilyTest, AggregateSortByJson) {
1684+
Run({"JSON.SET", "j1", "$", R"({"name": "first", "number": 1200, "group": "first"})"});
1685+
Run({"JSON.SET", "j2", "$", R"({"name": "second", "number": 800, "group": "first"})"});
1686+
Run({"JSON.SET", "j3", "$", R"({"name": "third", "number": 300, "group": "first"})"});
1687+
Run({"JSON.SET", "j4", "$", R"({"name": "fourth", "number": 400, "group": "second"})"});
1688+
Run({"JSON.SET", "j5", "$", R"({"name": "fifth", "number": 900, "group": "second"})"});
1689+
Run({"JSON.SET", "j6", "$", R"({"name": "sixth", "number": 300, "group": "first"})"});
1690+
Run({"JSON.SET", "j7", "$", R"({"name": "seventh", "number": 400, "group": "second"})"});
1691+
Run({"JSON.SET", "j8", "$", R"({"name": "eighth", "group": "first"})"});
1692+
Run({"JSON.SET", "j9", "$", R"({"name": "ninth", "group": "second"})"});
1693+
1694+
Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.name", "AS", "name", "TEXT", "$.number",
1695+
"AS", "number", "NUMERIC", "$.group", "AS", "group", "TAG"});
1696+
1697+
// Test sorting by name (DESC) and number (ASC)
1698+
auto resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "4", "@name", "DESC", "@number", "ASC"});
1699+
EXPECT_THAT(resp, IsUnordArrayWithSize(
1700+
IsMap("name", "\"third\"", "number", "300"),
1701+
IsMap("name", "\"sixth\"", "number", "300"),
1702+
IsMap("name", "\"seventh\"", "number", "400"),
1703+
IsMap("name", "\"second\"", "number", "800"), IsMap("name", "\"ninth\""),
1704+
IsMap("name", "\"fourth\"", "number", "400"),
1705+
IsMap("name", "\"first\"", "number", "1200"),
1706+
IsMap("name", "\"fifth\"", "number", "900"), IsMap("name", "\"eighth\"")));
1707+
1708+
// Test sorting by name (ASC) and number (DESC)
1709+
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "4", "@name", "ASC", "@number", "DESC"});
1710+
EXPECT_THAT(resp, IsUnordArrayWithSize(
1711+
IsMap("name", "\"eighth\""), IsMap("name", "\"fifth\"", "number", "900"),
1712+
IsMap("name", "\"first\"", "number", "1200"),
1713+
IsMap("name", "\"fourth\"", "number", "400"), IsMap("name", "\"ninth\""),
1714+
IsMap("name", "\"second\"", "number", "800"),
1715+
IsMap("name", "\"seventh\"", "number", "400"),
1716+
IsMap("name", "\"sixth\"", "number", "300"),
1717+
IsMap("name", "\"third\"", "number", "300")));
1718+
1719+
// Test sorting by group (ASC), number (DESC), and name
1720+
resp = Run(
1721+
{"FT.AGGREGATE", "index", "*", "SORTBY", "5", "@group", "ASC", "@number", "DESC", "@name"});
1722+
EXPECT_THAT(resp, IsUnordArrayWithSize(
1723+
IsMap("group", "\"first\"", "number", "1200", "name", "\"first\""),
1724+
IsMap("group", "\"first\"", "number", "800", "name", "\"second\""),
1725+
IsMap("group", "\"first\"", "number", "300", "name", "\"sixth\""),
1726+
IsMap("group", "\"first\"", "number", "300", "name", "\"third\""),
1727+
IsMap("group", "\"first\"", "name", "\"eighth\""),
1728+
IsMap("group", "\"second\"", "number", "900", "name", "\"fifth\""),
1729+
IsMap("group", "\"second\"", "number", "400", "name", "\"fourth\""),
1730+
IsMap("group", "\"second\"", "number", "400", "name", "\"seventh\""),
1731+
IsMap("group", "\"second\"", "name", "\"ninth\"")));
1732+
1733+
// Test sorting by number (ASC), group (DESC), and name
1734+
resp = Run(
1735+
{"FT.AGGREGATE", "index", "*", "SORTBY", "5", "@number", "ASC", "@group", "DESC", "@name"});
1736+
EXPECT_THAT(resp, IsUnordArrayWithSize(
1737+
IsMap("number", "300", "group", "\"first\"", "name", "\"sixth\""),
1738+
IsMap("number", "300", "group", "\"first\"", "name", "\"third\""),
1739+
IsMap("number", "400", "group", "\"second\"", "name", "\"fourth\""),
1740+
IsMap("number", "400", "group", "\"second\"", "name", "\"seventh\""),
1741+
IsMap("number", "800", "group", "\"first\"", "name", "\"second\""),
1742+
IsMap("number", "900", "group", "\"second\"", "name", "\"fifth\""),
1743+
IsMap("number", "1200", "group", "\"first\"", "name", "\"first\""),
1744+
IsMap("group", "\"second\"", "name", "\"ninth\""),
1745+
IsMap("group", "\"first\"", "name", "\"eighth\"")));
1746+
1747+
// Test sorting with MAX 3
1748+
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@number", "MAX", "3"});
1749+
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "300"), IsMap("number", "300"),
1750+
IsMap("number", "400")));
1751+
1752+
// Test sorting with MAX 3
1753+
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "2", "@number", "DESC", "MAX", "3"});
1754+
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "1200"), IsMap("number", "900"),
1755+
IsMap("number", "800")));
1756+
1757+
// Test sorting by number (ASC) with MAX 999
1758+
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@number", "MAX", "999"});
1759+
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "300"), IsMap("number", "300"),
1760+
IsMap("number", "400"), IsMap("number", "400"),
1761+
IsMap("number", "800"), IsMap("number", "900"),
1762+
IsMap("number", "1200"), IsMap(), IsMap()));
1763+
1764+
// Test sorting by name and number (DESC)
1765+
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "3", "@name", "@number", "DESC"});
1766+
EXPECT_THAT(resp, IsUnordArrayWithSize(
1767+
IsMap("name", "\"eighth\""), IsMap("name", "\"fifth\"", "number", "900"),
1768+
IsMap("name", "\"first\"", "number", "1200"),
1769+
IsMap("name", "\"fourth\"", "number", "400"), IsMap("name", "\"ninth\""),
1770+
IsMap("name", "\"second\"", "number", "800"),
1771+
IsMap("name", "\"seventh\"", "number", "400"),
1772+
IsMap("name", "\"sixth\"", "number", "300"),
1773+
IsMap("name", "\"third\"", "number", "300")));
1774+
1775+
// Test SORTBY with MAX, GROUPBY, and REDUCE COUNT
1776+
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "3", "GROUPBY", "1",
1777+
"@number", "REDUCE", "COUNT", "0", "AS", "count"});
1778+
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "900", "count", "1"),
1779+
IsMap("number", ArgType(RespExpr::NIL), "count", "1"),
1780+
IsMap("number", "1200", "count", "1")));
1781+
1782+
// Test SORTBY with MAX, GROUPBY (0 fields), and REDUCE COUNT
1783+
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "3", "GROUPBY", "0",
1784+
"REDUCE", "COUNT", "0", "AS", "count"});
1785+
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("count", "3")));
1786+
}
1787+
1788+
TEST_F(SearchFamilyTest, AggregateSortByParsingErrors) {
1789+
Run({"JSON.SET", "j1", "$", R"({"name": "first", "number": 1200, "group": "first"})"});
1790+
Run({"JSON.SET", "j2", "$", R"({"name": "second", "number": 800, "group": "first"})"});
1791+
Run({"JSON.SET", "j3", "$", R"({"name": "third", "number": 300, "group": "first"})"});
1792+
Run({"JSON.SET", "j4", "$", R"({"name": "fourth", "number": 400, "group": "second"})"});
1793+
Run({"JSON.SET", "j5", "$", R"({"name": "fifth", "number": 900, "group": "second"})"});
1794+
Run({"JSON.SET", "j6", "$", R"({"name": "sixth", "number": 300, "group": "first"})"});
1795+
Run({"JSON.SET", "j7", "$", R"({"name": "seventh", "number": 400, "group": "second"})"});
1796+
Run({"JSON.SET", "j8", "$", R"({"name": "eighth", "group": "first"})"});
1797+
Run({"JSON.SET", "j9", "$", R"({"name": "ninth", "group": "second"})"});
1798+
1799+
Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.name", "AS", "name", "TEXT", "$.number",
1800+
"AS", "number", "NUMERIC", "$.group", "AS", "group", "TAG"});
1801+
1802+
// Test SORTBY with invalid argument count
1803+
auto resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "999", "@name", "@number", "DESC"});
1804+
EXPECT_THAT(resp, ErrArg("bad arguments for SORTBY: specified invalid number of strings"));
1805+
1806+
// Test SORTBY with negative argument count
1807+
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "-3", "@name", "@number", "DESC"});
1808+
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
1809+
1810+
// Test MAX with invalid value
1811+
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "-10"});
1812+
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
1813+
1814+
// Test MAX without a value
1815+
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX"});
1816+
EXPECT_THAT(resp, ErrArg("syntax error"));
1817+
1818+
// Test SORTBY with a non-existing field
1819+
/* Temporary unsupported
1820+
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@nonexistingfield"});
1821+
EXPECT_THAT(resp, ErrArg("Property `nonexistingfield` not loaded nor in schema")); */
1822+
1823+
// Test SORTBY with an invalid value
1824+
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "notvalue", "@name"});
1825+
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
1826+
}
1827+
16831828
} // namespace dfly

0 commit comments

Comments
 (0)