@@ -72,6 +72,14 @@ struct CudaCachedEntry {
72
72
Values values;
73
73
};
74
74
75
+ struct CudaCacheRetrievalResult {
76
+ std::string source;
77
+ std::string specializedName;
78
+ std::vector<int > parameters;
79
+ Grid grid;
80
+ Block block;
81
+ };
82
+
75
83
/* *
76
84
* CudaCache stores the Cuda source of optimized kernels
77
85
* A CudaCache holds multiple CudaCachedEntry's.
@@ -92,16 +100,9 @@ class CudaCache : public Cache<CudaCache, CudaCachedEntry> {
92
100
public:
93
101
using ProtobufType = CudaCacheProto;
94
102
using CachedEntry = CudaCachedEntry;
103
+ using RetrievalResult = CudaCacheRetrievalResult;
95
104
static std::shared_ptr<CudaCache>& getGlobalSharedCache ();
96
105
97
- struct RetrievalResult {
98
- std::string source;
99
- std::string specializedName;
100
- std::vector<int > parameters;
101
- Grid grid;
102
- Block block;
103
- };
104
-
105
106
private:
106
107
/* *
107
108
* SearchKernel (through SearchKernelImpl) searches op in the cache
@@ -161,7 +162,7 @@ class CudaCache : public Cache<CudaCache, CudaCachedEntry> {
161
162
* Returns the cache entry that matches op (id, isl options, target device)
162
163
* and inputs' shapes.
163
164
*/
164
- std::unique_ptr<RetrievalResult > retrieveKernel (
165
+ std::unique_ptr<CudaCacheRetrievalResult > retrieveKernel (
165
166
const std::string& id,
166
167
const CudaMappingOptions& options,
167
168
const std::vector<const DLTensor*>& inputs,
@@ -219,10 +220,16 @@ struct OptionsCachedEntry {
219
220
std::vector<Values> values;
220
221
};
221
222
223
+ struct OptionsCacheRetrievalResult {
224
+ CudaMappingOptions options;
225
+ std::vector<Duration> recordedRuntimes;
226
+ };
227
+
222
228
class OptionsCache : public Cache <OptionsCache, OptionsCachedEntry> {
223
229
public:
224
230
using ProtobufType = OptionsCacheProto;
225
231
using CachedEntry = OptionsCachedEntry;
232
+ using RetrievalResult = OptionsCacheRetrievalResult;
226
233
static std::shared_ptr<OptionsCache>& getGlobalSharedCache ();
227
234
228
235
private:
@@ -256,10 +263,6 @@ class OptionsCache : public Cache<OptionsCache, OptionsCachedEntry> {
256
263
OptionsCache (const OptionsCacheProto& buf);
257
264
258
265
OptionsCacheProto toProtobuf () const ;
259
- struct RetrievalResult {
260
- CudaMappingOptions options;
261
- std::vector<Duration> recordedRuntimes;
262
- };
263
266
264
267
// returns the sum of cache entry sizes (that is a single cache entry can have
265
268
// multiple options and profiling information associated with it)
@@ -272,7 +275,7 @@ class OptionsCache : public Cache<OptionsCache, OptionsCachedEntry> {
272
275
const std::vector<const DLTensor*>& outputs,
273
276
Duration runtime);
274
277
275
- std::vector<RetrievalResult > retrieveOptionsAndRuntimes (
278
+ std::vector<OptionsCacheRetrievalResult > retrieveOptionsAndRuntimes (
276
279
const std::string& id,
277
280
const std::vector<const DLTensor*>& inputs,
278
281
const std::vector<const DLTensor*>& outputs) const ;
@@ -339,13 +342,16 @@ struct ManualCudaCachedEntry {
339
342
Values values;
340
343
};
341
344
345
+ typedef CudaCacheRetrievalResult ManualCudaCacheRetrievalResult;
346
+
342
347
/*
343
348
* ManualCudaCache stores the manually injected source of Cuda kernels
344
349
*/
345
350
class ManualCudaCache : public Cache <ManualCudaCache, ManualCudaCachedEntry> {
346
351
public:
347
352
using ProtobufType = ManualCudaCacheProto;
348
353
using CachedEntry = ManualCudaCachedEntry;
354
+ using RetrievalResult = ManualCudaCacheRetrievalResult;
349
355
static std::shared_ptr<ManualCudaCache>& getGlobalSharedCache ();
350
356
351
357
private:
@@ -401,7 +407,7 @@ class ManualCudaCache : public Cache<ManualCudaCache, ManualCudaCachedEntry> {
401
407
*Returns the cache entry that matches
402
408
*op(id, target device) and inputs' shapes.
403
409
*/
404
- std::unique_ptr<CudaCache::RetrievalResult > retrieveKernel (
410
+ std::unique_ptr<ManualCudaCacheRetrievalResult > retrieveKernel (
405
411
const std::string& id,
406
412
const std::vector<const DLTensor*>& inputs,
407
413
const std::vector<const DLTensor*>& outputs) const ;
0 commit comments