Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 68b01ed

Browse files
nicolasvasilacheTheodoros Theodoridis
authored andcommitted
Extract out RetrievalResult in standalone class
1 parent 41e5cb2 commit 68b01ed

File tree

4 files changed

+38
-32
lines changed

4 files changed

+38
-32
lines changed

tc/autotuner/utils/utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ std::vector<OptionsWithMedianTime> getOptionsAndMedianRuntimes(
6565
candidates.begin(),
6666
candidates.end(),
6767
std::back_inserter(c),
68-
[](const OptionsCache::RetrievalResult& rr) -> OptionsWithMedianTime {
68+
[](const OptionsCacheRetrievalResult& rr) -> OptionsWithMedianTime {
6969
return {std::move(rr.options), median(rr.recordedRuntimes)};
7070
});
7171
return c;

tc/core/cuda/cuda_compilation_cache.cc

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ const CudaCachedEntry* CudaCache::searchKernel(
170170
return searchKernelImpl(*this, id, options, inputs, outputs);
171171
}
172172

173-
std::unique_ptr<CudaCache::RetrievalResult> CudaCache::retrieveKernel(
173+
std::unique_ptr<CudaCacheRetrievalResult> CudaCache::retrieveKernel(
174174
const std::string& id,
175175
const CudaMappingOptions& options,
176176
const std::vector<const DLTensor*>& inputs,
@@ -182,12 +182,12 @@ std::unique_ptr<CudaCache::RetrievalResult> CudaCache::retrieveKernel(
182182
return nullptr;
183183
}
184184
++numberSuccessfulRetrievals;
185-
return std::unique_ptr<CudaCache::RetrievalResult>(
186-
new CudaCache::RetrievalResult{entry->values.cudaSource,
187-
entry->values.kernelSpecializedName,
188-
entry->values.kernelParameters,
189-
entry->values.grid,
190-
entry->values.block});
185+
return std::unique_ptr<CudaCacheRetrievalResult>(
186+
new CudaCacheRetrievalResult{entry->values.cudaSource,
187+
entry->values.kernelSpecializedName,
188+
entry->values.kernelParameters,
189+
entry->values.grid,
190+
entry->values.block});
191191
}
192192

193193
void CudaCache::removeEntriesNotInOptionsCache(const OptionsCache& oc) {
@@ -228,7 +228,7 @@ std::unique_ptr<CudaMappingOptions> OptionsCache::retrieveBestOptions(
228228
new CudaMappingOptions(ret.front()));
229229
}
230230

231-
std::vector<OptionsCache::RetrievalResult>
231+
std::vector<OptionsCacheRetrievalResult>
232232
OptionsCache::retrieveOptionsAndRuntimes(
233233
const std::string& id,
234234
const std::vector<const DLTensor*>& inputs,
@@ -533,7 +533,7 @@ void removeFromCudaCacheEntriesNotInOptionsCache(
533533
cc.removeEntriesNotInOptionsCache(oc);
534534
}
535535

536-
std::unique_ptr<CudaCache::RetrievalResult> ManualCudaCache::retrieveKernel(
536+
std::unique_ptr<CudaCacheRetrievalResult> ManualCudaCache::retrieveKernel(
537537
const std::string& id,
538538
const std::vector<const DLTensor*>& inputs,
539539
const std::vector<const DLTensor*>& outputs) const {
@@ -544,12 +544,12 @@ std::unique_ptr<CudaCache::RetrievalResult> ManualCudaCache::retrieveKernel(
544544
return nullptr;
545545
}
546546
++numberSuccessfulRetrievals;
547-
return std::unique_ptr<CudaCache::RetrievalResult>(
548-
new CudaCache::RetrievalResult{entry->values.cudaSource,
549-
entry->values.kernelSpecializedName,
550-
entry->values.kernelParameters,
551-
entry->values.grid,
552-
entry->values.block});
547+
return std::unique_ptr<CudaCacheRetrievalResult>(
548+
new CudaCacheRetrievalResult{entry->values.cudaSource,
549+
entry->values.kernelSpecializedName,
550+
entry->values.kernelParameters,
551+
entry->values.grid,
552+
entry->values.block});
553553
}
554554

555555
ManualCudaCachedEntry* ManualCudaCache::searchKernel(

tc/core/cuda/cuda_compilation_cache.h

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ struct CudaCachedEntry {
7272
Values values;
7373
};
7474

75+
struct CudaCacheRetrievalResult {
76+
std::string source;
77+
std::string specializedName;
78+
std::vector<int> parameters;
79+
Grid grid;
80+
Block block;
81+
};
82+
7583
/**
7684
* CudaCache stores the Cuda source of optimized kernels
7785
* A CudaCache holds multiple CudaCachedEntry's.
@@ -92,16 +100,9 @@ class CudaCache : public Cache<CudaCache, CudaCachedEntry> {
92100
public:
93101
using ProtobufType = CudaCacheProto;
94102
using CachedEntry = CudaCachedEntry;
103+
using RetrievalResult = CudaCacheRetrievalResult;
95104
static std::shared_ptr<CudaCache>& getGlobalSharedCache();
96105

97-
struct RetrievalResult {
98-
std::string source;
99-
std::string specializedName;
100-
std::vector<int> parameters;
101-
Grid grid;
102-
Block block;
103-
};
104-
105106
private:
106107
/**
107108
* SearchKernel (through SearchKernelImpl) searches op in the cache
@@ -161,7 +162,7 @@ class CudaCache : public Cache<CudaCache, CudaCachedEntry> {
161162
* Returns the cache entry that matches op (id, isl options, target device)
162163
* and inputs' shapes.
163164
*/
164-
std::unique_ptr<RetrievalResult> retrieveKernel(
165+
std::unique_ptr<CudaCacheRetrievalResult> retrieveKernel(
165166
const std::string& id,
166167
const CudaMappingOptions& options,
167168
const std::vector<const DLTensor*>& inputs,
@@ -219,10 +220,16 @@ struct OptionsCachedEntry {
219220
std::vector<Values> values;
220221
};
221222

223+
struct OptionsCacheRetrievalResult {
224+
CudaMappingOptions options;
225+
std::vector<Duration> recordedRuntimes;
226+
};
227+
222228
class OptionsCache : public Cache<OptionsCache, OptionsCachedEntry> {
223229
public:
224230
using ProtobufType = OptionsCacheProto;
225231
using CachedEntry = OptionsCachedEntry;
232+
using RetrievalResult = OptionsCacheRetrievalResult;
226233
static std::shared_ptr<OptionsCache>& getGlobalSharedCache();
227234

228235
private:
@@ -256,10 +263,6 @@ class OptionsCache : public Cache<OptionsCache, OptionsCachedEntry> {
256263
OptionsCache(const OptionsCacheProto& buf);
257264

258265
OptionsCacheProto toProtobuf() const;
259-
struct RetrievalResult {
260-
CudaMappingOptions options;
261-
std::vector<Duration> recordedRuntimes;
262-
};
263266

264267
// returns the sum of cache entry sizes (that is a single cache entry can have
265268
// multiple options and profiling information associated with it)
@@ -272,7 +275,7 @@ class OptionsCache : public Cache<OptionsCache, OptionsCachedEntry> {
272275
const std::vector<const DLTensor*>& outputs,
273276
Duration runtime);
274277

275-
std::vector<RetrievalResult> retrieveOptionsAndRuntimes(
278+
std::vector<OptionsCacheRetrievalResult> retrieveOptionsAndRuntimes(
276279
const std::string& id,
277280
const std::vector<const DLTensor*>& inputs,
278281
const std::vector<const DLTensor*>& outputs) const;
@@ -339,13 +342,16 @@ struct ManualCudaCachedEntry {
339342
Values values;
340343
};
341344

345+
typedef CudaCacheRetrievalResult ManualCudaCacheRetrievalResult;
346+
342347
/*
343348
* ManualCudaCache stores the manually injected source of Cuda kernels
344349
*/
345350
class ManualCudaCache : public Cache<ManualCudaCache, ManualCudaCachedEntry> {
346351
public:
347352
using ProtobufType = ManualCudaCacheProto;
348353
using CachedEntry = ManualCudaCachedEntry;
354+
using RetrievalResult = ManualCudaCacheRetrievalResult;
349355
static std::shared_ptr<ManualCudaCache>& getGlobalSharedCache();
350356

351357
private:
@@ -401,7 +407,7 @@ class ManualCudaCache : public Cache<ManualCudaCache, ManualCudaCachedEntry> {
401407
*Returns the cache entry that matches
402408
*op(id, target device) and inputs' shapes.
403409
*/
404-
std::unique_ptr<CudaCache::RetrievalResult> retrieveKernel(
410+
std::unique_ptr<ManualCudaCacheRetrievalResult> retrieveKernel(
405411
const std::string& id,
406412
const std::vector<const DLTensor*>& inputs,
407413
const std::vector<const DLTensor*>& outputs) const;

tc/core/cuda/cuda_tc_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void CudaTcExecutor::compile(const tc::CudaMappingOptions& options) {
5353
}
5454
executionInfo_.options = options.toProtobufSerializedString();
5555

56-
auto cachedOp = [&]() -> std::unique_ptr<CudaCache::RetrievalResult> {
56+
auto cachedOp = [&]() -> std::unique_ptr<CudaCacheRetrievalResult> {
5757
if (ManualCudaCache::cacheEnabled()) {
5858
auto rr = ManualCudaCache::getCache()->retrieveKernel(
5959
cacheKeyId_,

0 commit comments

Comments
 (0)