Skip to content

Commit c320ff3

Browse files
committed
ListSample/ListSampleN/ListShuffle implementation
commit_hash:987b10b398caa89eee8b94b33f9ea1dc74197223
1 parent 00bc077 commit c320ff3

File tree

21 files changed

+705
-0
lines changed

21 files changed

+705
-0
lines changed

yql/essentials/core/common_opt/yql_co_simple1.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3675,6 +3675,28 @@ bool IsEarlyExpandOfSkipNullAllowed(const TOptimizeContext& optCtx) {
36753675
return optCtx.Types->OptimizerFlags.contains(skipNullFlags);
36763676
}
36773677

3678+
TExprNode::TPtr ReplaceFuncWithImpl(const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx) {
3679+
auto exportsPtr = optCtx.Types->Modules->GetModule("/lib/yql/core.yql");
3680+
YQL_ENSURE(exportsPtr);
3681+
const auto& exports = exportsPtr->Symbols();
3682+
const auto ex = exports.find(TString(node->Content()) + "Impl");
3683+
YQL_ENSURE(exports.cend() != ex);
3684+
TNodeOnNodeOwnedMap deepClones;
3685+
auto lambda = ctx.DeepCopy(*ex->second, exportsPtr->ExprCtx(), deepClones, true, false);
3686+
3687+
YQL_CLOG(DEBUG, Core) << "Replace " << node->Content() << " with implementation";
3688+
return ctx.Builder(node->Pos())
3689+
.Apply(lambda)
3690+
.Do([&node](TExprNodeReplaceBuilder& builder) -> TExprNodeReplaceBuilder& {
3691+
for (size_t i = 0; i < node->ChildrenSize(); i++) {
3692+
builder.With(i, node->ChildPtr(i));
3693+
}
3694+
return builder;
3695+
})
3696+
.Seal()
3697+
.Build();
3698+
}
3699+
36783700
} // namespace
36793701

