Skip to content

Commit 134401d

Browse files
authored
[Offload] Move RPC server handling to a dedicated thread (llvm#112988)
Summary: Handling the RPC server requires running through list of jobs that the device has requested to be done. Currently this is handled by the thread that does the waiting for the kernel to finish. However, this is not sound on NVIDIA architectures and only works for async launches in the OpenMP model that uses helper threads. However, we also don't want to have this thread doing work unnnecessarily. For this reason we track the execution of kernels and cause the thread to sleep via a condition variable (usually backed by some kind of futex or other intelligent sleeping mechanism) so that the thread will be idle while no kernels are running.
1 parent c025b96 commit 134401d

File tree

8 files changed

+281
-88
lines changed

8 files changed

+281
-88
lines changed

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -621,9 +621,9 @@ struct AMDGPUSignalTy {
621621
}
622622

623623
/// Wait until the signal gets a zero value.
624-
Error wait(const uint64_t ActiveTimeout = 0, RPCServerTy *RPCServer = nullptr,
624+
Error wait(const uint64_t ActiveTimeout = 0,
625625
GenericDeviceTy *Device = nullptr) const {
626-
if (ActiveTimeout && !RPCServer) {
626+
if (ActiveTimeout) {
627627
hsa_signal_value_t Got = 1;
628628
Got = hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
629629
ActiveTimeout, HSA_WAIT_STATE_ACTIVE);
@@ -632,14 +632,11 @@ struct AMDGPUSignalTy {
632632
}
633633

634634
// If there is an RPC device attached to this stream we run it as a server.
635-
uint64_t Timeout = RPCServer ? 8192 : UINT64_MAX;
636-
auto WaitState = RPCServer ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED;
635+
uint64_t Timeout = UINT64_MAX;
636+
auto WaitState = HSA_WAIT_STATE_BLOCKED;
637637
while (hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
638-
Timeout, WaitState) != 0) {
639-
if (RPCServer && Device)
640-
if (auto Err = RPCServer->runServer(*Device))
641-
return Err;
642-
}
638+
Timeout, WaitState) != 0)
639+
;
643640
return Plugin::success();
644641
}
645642

@@ -1052,11 +1049,6 @@ struct AMDGPUStreamTy {
10521049
/// operation that was already finalized in a previous stream sycnhronize.
10531050
uint32_t SyncCycle;
10541051

1055-
/// A pointer associated with an RPC server running on the given device. If
1056-
/// RPC is not being used this will be a null pointer. Otherwise, this
1057-
/// indicates that an RPC server is expected to be run on this stream.
1058-
RPCServerTy *RPCServer;
1059-
10601052
/// Mutex to protect stream's management.
10611053
mutable std::mutex Mutex;
10621054

@@ -1236,9 +1228,6 @@ struct AMDGPUStreamTy {
12361228
/// Deinitialize the stream's signals.
12371229
Error deinit() { return Plugin::success(); }
12381230

1239-
/// Attach an RPC server to this stream.
1240-
void setRPCServer(RPCServerTy *Server) { RPCServer = Server; }
1241-
12421231
/// Push a asynchronous kernel to the stream. The kernel arguments must be
12431232
/// placed in a special allocation for kernel args and must keep alive until
12441233
/// the kernel finalizes. Once the kernel is finished, the stream will release
@@ -1266,10 +1255,30 @@ struct AMDGPUStreamTy {
12661255
if (auto Err = Slots[Curr].schedReleaseBuffer(KernelArgs, MemoryManager))
12671256
return Err;
12681257

1258+
// If we are running an RPC server we want to wake up the server thread
1259+
// whenever there is a kernel running and let it sleep otherwise.
1260+
if (Device.getRPCServer())
1261+
Device.Plugin.getRPCServer().Thread->notify();
1262+
12691263
// Push the kernel with the output signal and an input signal (optional)
1270-
return Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads, NumBlocks,
1271-
GroupSize, StackSize, OutputSignal,
1272-
InputSignal);
1264+
if (auto Err = Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads,
1265+
NumBlocks, GroupSize, StackSize,
1266+
OutputSignal, InputSignal))
1267+
return Err;
1268+
1269+
// Register a callback to indicate when the kernel is complete.
1270+
if (Device.getRPCServer()) {
1271+
if (auto Err = Slots[Curr].schedCallback(
1272+
[](void *Data) -> llvm::Error {
1273+
GenericPluginTy &Plugin =
1274+
*reinterpret_cast<GenericPluginTy *>(Data);
1275+
Plugin.getRPCServer().Thread->finish();
1276+
return Error::success();
1277+
},
1278+
&Device.Plugin))
1279+
return Err;
1280+
}
1281+
return Plugin::success();
12731282
}
12741283

12751284
/// Push an asynchronous memory copy between pinned memory buffers.
@@ -1479,8 +1488,8 @@ struct AMDGPUStreamTy {
14791488
return Plugin::success();
14801489

14811490
// Wait until all previous operations on the stream have completed.
1482-
if (auto Err = Slots[last()].Signal->wait(StreamBusyWaitMicroseconds,
1483-
RPCServer, &Device))
1491+
if (auto Err =
1492+
Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, &Device))
14841493
return Err;
14851494

14861495
// Reset the stream and perform all pending post actions.
@@ -3027,7 +3036,7 @@ AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device)
30273036
: Agent(Device.getAgent()), Queue(nullptr),
30283037
SignalManager(Device.getSignalManager()), Device(Device),
30293038
// Initialize the std::deque with some empty positions.
3030-
Slots(32), NextSlot(0), SyncCycle(0), RPCServer(nullptr),
3039+
Slots(32), NextSlot(0), SyncCycle(0),
30313040
StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()),
30323041
UseMultipleSdmaEngines(Device.useMultipleSdmaEngines()) {}
30333042

