Skip to content

Commit a5de48c

Browse files
authored
Don't skip null values on COUNT(*) in column shards (#8024)
1 parent bb1c803 commit a5de48c

File tree

12 files changed

+317
-12
lines changed

12 files changed

+317
-12
lines changed

ydb/core/formats/arrow/custom_registry.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <AggregateFunctions/AggregateFunctionMinMaxAny.h>
1414
#include <AggregateFunctions/AggregateFunctionSum.h>
1515
#include <AggregateFunctions/AggregateFunctionAvg.h>
16+
#include <AggregateFunctions/AggregateFunctionNumRows.h>
1617
#endif
1718

1819
namespace cp = ::arrow::compute;
@@ -62,6 +63,10 @@ static void RegisterYdbCast(cp::FunctionRegistry* registry) {
6263
Y_ABORT_UNLESS(registry->AddFunction(std::make_shared<YdbCastMetaFunction>()).ok());
6364
}
6465

66+
static void RegisterCustomAggregates(cp::FunctionRegistry* registry) {
67+
Y_ABORT_UNLESS(registry->AddFunction(std::make_shared<TNumRows>(GetFunctionName(EAggregate::NumRows))).ok());
68+
}
69+
6570
static void RegisterHouseAggregates(cp::FunctionRegistry* registry) {
6671
#ifndef WIN32
6772
try {
@@ -71,6 +76,7 @@ static void RegisterHouseAggregates(cp::FunctionRegistry* registry) {
7176
Y_ABORT_UNLESS(registry->AddFunction(std::make_shared<CH::WrappedMax>(GetHouseFunctionName(EAggregate::Max))).ok());
7277
Y_ABORT_UNLESS(registry->AddFunction(std::make_shared<CH::WrappedSum>(GetHouseFunctionName(EAggregate::Sum))).ok());
7378
//Y_ABORT_UNLESS(registry->AddFunction(std::make_shared<CH::WrappedAvg>(GetHouseFunctionName(EAggregate::Avg))).ok());
79+
Y_ABORT_UNLESS(registry->AddFunction(std::make_shared<CH::WrappedNumRows>(GetHouseFunctionName(EAggregate::NumRows))).ok());
7480

7581
Y_ABORT_UNLESS(registry->AddFunction(std::make_shared<CH::ArrowGroupBy>(GetHouseGroupByName())).ok());
7682
} catch (const std::exception& /*ex*/) {
@@ -88,6 +94,7 @@ static std::unique_ptr<cp::FunctionRegistry> CreateCustomRegistry() {
8894
RegisterRound(registry.get());
8995
RegisterArithmetic(registry.get());
9096
RegisterYdbCast(registry.get());
97+
RegisterCustomAggregates(registry.get());
9198
RegisterHouseAggregates(registry.get());
9299
return registry;
93100
}

ydb/core/formats/arrow/program.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ const char * GetFunctionName(EAggregate op) {
398398
return "min_max";
399399
case EAggregate::Sum:
400400
return "sum";
401+
case EAggregate::NumRows:
402+
return "num_rows";
401403
#if 0 // TODO
402404
case EAggregate::Avg:
403405
return "mean";
@@ -424,6 +426,8 @@ const char * GetHouseFunctionName(EAggregate op) {
424426
case EAggregate::Avg:
425427
return "ch.avg";
426428
#endif
429+
case EAggregate::NumRows:
430+
return "ch.num_rows";
427431
default:
428432
break;
429433
}
@@ -448,6 +452,8 @@ CH::AggFunctionId GetHouseFunction(EAggregate op) {
448452
case EAggregate::Avg:
449453
return CH::AggFunctionId::AGG_AVG;
450454
#endif
455+
case EAggregate::NumRows:
456+
return CH::AggFunctionId::AGG_NUM_ROWS;
451457
default:
452458
break;
453459
}
@@ -678,6 +684,27 @@ IStepFunction<TAggregateAssign>::TPtr TAggregateAssign::GetFunction(arrow::compu
678684
return std::make_shared<TAggregateFunction>(ctx);
679685
}
680686

687+
TString TAggregateAssign::DebugString() const {
688+
TStringBuilder sb;
689+
sb << "{";
690+
if (Operation != EAggregate::Unspecified) {
691+
sb << "op=" << GetFunctionName(Operation) << ";";
692+
}
693+
if (Arguments.size()) {
694+
sb << "arguments=[";
695+
for (auto&& i : Arguments) {
696+
sb << i.DebugString() << ";";
697+
}
698+
sb << "];";
699+
}
700+
sb << "options=" << ScalarOpts.ToString() << ";";
701+
if (KernelFunction) {
702+
sb << "kernel=" << KernelFunction->name() << ";";
703+
}
704+
sb << "column=" << Column.DebugString() << ";";
705+
sb << "}";
706+
return sb;
707+
}
681708

682709
arrow::Status TProgramStep::ApplyAssignes(TDatumBatch& batch, arrow::compute::ExecContext* ctx) const {
683710
if (Assignes.empty()) {

ydb/core/formats/arrow/program.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ enum class EAggregate {
2121
Max = 4,
2222
Sum = 5,
2323
//Avg = 6,
24+
NumRows = 7,
2425
};
2526

2627
}
@@ -323,6 +324,7 @@ class TAggregateAssign {
323324
const arrow::compute::ScalarAggregateOptions* GetOptions() const { return &ScalarOpts; }
324325

325326
IStepFunction<TAggregateAssign>::TPtr GetFunction(arrow::compute::ExecContext* ctx) const;
327+
TString DebugString() const;
326328

327329
private:
328330
TColumnInfo Column;
@@ -372,10 +374,18 @@ class TProgramStep {
372374
sb << "];";
373375
}
374376
if (GroupBy.size()) {
375-
sb << "group_by_count=" << GroupBy.size() << "; ";
377+
sb << "group_by_assignes=[";
378+
for (auto&& i : GroupBy) {
379+
sb << i.DebugString() << ";";
380+
}
381+
sb << "];";
376382
}
377383
if (GroupByKeys.size()) {
378-
sb << "group_by_keys_count=" << GroupByKeys.size() << ";";
384+
sb << "group_by_keys=[";
385+
for (auto&& i : GroupByKeys) {
386+
sb << i.DebugString() << ";";
387+
}
388+
sb << "];";
379389
}
380390

381391
sb << "projections=[";

ydb/core/formats/arrow/ssa_program_optimizer.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "ssa_program_optimizer.h"
22

3+
#include <ydb/library/actors/core/log.h>
4+
35
namespace NKikimr::NSsa {
46

57
namespace {
@@ -11,7 +13,8 @@ void ReplaceCountAll(TProgram& program) {
1113
Y_ABORT_UNLESS(step);
1214

1315
for (auto& groupBy : step->MutableGroupBy()) {
14-
if (groupBy.GetOperation() == EAggregate::Count && groupBy.GetArguments().empty()) {
16+
if (groupBy.GetOperation() == EAggregate::NumRows) {
17+
AFL_VERIFY(groupBy.GetArguments().empty());
1518
if (step->GetGroupByKeys().size()) {
1619
groupBy.MutableArguments().push_back(step->GetGroupByKeys()[0]);
1720
} else {

ydb/core/kqp/ut/olap/kqp_olap_ut.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2761,6 +2761,96 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
27612761
}
27622762
}
27632763

2764+
Y_UNIT_TEST(CountWhereColumnIsNull) {
2765+
auto settings = TKikimrSettings()
2766+
.SetWithSampleTables(false);
2767+
TKikimrRunner kikimr(settings);
2768+
kikimr.GetTestServer().GetRuntime()->SetLogPriority(NKikimrServices::TX_COLUMNSHARD_SCAN, NActors::NLog::PRI_DEBUG);
2769+
2770+
TLocalHelper(kikimr).CreateTestOlapTable();
2771+
2772+
WriteTestData(kikimr, "/Root/olapStore/olapTable", 0, 1000000, 300, true);
2773+
2774+
auto client = kikimr.GetTableClient();
2775+
2776+
Tests::NCommon::TLoggerInit(kikimr).Initialize();
2777+
2778+
{
2779+
auto it = client.StreamExecuteScanQuery(R"(
2780+
--!syntax_v1
2781+
2782+
SELECT COUNT(*), COUNT(level)
2783+
FROM `/Root/olapStore/olapTable`
2784+
WHERE level IS NULL
2785+
)").GetValueSync();
2786+
2787+
UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString());
2788+
TString result = StreamResultToYson(it);
2789+
Cout << result << Endl;
2790+
CompareYson("[[100u;0u]]", result);
2791+
}
2792+
2793+
{
2794+
auto it = client.StreamExecuteScanQuery(R"(
2795+
--!syntax_v1
2796+
2797+
SELECT COUNT(*), COUNT(level)
2798+
FROM `/Root/olapStore/olapTable`
2799+
WHERE level IS NULL AND uid IS NOT NULL
2800+
)").GetValueSync();
2801+
2802+
UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString());
2803+
TString result = StreamResultToYson(it);
2804+
Cout << result << Endl;
2805+
CompareYson("[[100u;0u]]", result);
2806+
}
2807+
2808+
{
2809+
auto it = client.StreamExecuteScanQuery(R"(
2810+
--!syntax_v1
2811+
2812+
SELECT COUNT(*), COUNT(level)
2813+
FROM `/Root/olapStore/olapTable`
2814+
WHERE level IS NULL
2815+
GROUP BY level
2816+
)").GetValueSync();
2817+
2818+
UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString());
2819+
TString result = StreamResultToYson(it);
2820+
Cout << result << Endl;
2821+
CompareYson("[[100u;0u]]", result);
2822+
}
2823+
}
2824+
2825+
Y_UNIT_TEST(SimpleCount) {
2826+
auto settings = TKikimrSettings()
2827+
.SetWithSampleTables(false);
2828+
TKikimrRunner kikimr(settings);
2829+
kikimr.GetTestServer().GetRuntime()->SetLogPriority(NKikimrServices::TX_COLUMNSHARD_SCAN, NActors::NLog::PRI_DEBUG);
2830+
2831+
TLocalHelper(kikimr).CreateTestOlapTable();
2832+
2833+
WriteTestData(kikimr, "/Root/olapStore/olapTable", 0, 1000000, 300, true);
2834+
2835+
auto client = kikimr.GetTableClient();
2836+
2837+
Tests::NCommon::TLoggerInit(kikimr).Initialize();
2838+
2839+
{
2840+
auto it = client.StreamExecuteScanQuery(R"(
2841+
--!syntax_v1
2842+
2843+
SELECT COUNT(level)
2844+
FROM `/Root/olapStore/olapTable`
2845+
WHERE StartsWith(uid, "uid_")
2846+
)").GetValueSync();
2847+
2848+
UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString());
2849+
TString result = StreamResultToYson(it);
2850+
Cout << result << Endl;
2851+
CompareYson("[[200u]]", result);
2852+
}
2853+
}
27642854
}
27652855

27662856
}

ydb/core/tx/columnshard/engines/ut/ut_program.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,4 +852,64 @@ Y_UNIT_TEST_SUITE(TestProgram) {
852852
auto expected = result.BuildArrow();
853853
UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected->ToString());
854854
}
855+
856+
Y_UNIT_TEST(CountWithNulls) {
857+
TIndexInfo indexInfo = BuildTableInfo(testColumns, testKey);
858+
;
859+
NReader::NPlain::TIndexColumnResolver columnResolver(indexInfo);
860+
861+
NKikimrSSA::TProgram programProto;
862+
{
863+
auto* command = programProto.AddCommand();
864+
auto* functionProto = command->MutableAssign()->MutableFunction();
865+
auto* column = command->MutableAssign()->MutableColumn();
866+
column->SetName("0");
867+
auto* funcArg = functionProto->AddArguments();
868+
funcArg->SetName("uid");
869+
functionProto->SetId(NKikimrSSA::TProgram::TAssignment::EFunction::TProgram_TAssignment_EFunction_FUNC_IS_NULL);
870+
}
871+
{
872+
auto* command = programProto.AddCommand();
873+
auto* filter = command->MutableFilter();
874+
auto* predicate = filter->MutablePredicate();
875+
predicate->SetName("0");
876+
}
877+
{
878+
auto* command = programProto.AddCommand();
879+
auto* groupBy = command->MutableGroupBy();
880+
auto* aggregate = groupBy->AddAggregates();
881+
aggregate->MutableFunction()->SetId(static_cast<ui32>(NArrow::EAggregate::Count));
882+
aggregate->MutableColumn()->SetName("1");
883+
}
884+
{
885+
auto* command = programProto.AddCommand();
886+
auto* projectionProto = command->MutableProjection();
887+
auto* column = projectionProto->AddColumns();
888+
column->SetName("1");
889+
}
890+
const auto programSerialized = SerializeProgram(programProto);
891+
892+
TProgramContainer program;
893+
TString errors;
894+
UNIT_ASSERT_C(
895+
program.Init(columnResolver, NKikimrSchemeOp::EOlapProgramType::OLAP_PROGRAM_SSA_PROGRAM_WITH_PARAMETERS, programSerialized, errors),
896+
errors);
897+
898+
TTableUpdatesBuilder updates(NArrow::MakeArrowSchema({ std::make_pair("uid", TTypeInfo(NTypeIds::Utf8)) }));
899+
updates.AddRow().Add("a");
900+
updates.AddRow().AddNull();
901+
updates.AddRow().Add("bbb");
902+
updates.AddRow().AddNull();
903+
updates.AddRow().AddNull();
904+
905+
auto batch = updates.BuildArrow();
906+
auto res = program.ApplyProgram(batch);
907+
UNIT_ASSERT_C(res.ok(), res.ToString());
908+
909+
TTableUpdatesBuilder result(NArrow::MakeArrowSchema({ std::make_pair("1", TTypeInfo(NTypeIds::Uint64)) }));
910+
result.AddRow().Add<uint64_t>(3);
911+
912+
auto expected = result.BuildArrow();
913+
UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected->ToString());
914+
}
855915
}

ydb/core/tx/program/program.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ NSsa::TAggregateAssign TProgramBuilder::MakeAggregate(const NSsa::TColumnInfo& n
330330
}
331331
} else if (func.ArgumentsSize() == 0 && func.GetId() == TId::AGG_COUNT) {
332332
// COUNT(*) case
333-
return TAggregateAssign(name, EAggregate::Count);
333+
return TAggregateAssign(name, EAggregate::NumRows);
334334
}
335335
return TAggregateAssign(name); // !ok()
336336
}
@@ -483,7 +483,7 @@ bool TProgramContainer::Init(const IColumnResolver& columnResolver, const NKikim
483483
if (IS_DEBUG_LOG_ENABLED(NKikimrServices::TX_COLUMNSHARD)) {
484484
TString out;
485485
::google::protobuf::TextFormat::PrintToString(programProto, &out);
486-
AFL_DEBUG(NKikimrServices::TX_COLUMNSHARD)("program", out);
486+
AFL_DEBUG(NKikimrServices::TX_COLUMNSHARD)("event", "parse_program")("program", out);
487487
}
488488

489489
if (programProto.HasKernels()) {
@@ -496,6 +496,7 @@ bool TProgramContainer::Init(const IColumnResolver& columnResolver, const NKikim
496496
}
497497
return false;
498498
}
499+
AFL_DEBUG(NKikimrServices::TX_COLUMNSHARD)("event", "program_parsed")("result", DebugString());
499500

500501
return true;
501502
}

0 commit comments

Comments
 (0)