Skip to content

Commit 43e637c

Browse files
authored
[Offload][AMDGPU] Added support for runtime tuning (#190)
2 parents 0b30020 + 2107a33 commit 43e637c

File tree

3 files changed

+196
-26
lines changed

3 files changed

+196
-26
lines changed

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,6 +1709,10 @@ struct AMDGPUStreamTy {
17091709
hsa_agent_t Agent;
17101710
AMDGPUSignalTy *Signal;
17111711
double TicksToTime;
1712+
std::string KernelName;
1713+
uint32_t NumTeams;
1714+
uint32_t NumThreads;
1715+
KernelRunRecordTy *KernelRunRecords;
17121716
};
17131717

17141718
using AMDGPUStreamCallbackTy = Error(void *Data);
@@ -2087,9 +2091,20 @@ struct AMDGPUStreamTy {
20872091
PostKernelRunProcessingArgsTy *Args =
20882092
reinterpret_cast<PostKernelRunProcessingArgsTy *>(Data);
20892093

2090-
uint64_t KernelDuration = getKernelDuration(Args);
2091-
fprintf(stderr, "Kernel Duration: %lu ns\n", KernelDuration);
2094+
KernelRunRecordTy *KernelRecord = Args->KernelRunRecords;
2095+
assert(KernelRecord && "KernelRunRecord is null!");
20922096

2097+
uint64_t KernelDuration = getKernelDuration(Args);
2098+
KernelRecord->addEntry(Args->KernelName, Args->NumTeams, Args->NumThreads,
2099+
KernelDuration);
2100+
2101+
if (getInfoLevel() & OMP_INFOTYPE_AMD_KERNEL_TRACE) {
2102+
fprintf(stderr,
2103+
"[Autotuning run] Kernel %s with %u teams and %u threads "
2104+
"completed in %lu ns.\n",
2105+
Args->KernelName.c_str(), Args->NumTeams, Args->NumThreads,
2106+
KernelDuration);
2107+
}
20932108
return Plugin::success();
20942109
}
20952110

@@ -2171,14 +2186,26 @@ struct AMDGPUStreamTy {
21712186

21722187
// If runtime autotuning is enabled, setup the callback functions to process
21732188
// the data after kernel completed.
2174-
if (Device.enableRuntimeAutotuning()) {
2175-
PostKernelRunProcessingArgs.Agent = Agent;
2176-
PostKernelRunProcessingArgs.Signal = OutputSignal;
2177-
PostKernelRunProcessingArgs.TicksToTime = 1.0;
2178-
2179-
if (auto Err = Slots[Curr].schedCallback(postKernelRunProcessingAction,
2180-
&PostKernelRunProcessingArgs))
2181-
return Err;
2189+
if (Device.enableRuntimeAutotuning() && Kernel.isSPMDMode()) {
2190+
std::string KernelName(Kernel.getName());
2191+
KernelRunRecordTy *KernelRecords = Device.getKernelRunRecords();
2192+
assert(KernelRecords && "No KernelRecords!");
2193+
2194+
// If this kernel has reached the run limit,
2195+
// skip registering the callback function.
2196+
if (!KernelRecords->reachedRunLimitForKernel(KernelName)) {
2197+
PostKernelRunProcessingArgs.Agent = Agent;
2198+
PostKernelRunProcessingArgs.Signal = OutputSignal;
2199+
PostKernelRunProcessingArgs.TicksToTime = 1.0;
2200+
PostKernelRunProcessingArgs.KernelName = KernelName;
2201+
PostKernelRunProcessingArgs.NumTeams = NumBlocks[0];
2202+
PostKernelRunProcessingArgs.NumThreads = NumThreads[0];
2203+
PostKernelRunProcessingArgs.KernelRunRecords = KernelRecords;
2204+
2205+
if (auto Err = Slots[Curr].schedCallback(postKernelRunProcessingAction,
2206+
&PostKernelRunProcessingArgs))
2207+
return Err;
2208+
}
21822209
}
21832210

21842211
// Push the kernel with the output signal and an input signal (optional)

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include <list>
1818
#include <map>
1919
#include <shared_mutex>
20+
#include <unordered_map>
21+
#include <unordered_set>
2022
#include <vector>
2123

2224
#include "ExclusiveAccess.h"
@@ -58,6 +60,7 @@ struct GenericPluginTy;
5860
struct GenericKernelTy;
5961
struct GenericDeviceTy;
6062
struct RecordReplayTy;
63+
struct KernelRunRecordTy;
6164

6265
/// Class that wraps the __tgt_async_info to simply its usage. In case the
6366
/// object is constructed without a valid __tgt_async_info, the object will use
@@ -1105,6 +1108,8 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
11051108

11061109
bool getMultiDeviceKernelValue(void *EntryPtr);
11071110

1111+
KernelRunRecordTy *getKernelRunRecords() const { return KernelRunRecords; }
1112+
11081113
/// Return true if a descriptor of size 'Size' should be allocated using
11091114
/// shared memory. Default implementation returns 'false',
11101115
virtual bool useSharedMemForDescriptor(int64_t Size);
@@ -1256,6 +1261,9 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
12561261
/// This is used to run the RPC server during task synchronization.
12571262
RPCServerTy *RPCServer;
12581263

1264+
/// Structs for functions and data used in runtime autotuning.
1265+
KernelRunRecordTy *KernelRunRecords;
1266+
12591267
private:
12601268
#ifdef OMPT_SUPPORT
12611269
/// OMPT callback functions
@@ -1282,6 +1290,109 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
12821290
bool IsFastReductionEnabled = false;
12831291
};
12841292

1293+
/// Struct represents the metadata for each kernel run on the device.
1294+
struct KernelRunRecordTy {
1295+
1296+
struct KernelRunEntryTy {
1297+
std::string KernelName;
1298+
uint32_t NumTeams = 0;
1299+
uint32_t NumThreads = 0;
1300+
uint64_t RunDuration = 0;
1301+
};
1302+
1303+
// Metadata used in tuning process.
1304+
struct TuningMetadataTy {
1305+
uint32_t IdxThread = 0;
1306+
uint32_t IdxCUMultiplier = 0;
1307+
// Run counters.
1308+
uint32_t RunCounters = 0;
1309+
// Entry with minimum running time.
1310+
KernelRunEntryTy MinEntry;
1311+
};
1312+
1313+
// Add a new entry
1314+
void addEntry(std::string KernelName, uint32_t NumTeams, uint32_t NumThreads,
1315+
uint64_t RunDuration) {
1316+
TuningData[KernelName].RunCounters++;
1317+
1318+
// Update min entries.
1319+
uint64_t MinDuration = 0;
1320+
auto It = TuningData.find(KernelName);
1321+
if (It != TuningData.end()) {
1322+
MinDuration = It->second.MinEntry.RunDuration;
1323+
}
1324+
if (MinDuration > RunDuration || MinDuration == 0) {
1325+
TuningData[KernelName].MinEntry = {KernelName, NumTeams, NumThreads,
1326+
RunDuration};
1327+
}
1328+
}
1329+
1330+
// Get parameters for next kernel launch.
1331+
std::pair<uint32_t, uint32_t>
1332+
getLaunchParamsForKernel(std::string KernelName,
1333+
GenericDeviceTy &GenericDevice) {
1334+
// If the kernel reaches the run limit,
1335+
// return the current optimal launch parameters.
1336+
if (reachedRunLimitForKernel(KernelName)) {
1337+
auto MinEntry = TuningData[KernelName].MinEntry;
1338+
return {MinEntry.NumTeams, MinEntry.NumThreads};
1339+
}
1340+
1341+
// Pick new launch parameters.
1342+
uint32_t IdxCUMulti = TuningData[KernelName].IdxCUMultiplier;
1343+
uint32_t IdxThread = TuningData[KernelName].IdxThread;
1344+
1345+
if (IdxCUMulti >= CUMultiplierCandidate.size()) {
1346+
// No more element to search.
1347+
// Return current optimal launch parameters.
1348+
return {TuningData[KernelName].MinEntry.NumTeams,
1349+
TuningData[KernelName].MinEntry.NumThreads};
1350+
}
1351+
1352+
// New team/thread pair for launch parameters.
1353+
uint32_t NumCU = GenericDevice.getNumComputeUnits();
1354+
std::pair<uint32_t, uint32_t> NewLaunchParams = {
1355+
CUMultiplierCandidate[IdxCUMulti] * NumCU, ThreadCandidate[IdxThread]};
1356+
1357+
// Update indices.
1358+
IdxThread++;
1359+
TuningData[KernelName].IdxThread = IdxThread;
1360+
1361+
if (IdxThread >= ThreadCandidate.size()) {
1362+
TuningData[KernelName].IdxThread = 0;
1363+
TuningData[KernelName].IdxCUMultiplier++;
1364+
}
1365+
1366+
return NewLaunchParams;
1367+
}
1368+
1369+
bool reachedRunLimitForKernel(std::string KernelName) {
1370+
if (TuningData.find(KernelName) == TuningData.end()) {
1371+
// If no record for this kernel.
1372+
return false;
1373+
}
1374+
1375+
return TuningData[KernelName].RunCounters > RunLimiter;
1376+
}
1377+
1378+
uint32_t getRunCounterForKernel(std::string KernelName) {
1379+
if (TuningData.find(KernelName) == TuningData.end()) {
1380+
return 0;
1381+
}
1382+
1383+
return TuningData[KernelName].RunCounters;
1384+
}
1385+
1386+
private:
1387+
// Candidates for thread and team.
1388+
std::vector<uint32_t> ThreadCandidate = {32, 64, 128, 256, 512, 1024};
1389+
std::vector<uint32_t> CUMultiplierCandidate = {4, 8, 16, 32, 64, 128};
1390+
// The max number of tuning runs for each kernel.
1391+
uint32_t RunLimiter = ThreadCandidate.size() * CUMultiplierCandidate.size();
1392+
// Used for keeping track of the metatdata used in tuning for each kernel.
1393+
std::unordered_map<std::string, TuningMetadataTy> TuningData;
1394+
};
1395+
12851396
/// Class implementing common functionalities of offload plugins. Each plugin
12861397
/// should define the specific plugin class, derive from this generic one, and
12871398
/// implement the necessary virtual function members.

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -726,21 +726,40 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
726726
uint32_t NumBlocks[3] = {KernelArgs.NumTeams[0], KernelArgs.NumTeams[1],
727727
KernelArgs.NumTeams[2]};
728728

729-
// TODO fix workaround since IsBareKernel is not properly set for legacy
730-
// flang and specialized kernels since they don't use kernel-env. While
731-
// we can check for specialized kernels, we can't for legacy flang. So,
732-
// on amd-staging, all kernels including bare ones use this codepath.
733-
NumThreads[0] = getNumThreads(GenericDevice, NumThreads);
734-
735-
std::pair<bool, uint32_t> AdjustInfo = adjustNumThreadsForLowTripCount(
736-
GenericDevice, NumThreads[0], KernelArgs.Tripcount,
737-
KernelArgs.ThreadLimit);
738-
if (AdjustInfo.first)
739-
NumThreads[0] = AdjustInfo.second;
740-
741-
NumBlocks[0] = getNumBlocks(GenericDevice, NumBlocks, KernelArgs.Tripcount,
742-
NumThreads[0], KernelArgs.ThreadLimit[0] > 0);
743-
// }
729+
std::string KernelName = getName();
730+
KernelRunRecordTy *KernelRecord = GenericDevice.getKernelRunRecords();
731+
uint32_t KernelRunCounter = 0;
732+
733+
if (KernelRecord) {
734+
KernelRunCounter = KernelRecord->getRunCounterForKernel(KernelName);
735+
}
736+
// If Autotuning is enabled and the kernel is not launched for the first time.
737+
if (GenericDevice.enableRuntimeAutotuning() && isSPMDMode() &&
738+
KernelRunCounter > 0) {
739+
assert(KernelRecord &&
740+
"Autotuning is enabled, but KernelRunRecord is not initialized!");
741+
742+
auto [Teams, Threads] =
743+
KernelRecord->getLaunchParamsForKernel(KernelName, GenericDevice);
744+
NumBlocks[0] = Teams;
745+
NumThreads[0] = Threads;
746+
} else {
747+
748+
// TODO fix workaround since IsBareKernel is not properly set for legacy
749+
// flang and specialized kernels since they don't use kernel-env. While
750+
// we can check for specialized kernels, we can't for legacy flang. So,
751+
// on amd-staging, all kernels including bare ones use this codepath.
752+
NumThreads[0] = getNumThreads(GenericDevice, NumThreads);
753+
754+
std::pair<bool, uint32_t> AdjustInfo = adjustNumThreadsForLowTripCount(
755+
GenericDevice, NumThreads[0], KernelArgs.Tripcount,
756+
KernelArgs.ThreadLimit);
757+
if (AdjustInfo.first)
758+
NumThreads[0] = AdjustInfo.second;
759+
760+
NumBlocks[0] = getNumBlocks(GenericDevice, NumBlocks, KernelArgs.Tripcount,
761+
NumThreads[0], KernelArgs.ThreadLimit[0] > 0);
762+
}
744763

745764
// Record the kernel description after we modified the argument count and num
746765
// blocks/threads.
@@ -930,7 +949,7 @@ GenericDeviceTy::GenericDeviceTy(GenericPluginTy &Plugin, int32_t DeviceId,
930949
OMPX_EnableRuntimeAutotuning("OMPX_ENABLE_RUNTIME_AUTOTUNING", false),
931950
DeviceId(DeviceId), GridValues(OMPGridValues),
932951
PeerAccesses(NumDevices, PeerAccessState::PENDING), PeerAccessesLock(),
933-
PinnedAllocs(*this), RPCServer(nullptr) {
952+
PinnedAllocs(*this), RPCServer(nullptr), KernelRunRecords(nullptr) {
934953
#ifdef OMPT_SUPPORT
935954
OmptInitialized.store(false);
936955
// Bind the callbacks to this device's member functions
@@ -1012,6 +1031,11 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) {
10121031
MemoryManager = new MemoryManagerTy(*this, ThresholdMM);
10131032
}
10141033

1034+
// Allocate resources for autotuning if enabled.
1035+
if (OMPX_EnableRuntimeAutotuning) {
1036+
KernelRunRecords = new KernelRunRecordTy();
1037+
}
1038+
10151039
return Plugin::success();
10161040
}
10171041

@@ -1084,6 +1108,14 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
10841108
if (auto Err = RPCServer->deinitDevice(*this))
10851109
return Err;
10861110

1111+
// Delete autotuning related resources if the option is on.
1112+
if (OMPX_EnableRuntimeAutotuning) {
1113+
if (KernelRunRecords) {
1114+
delete KernelRunRecords;
1115+
KernelRunRecords = nullptr;
1116+
}
1117+
}
1118+
10871119
#ifdef OMPT_SUPPORT
10881120
if (ompt::Initialized) {
10891121
bool ExpectedStatus = true;

0 commit comments

Comments
 (0)