@@ -3383,10 +3392,6 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
33833392
if (auto Err = AMDGPUDevice.getStream(AsyncInfoWrapper, Stream))
33843393
return Err;
33853394

3386-
// If this kernel requires an RPC server we attach its pointer to the stream.
3387-
if (GenericDevice.getRPCServer())
3388-
Stream->setRPCServer(GenericDevice.getRPCServer());
3389-
33903395
// Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
33913396
if (ImplArgs &&
33923397
getImplicitArgsSize() == sizeof(hsa_utils::AMDGPUImplicitArgsTy)) {

offload/plugins-nextgen/common/include/RPC.h

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
#include "llvm/ADT/DenseMap.h"
2020
#include "llvm/Support/Error.h"
2121

22+
#include <atomic>
23+
#include <condition_variable>
2224
#include <cstdint>
25+
#include <mutex>
26+
#include <thread>
2327

2428
namespace llvm::omp::target {
2529
namespace plugin {
@@ -37,6 +41,12 @@ struct RPCServerTy {
3741
/// Initializes the handles to the number of devices we may need to service.
3842
RPCServerTy(plugin::GenericPluginTy &Plugin);
3943

44+
/// Deinitialize the associated memory and resources.
45+
llvm::Error shutDown();
46+
47+
/// Initialize the worker thread.
48+
llvm::Error startThread();
49+
4050
/// Check if this device image is using an RPC server. This checks for the
4151
/// precense of an externally visible symbol in the device image that will
4252
/// be present whenever RPC code is called.
@@ -51,17 +61,77 @@ struct RPCServerTy {
5161
plugin::GenericGlobalHandlerTy &Handler,
5262
plugin::DeviceImageTy &Image);
5363

54-
/// Runs the RPC server associated with the \p Device until the pending work
55-
/// is cleared.
56-
llvm::Error runServer(plugin::GenericDeviceTy &Device);
57-
5864
/// Deinitialize the RPC server for the given device. This will free the
5965
/// memory associated with the k
6066
llvm::Error deinitDevice(plugin::GenericDeviceTy &Device);
6167

6268
private:
6369
/// Array from this device's identifier to its attached devices.
64-
llvm::SmallVector<void *> Buffers;
70+
std::unique_ptr<void *[]> Buffers;
71+
72+
/// Array of associated devices. These must be alive as long as the server is.
73+
std::unique_ptr<plugin::GenericDeviceTy *[]> Devices;
74+
75+
/// A helper class for running the user thread that handles the RPC interface.
76+
/// Because we only need to check the RPC server while any kernels are
77+
/// working, we track submission / completion events to allow the thread to
78+
/// sleep when it is not needed.
79+
struct ServerThread {
80+
std::thread Worker;
81+
82+
/// A boolean indicating whether or not the worker thread should continue.
83+
std::atomic<bool> Running;
84+
85+
/// The number of currently executing kernels across all devices that need
86+
/// the server thread to be running.
87+
std::atomic<uint32_t> NumUsers;
88+
89+
/// The condition variable used to suspend the thread if no work is needed.
90+
std::condition_variable CV;
91+
std::mutex Mutex;
92+
93+
/// A reference to all the RPC interfaces that the server is handling.
94+
llvm::ArrayRef<void *> Buffers;
95+
96+
/// A reference to the associated generic device for the buffer.
97+
llvm::ArrayRef<plugin::GenericDeviceTy *> Devices;
98+
99+
/// Initialize the worker thread to run in the background.
100+
ServerThread(void *Buffers[], plugin::GenericDeviceTy *Devices[],
101+
size_t Length)
102+
: Running(true), NumUsers(0), CV(), Mutex(), Buffers(Buffers, Length),
103+
Devices(Devices, Length) {}
104+
105+
~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); }
106+
107+
/// Notify the worker thread that there is a user that needs it.
108+
void notify() {
109+
std::lock_guard<decltype(Mutex)> Lock(Mutex);
110+
NumUsers.fetch_add(1, std::memory_order_relaxed);
111+
CV.notify_all();
112+
}
113+
114+
/// Indicate that one of the dependent users has finished.
115+
void finish() {
116+
[[maybe_unused]] uint32_t Old =
117+
NumUsers.fetch_sub(1, std::memory_order_relaxed);
118+
assert(Old > 0 && "Attempt to signal finish with no pending work");
119+
}
120+
121+
/// Destroy the worker thread and wait.
122+
void shutDown();
123+
124+
/// Initialize the worker thread.
125+
void startThread();
126+
127+
/// Run the server thread to continuously check the RPC interface for work
128+
/// to be done for the device.
129+
void run();
130+
};
131+
132+
public:
133+
/// Pointer to the server thread instance.
134+
std::unique_ptr<ServerThread> Thread;
65135
};
66136

67137
} // namespace llvm::omp::target

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,9 @@ Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin,
10571057
if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image))
10581058
return Err;
10591059

1060+
if (auto Err = Server.startThread())
1061+
return Err;
1062+
10601063
RPCServer = &Server;
10611064
DP("Running an RPC server on device %d\n", getDeviceId());
10621065
return Plugin::success();
@@ -1630,8 +1633,11 @@ Error GenericPluginTy::deinit() {
16301633
if (GlobalHandler)
16311634
delete GlobalHandler;
16321635

1633-
if (RPCServer)
1636+
if (RPCServer) {
1637+
if (Error Err = RPCServer->shutDown())
1638+
return Err;
16341639
delete RPCServer;
1640+
}
16351641

16361642
if (RecordReplay)
16371643
delete RecordReplay;

0 commit comments

Comments
 (0)