Skip to content

Commit 99d9b07

Browse files
committed
[Offload][AMDGPU] Added support for runtime tuning
This PR implemented the necessary structs and functions for runtime tuning. The initial tuning logic is fairly straightforward with hard-coded candidates and exhaustive iterations. We will contiune to improve it in following patches through further discussions and experiments.
1 parent c8f7889 commit 99d9b07

File tree

3 files changed

+182
-25
lines changed

3 files changed

+182
-25
lines changed

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,6 +1709,10 @@ struct AMDGPUStreamTy {
17091709
hsa_agent_t Agent;
17101710
AMDGPUSignalTy *Signal;
17111711
double TicksToTime;
1712+
std::string KernelName;
1713+
uint32_t NumTeams;
1714+
uint32_t NumThreads;
1715+
KernelRunRecord *KernelRunRecords;
17121716
};
17131717

17141718
using AMDGPUStreamCallbackTy = Error(void *Data);
@@ -2087,9 +2091,20 @@ struct AMDGPUStreamTy {
20872091
PostKernelRunProcessingArgsTy *Args =
20882092
reinterpret_cast<PostKernelRunProcessingArgsTy *>(Data);
20892093

2090-
uint64_t KernelDuration = getKernelDuration(Args);
2091-
fprintf(stderr, "Kernel Duration: %lu ns\n", KernelDuration);
2094+
KernelRunRecord *KernelRecord = Args->KernelRunRecords;
2095+
assert(KernelRecord && "KernelRunRecord is null!");
20922096

2097+
uint64_t KernelDuration = getKernelDuration(Args);
2098+
KernelRecord->addEntry(Args->KernelName, Args->NumTeams, Args->NumThreads,
2099+
KernelDuration);
2100+
2101+
if (getInfoLevel() & OMP_INFOTYPE_AMD_KERNEL_TRACE) {
2102+
fprintf(stderr,
2103+
"[Autotuning run] Kernel %s with %u teams and %u threads "
2104+
"completed in %lu ns.\n",
2105+
Args->KernelName.c_str(), Args->NumTeams, Args->NumThreads,
2106+
KernelDuration);
2107+
}
20932108
return Plugin::success();
20942109
}
20952110

@@ -2172,13 +2187,24 @@ struct AMDGPUStreamTy {
21722187
// If runtime autotuning is enabled, setup the callback functions to process
21732188
// the data after kernel completed.
21742189
if (Device.enableRuntimeAutotuning()) {
2175-
PostKernelRunProcessingArgs.Agent = Agent;
2176-
PostKernelRunProcessingArgs.Signal = OutputSignal;
2177-
PostKernelRunProcessingArgs.TicksToTime = 1.0;
2178-
2179-
if (auto Err = Slots[Curr].schedCallback(postKernelRunProcessingAction,
2180-
&PostKernelRunProcessingArgs))
2181-
return Err;
2190+
std::string KernelName(Kernel.getName());
2191+
KernelRunRecord *KernelRecords = Device.getKernelRunRecords();
2192+
2193+
// If this kernel has reached the run limit,
2194+
// skip registering the callback function.
2195+
if (!KernelRecords->reachedRunLimitForKernel(KernelName)) {
2196+
PostKernelRunProcessingArgs.Agent = Agent;
2197+
PostKernelRunProcessingArgs.Signal = OutputSignal;
2198+
PostKernelRunProcessingArgs.TicksToTime = 1.0;
2199+
PostKernelRunProcessingArgs.KernelName = KernelName;
2200+
PostKernelRunProcessingArgs.NumTeams = NumBlocks[0];
2201+
PostKernelRunProcessingArgs.NumThreads = NumThreads[0];
2202+
PostKernelRunProcessingArgs.KernelRunRecords = KernelRecords;
2203+
2204+
if (auto Err = Slots[Curr].schedCallback(postKernelRunProcessingAction,
2205+
&PostKernelRunProcessingArgs))
2206+
return Err;
2207+
}
21822208
}
21832209

21842210
// Push the kernel with the output signal and an input signal (optional)

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include <list>
1818
#include <map>
1919
#include <shared_mutex>
20+
#include <unordered_map>
21+
#include <unordered_set>
2022
#include <vector>
2123

2224
#include "ExclusiveAccess.h"
@@ -58,6 +60,7 @@ struct GenericPluginTy;
5860
struct GenericKernelTy;
5961
struct GenericDeviceTy;
6062
struct RecordReplayTy;
63+
struct KernelRunRecord;
6164

6265
/// Class that wraps the __tgt_async_info to simply its usage. In case the
6366
/// object is constructed without a valid __tgt_async_info, the object will use
@@ -1105,6 +1108,8 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
11051108

11061109
bool getMultiDeviceKernelValue(void *EntryPtr);
11071110

1111+
KernelRunRecord *getKernelRunRecords() const { return KernelRunRecords; }
1112+
11081113
/// Return true if a descriptor of size 'Size' should be allocated using
11091114
/// shared memory. Default implementation returns 'false',
11101115
virtual bool useSharedMemForDescriptor(int64_t Size);
@@ -1256,6 +1261,9 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
12561261
/// This is used to run the RPC server during task synchronization.
12571262
RPCServerTy *RPCServer;
12581263

1264+
/// Structs for functions and data used in runtime autotuning.
1265+
KernelRunRecord *KernelRunRecords;
1266+
12591267
private:
12601268
#ifdef OMPT_SUPPORT
12611269
/// OMPT callback functions
@@ -1282,6 +1290,99 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
12821290
bool IsFastReductionEnabled = false;
12831291
};
12841292

