Skip to content

Commit 28a80e6

Browse files
Fix TTbbLocalExecutor::GetWorkerThreadId. Enable test_fit_on_scipy_sparse_spmatrix on Windows.
commit_hash:c52cdc5529aff5cda1cbd11be2852647736ceb49
1 parent f7d173e commit 28a80e6

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

library/cpp/threading/local_executor/tbb_local_executor.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,21 @@ int NPar::TTbbLocalExecutor<RespectTls>::GetThreadCount() const noexcept {
1414

1515
template <bool RespectTls>
1616
int NPar::TTbbLocalExecutor<RespectTls>::GetWorkerThreadId() const noexcept {
17-
return TbbArena.execute([] {
18-
return tbb::this_task_arena::current_thread_index();
19-
});
17+
static thread_local int WorkerThreadId = -1;
18+
if (WorkerThreadId == -1) {
19+
// Can't rely on return value except checking that it is 'not_initialized' because of
20+
// "Since a thread may exit the arena at any time if it does not execute a task, the index of
21+
// a thread may change between any two tasks"
22+
// (https://oneapi-spec.uxlfoundation.org/specifications/oneapi/latest/elements/onetbb/source/task_scheduler/task_arena/this_task_arena_ns#_CPPv4N3tbb15this_task_arena20current_thread_indexEv)
23+
const auto tbbThreadIndex = tbb::this_task_arena::current_thread_index();
24+
if (tbbThreadIndex == tbb::task_arena::not_initialized) {
25+
// This thread does not belong to TBB worker threads
26+
WorkerThreadId = 0;
27+
} else {
28+
WorkerThreadId = ++RegisteredThreadCounter;
29+
}
30+
}
31+
return WorkerThreadId;
2032
}
2133

2234
template <bool RespectTls>

library/cpp/threading/local_executor/tbb_local_executor.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,19 @@
99
#include <contrib/libs/tbb/include/tbb/task_arena.h>
1010
#include <contrib/libs/tbb/include/tbb/task_group.h>
1111

12+
#include <atomic>
13+
14+
1215
namespace NPar {
1316
template <bool RespectTls = false>
1417
class TTbbLocalExecutor final: public ILocalExecutor {
1518
public:
1619
TTbbLocalExecutor(int nThreads)
1720
: ILocalExecutor()
1821
, TbbArena(nThreads)
19-
, NumberOfTbbThreads(nThreads) {}
22+
, NumberOfTbbThreads(nThreads)
23+
, RegisteredThreadCounter(0)
24+
{}
2025
~TTbbLocalExecutor() noexcept override {}
2126

2227
virtual int GetWorkerThreadId() const noexcept override;
@@ -44,5 +49,7 @@ namespace NPar {
4449
mutable tbb::task_arena TbbArena;
4550
tbb::task_group Group;
4651
int NumberOfTbbThreads;
52+
53+
mutable std::atomic_int RegisteredThreadCounter;
4754
};
4855
}

0 commit comments

Comments
 (0)