Skip to content

Commit 52092f0

Browse files
author
babenko
committed
Fix unbounded recursion in T[Cancelable]BoundedConcurrencyRunner
commit_hash:59473e0eed6e5a362e144c261c9db74a7bc0d2cc
1 parent 8dff5c1 commit 52092f0

File tree

2 files changed

+86
-35
lines changed

2 files changed

+86
-35
lines changed

yt/yt/core/actions/future-inl.h

Lines changed: 67 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2461,7 +2461,7 @@ class TCancelableBoundedConcurrencyRunner
24612461
, ConcurrencyLimit_(concurrencyLimit)
24622462
, Futures_(Callbacks_.size(), VoidFuture)
24632463
, Results_(Callbacks_.size())
2464-
, CurrentIndex_(std::min<int>(ConcurrencyLimit_, ssize(Callbacks_)))
2464+
, CurrentIndex_(std::min<int>(ConcurrencyLimit_, std::ssize(Callbacks_)))
24652465
, FailOnFirstError_(failOnError)
24662466
{ }
24672467

@@ -2499,56 +2499,70 @@ class TCancelableBoundedConcurrencyRunner
24992499

25002500
void RunCallback(int index)
25012501
{
2502-
auto future = Callbacks_[index]();
2502+
// ORD-2002: Avoid calling OnResult directly to prevent unbounded
2503+
// recursive chains like RunCallback -> OnResult -> RunCallback...
2504+
while (true) {
2505+
auto future = Callbacks_[index]();
2506+
if (!future.IsSet()) {
2507+
{
2508+
auto guard = Guard(SpinLock_);
2509+
if (Error_) {
2510+
guard.Release();
2511+
future.Cancel(*Error_);
2512+
return;
2513+
}
25032514

2504-
if (future.IsSet()) {
2505-
OnResult(index, std::move(future.Get()));
2506-
return;
2507-
}
2515+
Futures_[index] = future.template As<void>();
2516+
}
25082517

2509-
{
2510-
auto guard = Guard(SpinLock_);
2511-
if (Error_) {
2512-
guard.Release();
2513-
future.Cancel(*Error_);
2514-
return;
2518+
future.Subscribe(
2519+
BIND_NO_PROPAGATE(&TCancelableBoundedConcurrencyRunner::OnResult, MakeStrong(this), index)
2520+
// NB: Sync invoker protects from unbounded recursion.
2521+
.Via(GetSyncInvoker()));
2522+
break;
25152523
}
25162524

2517-
Futures_[index] = future.template As<void>();
2518-
}
2525+
auto suggestedIndex = HandleResultAndSuggestNextIndex(index, std::move(future.Get()));
2526+
if (!suggestedIndex) {
2527+
break;
2528+
}
25192529

2520-
future.Subscribe(
2521-
BIND_NO_PROPAGATE(&TCancelableBoundedConcurrencyRunner::OnResult, MakeStrong(this), index)
2522-
// NB: Sync invoker protects from unbounded recursion.
2523-
.Via(GetSyncInvoker()));
2530+
index = *suggestedIndex;
2531+
}
25242532
}
25252533

2526-
void OnResult(int index, const NYT::TErrorOr<T>& result)
2534+
[[nodiscard]]
2535+
std::optional<int> HandleResultAndSuggestNextIndex(int index, const TErrorOr<T>& result)
25272536
{
25282537
if (FailOnFirstError_ && !result.IsOK()) {
25292538
OnError(result);
2530-
return;
2539+
return std::nullopt;
25312540
}
25322541

25332542
int newIndex;
25342543
int finishedCount;
25352544
{
25362545
auto guard = Guard(SpinLock_);
25372546
if (Error_) {
2538-
return;
2547+
return std::nullopt;
25392548
}
25402549

25412550
newIndex = CurrentIndex_++;
25422551
finishedCount = ++FinishedCount_;
25432552
Results_[index] = result;
25442553
}
25452554

2546-
if (finishedCount == ssize(Callbacks_)) {
2555+
if (finishedCount == std::ssize(Callbacks_)) {
25472556
Promise_.TrySet(Results_);
25482557
}
25492558

2550-
if (newIndex < ssize(Callbacks_)) {
2551-
RunCallback(newIndex);
2559+
return newIndex < std::ssize(Callbacks_) ? std::optional(newIndex) : std::nullopt;
2560+
}
2561+
2562+
void OnResult(int index, const TErrorOr<T>& result)
2563+
{
2564+
if (auto suggestedIndex = HandleResultAndSuggestNextIndex(index, result)) {
2565+
RunCallback(*suggestedIndex);
25522566
}
25532567
}
25542568

@@ -2616,26 +2630,44 @@ class TBoundedConcurrencyRunner
26162630

26172631
void RunCallback(int index)
26182632
{
2619-
auto future = Callbacks_[index]();
2620-
if (future.IsSet()) {
2621-
OnResult(index, future.Get());
2622-
} else {
2623-
future.Subscribe(
2624-
BIND_NO_PROPAGATE(&TBoundedConcurrencyRunner::OnResult, MakeStrong(this), index));
2633+
// ORD-2002: Avoid calling OnResult directly to prevent unbounded
2634+
// recursive chains like RunCallback -> OnResult -> RunCallback...
2635+
while (true) {
2636+
auto future = Callbacks_[index]();
2637+
if (!future.IsSet()) {
2638+
future.Subscribe(
2639+
BIND_NO_PROPAGATE(&TBoundedConcurrencyRunner::OnResult, MakeStrong(this), index)
2640+
// NB: Sync invoker protects from unbounded recursion.
2641+
.Via(GetSyncInvoker()));
2642+
break;
2643+
}
2644+
2645+
auto suggestedIndex = HandleResultAndSuggestNextIndex(index, future.Get());
2646+
if (!suggestedIndex) {
2647+
break;
2648+
}
2649+
2650+
index = *suggestedIndex;
26252651
}
26262652
}
26272653

2628-
void OnResult(int index, const NYT::TErrorOr<T>& result)
2654+
[[nodiscard]]
2655+
std::optional<int> HandleResultAndSuggestNextIndex(int index, const TErrorOr<T>& result)
26292656
{
26302657
Results_[index] = result;
26312658

26322659
int newIndex = CurrentIndex_++;
2633-
if (newIndex < static_cast<ssize_t>(Callbacks_.size())) {
2634-
RunCallback(newIndex);
2660+
if (++FinishedCount_ == std::ssize(Callbacks_)) {
2661+
Promise_.Set(Results_);
26352662
}
26362663

2637-
if (++FinishedCount_ == static_cast<ssize_t>(Callbacks_.size())) {
2638-
Promise_.Set(Results_);
2664+
return newIndex < std::ssize(Callbacks_) ? std::optional(newIndex) : std::nullopt;
2665+
}
2666+
2667+
void OnResult(int index, const TErrorOr<T>& result)
2668+
{
2669+
if (auto suggestedIndex = HandleResultAndSuggestNextIndex(index, result)) {
2670+
RunCallback(*suggestedIndex);
26392671
}
26402672
}
26412673
};

yt/yt/core/actions/unittests/actions_ut.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,25 @@ TEST(TCancelableRunWithBoundedConcurrencyTest, Cancelation)
9494
EXPECT_EQ(canceledCount, 4);
9595
}
9696

97+
TEST(TCancelableRunWithBoundedConcurrencyTest, RecurseRunner)
98+
{
99+
auto threadPool = CreateThreadPool(4, "ThreadPool");
100+
101+
std::vector<TCallback<TFuture<void>()>> callbacks;
102+
for (int i = 0; i < 50'000; ++i) {
103+
callbacks.push_back(BIND([] {
104+
return VoidFuture;
105+
}));
106+
}
107+
108+
auto future = CancelableRunWithBoundedConcurrency<void>(
109+
std::move(callbacks),
110+
/*concurrencyLimit*/ 5);
111+
112+
WaitFor(future)
113+
.ThrowOnError();
114+
}
115+
97116
TEST(TAllSucceededBoundedConcurrencyTest, CancelOthers)
98117
{
99118
using TCounter = std::atomic<int>;

0 commit comments

Comments
 (0)