1293+
/// Struct represents the metadata for each kernel run on the device.
1294+
struct KernelRunRecord {
1295+
1296+
struct KernelRunEntry {
1297+
std::string KernelName;
1298+
uint32_t NumTeams;
1299+
uint32_t NumThreads;
1300+
uint64_t RunDuration;
1301+
};
1302+
1303+
// Metadata used in tuning process.
1304+
struct TuningMetadata {
1305+
uint32_t IdxThread = 0;
1306+
uint32_t IdxCUMultiplier = 0;
1307+
// Tuning history.
1308+
std::vector<KernelRunEntry> RunEntries;
1309+
// Run counters.
1310+
uint32_t RunCounters;
1311+
// Entry with minimum running time.
1312+
KernelRunEntry MinEntries;
1313+
};
1314+
1315+
// Add a new entry
1316+
void addEntry(std::string KernelName, uint32_t NumTeams, uint32_t NumThreads,
1317+
uint64_t RunDuration) {
1318+
KernelRunEntry NewRunEnry = {KernelName, NumTeams, NumThreads, RunDuration};
1319+
TuningData[KernelName].RunEntries.push_back(NewRunEnry);
1320+
TuningData[KernelName].RunCounters++;
1321+
1322+
// Update min entries.
1323+
auto MinDuration = TuningData[KernelName].MinEntries.RunDuration;
1324+
if (MinDuration > RunDuration || MinDuration == 0) {
1325+
TuningData[KernelName].MinEntries = NewRunEnry;
1326+
}
1327+
}
1328+
1329+
// Get parameters for next kernel launch.
1330+
std::pair<uint32_t, uint32_t>
1331+
getLaunchParamsForKernel(std::string KernelName,
1332+
GenericDeviceTy &GenericDevice) {
1333+
// If the kernel reaches the run limit,
1334+
// return the current optimal launch parameters.
1335+
if (reachedRunLimitForKernel(KernelName)) {
1336+
auto MinEntry = TuningData[KernelName].MinEntries;
1337+
return {MinEntry.NumTeams, MinEntry.NumThreads};
1338+
}
1339+
1340+
// Pick new launch parameters.
1341+
uint32_t IdxCUMulti = TuningData[KernelName].IdxCUMultiplier;
1342+
uint32_t IdxThread = TuningData[KernelName].IdxThread;
1343+
1344+
if (IdxCUMulti >= CUMultiplierCandidate.size()) {
1345+
// No more element to search.
1346+
// Return current optimal launch parameters.
1347+
return {TuningData[KernelName].MinEntries.NumTeams,
1348+
TuningData[KernelName].MinEntries.NumThreads};
1349+
}
1350+
1351+
// New team/thread pair for launch parameters.
1352+
uint32_t NumCU = GenericDevice.getNumComputeUnits();
1353+
std::pair<uint32_t, uint32_t> NewLaunchParams = {
1354+
CUMultiplierCandidate[IdxCUMulti] * NumCU, ThreadCandidate[IdxThread]};
1355+
1356+
// Update indices.
1357+
IdxThread++;
1358+
TuningData[KernelName].IdxThread = IdxThread;
1359+
1360+
if (IdxThread >= ThreadCandidate.size()) {
1361+
TuningData[KernelName].IdxThread = 0;
1362+
TuningData[KernelName].IdxCUMultiplier++;
1363+
}
1364+
1365+
return NewLaunchParams;
1366+
}
1367+
1368+
bool reachedRunLimitForKernel(std::string KernelName) {
1369+
return TuningData[KernelName].RunCounters > RunLimiter;
1370+
}
1371+
1372+
uint32_t getRunCounterForKernel(std::string KernelName) {
1373+
return TuningData[KernelName].RunCounters;
1374+
}
1375+
1376+
private:
1377+
// Candidates for thread and team.
1378+
std::vector<uint32_t> ThreadCandidate = {32, 64, 128, 256, 512, 1024};
1379+
std::vector<uint32_t> CUMultiplierCandidate = {4, 8, 16, 32, 64, 128};
1380+
// The max number of tuning runs for each kernel.
1381+
uint32_t RunLimiter = ThreadCandidate.size() * CUMultiplierCandidate.size();
1382+
// Used for keeping track of the metatdata used in tuning for each kernel.
1383+
std::unordered_map<std::string, TuningMetadata> TuningData;
1384+
};
1385+
12851386
/// Class implementing common functionalities of offload plugins. Each plugin
12861387
/// should define the specific plugin class, derive from this generic one, and
12871388
/// implement the necessary virtual function members.

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -726,21 +726,39 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
726726
uint32_t NumBlocks[3] = {KernelArgs.NumTeams[0], KernelArgs.NumTeams[1],
727727
KernelArgs.NumTeams[2]};
728728