36803702
void RegisterCoSimpleCallables1(TCallableOptimizerMap& map) {
@@ -4897,6 +4919,65 @@ void RegisterCoSimpleCallables1(TCallableOptimizerMap& map) {
48974919
return node;
48984920
};
48994921

4922+
map["ListSample"] = map["ListSampleN"] = [](const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx) {
4923+
if (node->Child(0)->GetTypeAnn()->GetKind() == ETypeAnnotationKind::Optional) {
4924+
YQL_CLOG(DEBUG, Core) << "Handle optional list in " << node->Content();
4925+
return ctx.Builder(node->Pos())
4926+
.Callable("Map")
4927+
.Add(0, node->Child(0))
4928+
.Lambda(1)
4929+
.Param("list")
4930+
.Callable(node->Content())
4931+
.Arg(0, "list")
4932+
.Add(1, node->Child(1))
4933+
.Add(2, node->Child(2))
4934+
.Seal()
4935+
.Seal()
4936+
.Seal()
4937+
.Build();
4938+
}
4939+
4940+
if (node->Child(1)->GetTypeAnn()->GetKind() == ETypeAnnotationKind::Optional) {
4941+
YQL_CLOG(DEBUG, Core) << "Handle optional prob arg in " << node->Content();
4942+
return ctx.Builder(node->Pos())
4943+
.Callable("IfPresent")
4944+
.Add(0, node->Child(1))
4945+
.Lambda(1)
4946+
.Param("probArg")
4947+
.Callable(node->Content())
4948+
.Add(0, node->Child(0))
4949+
.Arg(1, "probArg")
4950+
.Add(2, node->Child(2))
4951+
.Seal()
4952+
.Seal()
4953+
.Add(2, node->Child(0))
4954+
.Seal()
4955+
.Build();
4956+
}
4957+
4958+
return ReplaceFuncWithImpl(node, ctx, optCtx);
4959+
};
4960+
4961+
map["ListShuffle"] = [](const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx) {
4962+
if (node->Head().GetTypeAnn()->GetKind() == ETypeAnnotationKind::Optional) {
4963+
YQL_CLOG(DEBUG, Core) << "Handle optionals args in " << node->Content();
4964+
return ctx.Builder(node->Pos())
4965+
.Callable("Map")
4966+
.Add(0, node->Child(0))
4967+
.Lambda(1)
4968+
.Param("list")
4969+
.Callable(node->Content())
4970+
.Arg(0, "list")
4971+
.Add(1, node->Child(1))
4972+
.Seal()
4973+
.Seal()
4974+
.Seal()
4975+
.Build();
4976+
}
4977+
4978+
return ReplaceFuncWithImpl(node, ctx, optCtx);
4979+
};
4980+
49004981
map["OptionalReduce"] = std::bind(&RemoveOptionalReduceOverData, _1, _2);
49014982

49024983
map["Fold"] = [](const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& /*optCtx*/) {

yql/essentials/core/type_ann/type_ann_core.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12594,6 +12594,9 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
1259412594
Functions["ListTopSort"] = &ListTopSortWrapper;
1259512595
Functions["ListTopSortAsc"] = &ListTopSortWrapper;
1259612596
Functions["ListTopSortDesc"] = &ListTopSortWrapper;
12597+
Functions["ListSample"] = &ListSampleWrapper;
12598+
Functions["ListSampleN"] = &ListSampleNWrapper;
12599+
Functions["ListShuffle"] = &ListShuffleWrapper;
1259712600

1259812601
Functions["ExpandMap"] = &ExpandMapWrapper;
1259912602
Functions["WideMap"] = &WideMapWrapper;

yql/essentials/core/type_ann/type_ann_list.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,115 @@ namespace {
15241524
return OptListWrapperImpl<1U>(input, output, ctx, "Collect");
15251525
}
15261526

1527+
IGraphTransformer::TStatus ListSampleWrapperCommon(const TExprNode::TPtr& input, TExprNode::TPtr& output, NUdf::EDataSlot probArgDataType, TContext& ctx) {
1528+
if (!EnsureMinMaxArgsCount(*input, 2, 3, ctx.Expr)) {
1529+
return IGraphTransformer::TStatus::Error;
1530+
}
1531+
1532+
if (IsNull(input->Head())) {
1533+
output = input->HeadPtr();
1534+
return IGraphTransformer::TStatus::Repeat;
1535+
}
1536+
1537+
if (!EnsureComputable(input->Head(), ctx.Expr)) {
1538+
return IGraphTransformer::TStatus::Error;
1539+
}
1540+
1541+
auto type = input->Head().GetTypeAnn();
1542+
if (type->GetKind() == ETypeAnnotationKind::Optional) {
1543+
type = type->Cast<TOptionalExprType>()->GetItemType();
1544+
}
1545+
1546+
if (type->GetKind() != ETypeAnnotationKind::List && type->GetKind() != ETypeAnnotationKind::EmptyList) {
1547+
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Head().Pos()), TStringBuilder()
1548+
<< "Expected (empty) list or optional of (empty) list, but got: " << *input->Head().GetTypeAnn()));
1549+
return IGraphTransformer::TStatus::Error;
1550+
}
1551+
1552+
if (type->GetKind() == ETypeAnnotationKind::EmptyList) {
1553+
output = input->HeadPtr();
1554+
return IGraphTransformer::TStatus::Repeat;
1555+
}
1556+
1557+
if (IsNull(*input->Child(1))) {
1558+
output = input->HeadPtr();
1559+
return IGraphTransformer::TStatus::Repeat;
1560+
}
1561+
1562+
if (!EnsureSpecificDataType(*input->Child(1), probArgDataType, ctx.Expr, true)) {
1563+
return IGraphTransformer::TStatus::Error;
1564+
}
1565+
1566+
if (input->ChildrenSize() == 2) {
1567+
auto children = input->ChildrenList();
1568+
children.push_back(ctx.Expr.NewCallable(input->Pos(), "Null", {}));
1569+
output = ctx.Expr.ChangeChildren(*input, std::move(children));
1570+
return IGraphTransformer::TStatus::Repeat;
1571+
}
1572+
YQL_ENSURE(input->ChildrenSize() == 3);
1573+
1574+
if (!EnsureComputable(*input->Child(2), ctx.Expr)) {
1575+
return IGraphTransformer::TStatus::Error;
1576+
}
1577+
1578+
input->SetTypeAnn(input->Head().GetTypeAnn());
1579+
return IGraphTransformer::TStatus::Ok;
1580+
}
1581+
1582+
IGraphTransformer::TStatus ListSampleWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
1583+
return ListSampleWrapperCommon(input, output, NUdf::EDataSlot::Double, ctx);
1584+
}
1585+
1586+
IGraphTransformer::TStatus ListSampleNWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
1587+
return ListSampleWrapperCommon(input, output, NUdf::EDataSlot::Uint64, ctx);
1588+
}
1589+
1590+
IGraphTransformer::TStatus ListShuffleWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
1591+
if (!EnsureMinMaxArgsCount(*input, 1, 2, ctx.Expr)) {
1592+
return IGraphTransformer::TStatus::Error;
1593+
}
1594+
1595+
if (IsNull(input->Head())) {
1596+
output = input->HeadPtr();
1597+
return IGraphTransformer::TStatus::Repeat;
1598+
}
1599+
1600+
if (!EnsureComputable(input->Head(), ctx.Expr)) {
1601+
return IGraphTransformer::TStatus::Error;
1602+
}
1603+
1604+
auto type = input->Head().GetTypeAnn();
1605+
if (type->GetKind() == ETypeAnnotationKind::Optional) {
1606+
type = type->Cast<TOptionalExprType>()->GetItemType();
1607+
}
1608+
1609+
if (type->GetKind() != ETypeAnnotationKind::List && type->GetKind() != ETypeAnnotationKind::EmptyList) {
1610+
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Head().Pos()), TStringBuilder()
1611+
<< "Expected (empty) list or optional of (empty) list, but got: " << *input->Head().GetTypeAnn()));
1612+
return IGraphTransformer::TStatus::Error;
1613+
}
1614+
1615+
if (type->GetKind() == ETypeAnnotationKind::EmptyList) {
1616+
output = input->HeadPtr();
1617+
return IGraphTransformer::TStatus::Repeat;
1618+
}
1619+
1620+
if (input->ChildrenSize() == 1) {
1621+
auto children = input->ChildrenList();
1622+
children.push_back(ctx.Expr.NewCallable(input->Pos(), "Null", {}));
1623+
output = ctx.Expr.ChangeChildren(*input, std::move(children));
1624+
return IGraphTransformer::TStatus::Repeat;
1625+
}
1626+
YQL_ENSURE(input->ChildrenSize() == 2);
1627+
1628+
if (!EnsureComputable(*input->Child(1), ctx.Expr)) {
1629+
return IGraphTransformer::TStatus::Error;
1630+
}
1631+
1632+
input->SetTypeAnn(input->Head().GetTypeAnn());
1633+
return IGraphTransformer::TStatus::Ok;
1634+
}
1635+
15271636
IGraphTransformer::TStatus OptListFold1WrapperImpl(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx, TExprNode::TPtr&& updateLambda) {
15281637
if (IsNull(input->Head())) {
15291638
output = input->HeadPtr();

yql/essentials/core/type_ann/type_ann_list.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ namespace NTypeAnnImpl {
4141
IGraphTransformer::TStatus ListTopSortWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
4242
IGraphTransformer::TStatus ListExtractWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
4343
IGraphTransformer::TStatus ListCollectWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
44+
IGraphTransformer::TStatus ListSampleWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
45+
IGraphTransformer::TStatus ListSampleNWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
46+
IGraphTransformer::TStatus ListShuffleWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
4447
IGraphTransformer::TStatus FoldMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
4548
IGraphTransformer::TStatus Fold1MapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
4649
IGraphTransformer::TStatus Chain1MapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);

yql/essentials/mount/lib/yql/core.yql

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,56 @@ def signature(script, name):
479479
(lambda '() (Apply ListToTupleImpl list n)))
480480
))
481481

482+
(let ListSampleImpl (lambda '(list probability dependsOn)
483+
(Filter list (lambda '(x) (< (Random (DependsOn '(x probability dependsOn))) probability)))
484+
))
485+
486+
(let ListSampleNImpl (lambda '(list count dependsOn) (block '(
487+
(let value_type (ListItemType (TypeOf list)))
488+
489+
(let UdfVectorCreate (Udf 'Vector.Create (Void) (TupleType (TupleType value_type (DataType 'Uint64)) (StructType) value_type)))
490+
491+
(let resource_type (TypeOf (Apply UdfVectorCreate (Uint32 '0))))
492+
493+
(let UdfVectorEmplace (Udf 'Vector.Emplace (Void) (TupleType (TupleType resource_type (DataType 'Uint64) value_type) (StructType) value_type)))
494+
(let UdfVectorSwap (Udf 'Vector.Swap (Void) (TupleType (TupleType resource_type (DataType 'Uint64) (DataType 'Uint64)) (StructType) value_type)))
495+
(let UdfVectorGetResult (Udf 'Vector.GetResult (Void) (TupleType (TupleType resource_type) (StructType) value_type)))
496+
497+
(return (Apply UdfVectorGetResult (Fold
498+
(Skip (Enumerate list) count)
499+
(Fold
500+
(Take list count)
501+
(NamedApply UdfVectorCreate '(count) (AsStruct) (DependsOn '(list dependsOn)))
502+
(lambda '(x y) (Apply UdfVectorEmplace y count x))
503+
)
504+
(lambda '(x y) (block '(
505+
(let pos (Coalesce (% (RandomNumber (DependsOn '(x count dependsOn))) (+ (Nth x '0) (Uint64 '1))) (Uint64 '0)))
506+
(return (If (< pos count) (Apply UdfVectorEmplace y pos (Nth x '1)) y))
507+
)))
508+
)))
509+
))))
510+
511+
(let ListShuffleImpl (lambda '(list dependsOn) (block '(
512+
(let value_type (ListItemType (TypeOf list)))
513+
514+
(let UdfVectorCreate (Udf 'Vector.Create (Void) (TupleType (TupleType value_type (DataType 'Uint64)) (StructType) value_type)))
515+
516+
(let resource_type (TypeOf (Apply UdfVectorCreate (Uint32 '0))))
517+
518+
(let UdfVectorEmplace (Udf 'Vector.Emplace (Void) (TupleType (TupleType resource_type (DataType 'Uint64) value_type) (StructType) value_type)))
519+
(let UdfVectorSwap (Udf 'Vector.Swap (Void) (TupleType (TupleType resource_type (DataType 'Uint64) (DataType 'Uint64)) (StructType) value_type)))
520+
(let UdfVectorGetResult (Udf 'Vector.GetResult (Void) (TupleType (TupleType resource_type) (StructType) value_type)))
521+
522+
(return (Apply UdfVectorGetResult (Fold
523+
(Enumerate list)
524+
(NamedApply UdfVectorCreate '((Uint32 '1)) (AsStruct) (DependsOn '(list dependsOn)))
525+
(lambda '(x y) (block '(
526+
(let pos (Coalesce (% (RandomNumber (DependsOn '(x dependsOn))) (+ (Nth x '0) (Uint64 '1))) (Uint64 '0)))
527+
(return (Apply UdfVectorSwap (Apply UdfVectorEmplace y (Nth x '0) (Nth x '1)) pos (Nth x '0)))
528+
)))
529+
)))
530+
))))
531+
482532
(export Equals)
483533
(export Unequals)
484534
(export FindIndex)
@@ -516,4 +566,7 @@ def signature(script, name):
516566
(export ForceSpreadMembers)
517567
(export ListFromTuple)
518568
(export ListToTuple)
569+
(export ListSampleImpl)
570+
(export ListSampleNImpl)
571+
(export ListShuffleImpl)
519572
)

yql/essentials/sql/v1/builtin.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2916,6 +2916,9 @@ struct TBuiltinFuncData {
29162916
{"listtopsort", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("ListTopSort", 2, 3)},
29172917
{"listtopsortasc", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("ListTopSortAsc", 2, 3)},
29182918
{"listtopsortdesc", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("ListTopSortDesc", 2, 3)},
2919+
{"listsample", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("ListSample", 2, 3)},
2920+
{"listsamplen", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("ListSampleN", 2, 3)},
2921+
{"listshuffle", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("ListShuffle", 1, 2)},
29192922

29202923
// Dict builtins
29212924
{"dictlength", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("Length", 1, 1)},

yql/essentials/tests/common/test_framework/udfs_deps/ya.make

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ SET(
1919
yql/essentials/udfs/common/url_base
2020
yql/essentials/udfs/common/unicode_base
2121
yql/essentials/udfs/common/streaming
22+
yql/essentials/udfs/common/vector
2223
yql/essentials/udfs/examples/callables
2324
yql/essentials/udfs/examples/dicts
2425
yql/essentials/udfs/examples/dummylog

yql/essentials/tests/sql/sql2yql/canondata/result.json

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6453,6 +6453,27 @@
64536453
"uri": "https://{canondata_backend}/1784117/d56ae82ad9d30397a41490647be1bd2124718f98/resource.tar.gz#test_sql2yql.test_expr-list_replicate_fail_/sql.yql"
64546454
}
64556455
],
6456+
"test_sql2yql.test[expr-list_sample]": [
6457+
{
6458+
"checksum": "922f4c9c5a2fe848f40272dd15cfde42",
6459+
"size": 10843,
6460+
"uri": "https://{canondata_backend}/1924537/278b77accb7596bd976e3e218425469d4b97dcf9/resource.tar.gz#test_sql2yql.test_expr-list_sample_/sql.yql"
6461+
}
6462+
],
6463+
"test_sql2yql.test[expr-list_sample_n]": [
6464+
{
6465+
"checksum": "5ce08b8b61ef8b2863f931bc1b986679",
6466+
"size": 7573,
6467+
"uri": "https://{canondata_backend}/1924537/278b77accb7596bd976e3e218425469d4b97dcf9/resource.tar.gz#test_sql2yql.test_expr-list_sample_n_/sql.yql"
6468+
}
6469+
],
6470+
"test_sql2yql.test[expr-list_shuffle]": [
6471+
{
6472+
"checksum": "3cd4f632706daf9ac8962369e7d0eac3",
6473+
"size": 4413,
6474+
"uri": "https://{canondata_backend}/1777230/f0ec95d2b2a3a38fc99b00afc1f2d60d2b3e8548/resource.tar.gz#test_sql2yql.test_expr-list_shuffle_/sql.yql"
6475+
}
6476+
],
64566477
"test_sql2yql.test[expr-list_takeskipwhile]": [
64576478
{
64586479
"checksum": "827d6c45ccb33ccc641531600fa839ce",
@@ -26319,6 +26340,27 @@
2631926340
"uri": "https://{canondata_backend}/1880306/64654158d6bfb1289c66c626a8162239289559d0/resource.tar.gz#test_sql_format.test_expr-list_replicate_fail_/formatted.sql"
2632026341
}
2632126342
],
26343+
"test_sql_format.test[expr-list_sample]": [
26344+
{
26345+
"checksum": "a642f47aa5488ecfa6450c114a85903d",
26346+
"size": 1235,
26347+
"uri": "https://{canondata_backend}/1942525/0302d8428323e9211161c4db74348074ea0aab49/resource.tar.gz#test_sql_format.test_expr-list_sample_/formatted.sql"
26348+
}
26349+
],
26350+
"test_sql_format.test[expr-list_sample_n]": [
26351+
{
26352+
"checksum": "4b04a240db2a66eab919da4fbbf3cdea",
26353+
"size": 1128,
26354+
"uri": "https://{canondata_backend}/1942525/0302d8428323e9211161c4db74348074ea0aab49/resource.tar.gz#test_sql_format.test_expr-list_sample_n_/formatted.sql"
26355+
}
26356+
],
26357+
"test_sql_format.test[expr-list_shuffle]": [
26358+
{
26359+
"checksum": "73822288846e1fc180736baa4a9548c7",
26360+
"size": 612,
26361+
"uri": "https://{canondata_backend}/1942525/0302d8428323e9211161c4db74348074ea0aab49/resource.tar.gz#test_sql_format.test_expr-list_shuffle_/formatted.sql"
26362+
}
26363+
],
2632226364
"test_sql_format.test[expr-list_takeskipwhile]": [
2632326365
{
2632426366
"checksum": "fe413941b62655034d49cd2674f2c947",
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
providers yt

0 commit comments

Comments
 (0)