729-
// TODO fix workaround since IsBareKernel is not properly set for legacy
730-
// flang and specialized kernels since they don't use kernel-env. While
731-
// we can check for specialized kernels, we can't for legacy flang. So,
732-
// on amd-staging, all kernels including bare ones use this codepath.
733-
NumThreads[0] = getNumThreads(GenericDevice, NumThreads);
734-
735-
std::pair<bool, uint32_t> AdjustInfo = adjustNumThreadsForLowTripCount(
736-
GenericDevice, NumThreads[0], KernelArgs.Tripcount,
737-
KernelArgs.ThreadLimit);
738-
if (AdjustInfo.first)
739-
NumThreads[0] = AdjustInfo.second;
740-
741-
NumBlocks[0] = getNumBlocks(GenericDevice, NumBlocks, KernelArgs.Tripcount,
742-
NumThreads[0], KernelArgs.ThreadLimit[0] > 0);
743-
// }
729+
std::string KernelName = getName();
730+
KernelRunRecord *KernelRecord = GenericDevice.getKernelRunRecords();
731+
uint32_t KernelRunCounter = 0;
732+
733+
if (KernelRecord) {
734+
KernelRunCounter = KernelRecord->getRunCounterForKernel(KernelName);
735+
}
736+
// If Autotuning is enabled and the kernel is not launched for the first time.
737+
if (GenericDevice.enableRuntimeAutotuning() && KernelRunCounter > 0) {
738+
assert(KernelRecord &&
739+
"Autotuning is enabled, but KernelRunRecord is not initialized!");
740+
741+
auto [Teams, Threads] =
742+
KernelRecord->getLaunchParamsForKernel(KernelName, GenericDevice);
743+
NumBlocks[0] = Teams;
744+
NumThreads[0] = Threads;
745+
} else {
746+
747+
// TODO fix workaround since IsBareKernel is not properly set for legacy
748+
// flang and specialized kernels since they don't use kernel-env. While
749+
// we can check for specialized kernels, we can't for legacy flang. So,
750+
// on amd-staging, all kernels including bare ones use this codepath.
751+
NumThreads[0] = getNumThreads(GenericDevice, NumThreads);
752+
753+
std::pair<bool, uint32_t> AdjustInfo = adjustNumThreadsForLowTripCount(
754+
GenericDevice, NumThreads[0], KernelArgs.Tripcount,
755+
KernelArgs.ThreadLimit);
756+
if (AdjustInfo.first)
757+
NumThreads[0] = AdjustInfo.second;
758+
759+
NumBlocks[0] = getNumBlocks(GenericDevice, NumBlocks, KernelArgs.Tripcount,
760+
NumThreads[0], KernelArgs.ThreadLimit[0] > 0);
761+
}
744762

745763
// Record the kernel description after we modified the argument count and num
746764
// blocks/threads.
@@ -930,7 +948,7 @@ GenericDeviceTy::GenericDeviceTy(GenericPluginTy &Plugin, int32_t DeviceId,
930948
OMPX_EnableRuntimeAutotuning("OMPX_ENABLE_RUNTIME_AUTOTUNING", false),
931949
DeviceId(DeviceId), GridValues(OMPGridValues),
932950
PeerAccesses(NumDevices, PeerAccessState::PENDING), PeerAccessesLock(),
933-
PinnedAllocs(*this), RPCServer(nullptr) {
951+
PinnedAllocs(*this), RPCServer(nullptr), KernelRunRecords(nullptr) {
934952
#ifdef OMPT_SUPPORT
935953
OmptInitialized.store(false);
936954
// Bind the callbacks to this device's member functions
@@ -1012,6 +1030,11 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) {
10121030
MemoryManager = new MemoryManagerTy(*this, ThresholdMM);
10131031
}
10141032

1033+
// Allocate resources for autotuning if enabled.
1034+
if (OMPX_EnableRuntimeAutotuning) {
1035+
KernelRunRecords = new KernelRunRecord();
1036+
}
1037+
10151038
return Plugin::success();
10161039
}
10171040

@@ -1084,6 +1107,13 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
10841107
if (auto Err = RPCServer->deinitDevice(*this))
10851108
return Err;
10861109

1110+
// Delete autotuning related resources if the option is on.
1111+
if (OMPX_EnableRuntimeAutotuning) {
1112+
if (KernelRunRecords)
1113+
delete KernelRunRecords;
1114+
KernelRunRecords = nullptr;
1115+
}
1116+
10871117
#ifdef OMPT_SUPPORT
10881118
if (ompt::Initialized) {
10891119
bool ExpectedStatus = true;

0 commit comments

Comments
 (0)