@@ -55,6 +55,90 @@ void WriteProtobufArray(const Array& arr, Buf* buf) {
55
55
arr.begin (), arr.end ());
56
56
buf->Swap (&data);
57
57
}
58
+
59
+ template <typename CachedEntryType, typename TensorType>
60
+ const CachedEntryType* searchKernel (
61
+ const std::vector<CachedEntryType>& entries,
62
+ const std::string& id,
63
+ const std::vector<TensorType>& inputs,
64
+ const std::vector<TensorType>& outputs) {
65
+ auto gpuStr = CudaGPUInfo::GPUInfo ().GetCudaDeviceStr ();
66
+ auto it = std::find_if (
67
+ entries.begin (), entries.end (), [&](const CachedEntryType& c) {
68
+ using tc::operator ==;
69
+ return id == c.key .id && inputs == c.key .inputs &&
70
+ outputs == c.key .outputs && gpuStr == c.key .deviceStr ;
71
+ });
72
+ if (it != entries.end ()) {
73
+ if (it->key .gitVersion != tc::git_version) {
74
+ std::cerr << " [WARNING] Proto version doesn't match. TC git version is: "
75
+ << tc::git_version
76
+ << " and Proto version is: " << it->key .gitVersion
77
+ << " .This proto might be incompatible"
78
+ << " with your TC binary and can break. Please autotune"
79
+ << " against the correct TC version." << std::endl;
80
+ }
81
+ return &*it;
82
+ }
83
+ return nullptr ;
84
+ }
85
+
86
+ template <typename CachedEntryType, typename TensorType>
87
+ CachedEntryType* searchKernel (
88
+ std::vector<CachedEntryType>& entries,
89
+ const std::string& id,
90
+ const std::vector<TensorType>& inputs,
91
+ const std::vector<TensorType>& outputs) {
92
+ return const_cast <CachedEntryType*>(searchKernel (
93
+ static_cast <const std::vector<CachedEntryType>&>(entries),
94
+ id,
95
+ inputs,
96
+ outputs));
97
+ }
98
+
99
+ template <typename CachedEntryType, typename TensorType>
100
+ const CachedEntryType* searchKernel (
101
+ const std::vector<CachedEntryType>& entries,
102
+ const std::string& id,
103
+ const CudaMappingOptions& options,
104
+ const std::vector<TensorType>& inputs,
105
+ const std::vector<TensorType>& outputs) {
106
+ auto gpuStr = CudaGPUInfo::GPUInfo ().GetCudaDeviceStr ();
107
+ auto it = std::find_if (
108
+ entries.begin (), entries.end (), [&](const CachedEntryType& c) {
109
+ using tc::operator ==;
110
+ return id == c.key .id && options == c.key .mappingOptions &&
111
+ inputs == c.key .inputs && outputs == c.key .outputs &&
112
+ gpuStr == c.key .deviceStr ;
113
+ });
114
+ if (it != entries.end ()) {
115
+ if (it->key .gitVersion != tc::git_version) {
116
+ std::cerr << " [WARNING] Proto version doesn't match. TC git version is: "
117
+ << tc::git_version
118
+ << " and Proto version is: " << it->key .gitVersion
119
+ << " .This proto might be incompatible"
120
+ << " with your TC binary and can break. Please autotune"
121
+ << " against the correct TC version." << std::endl;
122
+ }
123
+ return &*it;
124
+ }
125
+ return nullptr ;
126
+ }
127
+
128
+ template <typename CachedEntryType, typename TensorType>
129
+ CachedEntryType* searchKernel (
130
+ std::vector<CachedEntryType>& entries,
131
+ const std::string& id,
132
+ const CudaMappingOptions& options,
133
+ const std::vector<TensorType>& inputs,
134
+ const std::vector<TensorType>& outputs) {
135
+ return const_cast <CachedEntryType*>(searchKernel (
136
+ static_cast <const std::vector<CachedEntryType>&>(entries),
137
+ id,
138
+ options,
139
+ inputs,
140
+ outputs));
141
+ }
58
142
} // namespace
59
143
60
144
std::shared_ptr<CudaCache>& CudaCache::getGlobalSharedCache () {
@@ -115,6 +199,7 @@ void CudaCache::cacheKernel(CudaCachedEntry&& entry) {
115
199
std::lock_guard<std::mutex> lock (mtx_);
116
200
++numberCacheAttemps;
117
201
auto retrievedEntry = searchKernel (
202
+ entries_,
118
203
entry.key .id ,
119
204
entry.key .mappingOptions ,
120
205
entry.key .inputs ,
@@ -133,38 +218,14 @@ void CudaCache::cacheKernel(CudaCachedEntry&& entry) {
133
218
entries_.emplace_back (entry);
134
219
}
135
220
136
- CudaCachedEntry* CudaCache::searchKernel (
137
- const std::string& id,
138
- const CudaMappingOptions& options,
139
- const std::vector<detail::TensorInfo>& inputs,
140
- const std::vector<detail::TensorInfo>& outputs) {
141
- return searchKernelImpl (*this , id, options, inputs, outputs);
142
- }
143
-
144
- CudaCachedEntry* CudaCache::searchKernel (
145
- const std::string& id,
146
- const CudaMappingOptions& options,
147
- const std::vector<const DLTensor*>& inputs,
148
- const std::vector<const DLTensor*>& outputs) {
149
- return searchKernelImpl (*this , id, options, inputs, outputs);
150
- }
151
-
152
- const CudaCachedEntry* CudaCache::searchKernel (
153
- const std::string& id,
154
- const CudaMappingOptions& options,
155
- const std::vector<const DLTensor*>& inputs,
156
- const std::vector<const DLTensor*>& outputs) const {
157
- return searchKernelImpl (*this , id, options, inputs, outputs);
158
- }
159
-
160
221
std::unique_ptr<CudaCacheRetrievalResult> CudaCache::retrieveKernel (
161
222
const std::string& id,
162
223
const CudaMappingOptions& options,
163
224
const std::vector<const DLTensor*>& inputs,
164
225
const std::vector<const DLTensor*>& outputs) const {
165
226
std::lock_guard<std::mutex> lock (mtx_);
166
227
++numberAttemptedRetrievals;
167
- auto entry = searchKernel (id, options, inputs, outputs);
228
+ auto entry = searchKernel (entries_, id, options, inputs, outputs);
168
229
if (not entry) {
169
230
return nullptr ;
170
231
}
@@ -182,6 +243,7 @@ void CudaCache::removeEntriesNotInOptionsCache(const OptionsCache& oc) {
182
243
for (const auto & entry : oc) {
183
244
for (const auto & options : entry.values ) {
184
245
auto cudaEntry = searchKernel (
246
+ entries_,
185
247
entry.key .id ,
186
248
options.mappingOptions ,
187
249
entry.key .inputs ,
@@ -222,7 +284,7 @@ OptionsCache::retrieveOptionsAndRuntimes(
222
284
const std::vector<const DLTensor*>& outputs) const {
223
285
std::lock_guard<std::mutex> lock (mtx_);
224
286
++numberAttemptedRetrievals;
225
- auto ret = searchKernel (id, inputs, outputs);
287
+ auto ret = searchKernel (entries_, id, inputs, outputs);
226
288
if (not ret) {
227
289
return {};
228
290
}
@@ -244,7 +306,7 @@ std::vector<CudaMappingOptions> OptionsCache::retrieveTopKOptions(
244
306
const std::vector<const DLTensor*>& inputs,
245
307
const std::vector<const DLTensor*>& outputs,
246
308
size_t k) const {
247
- auto candidates = searchKernel (id, inputs, outputs);
309
+ auto candidates = searchKernel (entries_, id, inputs, outputs);
248
310
std::lock_guard<std::mutex> lock (mtx_);
249
311
++numberAttemptedRetrievals;
250
312
if (not candidates) {
@@ -319,7 +381,7 @@ void OptionsCache::recordRuntime(
319
381
++numberCacheAttemps;
320
382
auto gpuStr = CudaGPUInfo::GPUInfo ().GetCudaDeviceStr ();
321
383
322
- auto kernel = searchKernel (id, inputs, outputs);
384
+ auto kernel = searchKernel (entries_, id, inputs, outputs);
323
385
if (not kernel) {
324
386
entries_.emplace_back (id, inputs, outputs, gpuStr, options, runtime);
325
387
return ;
@@ -338,20 +400,6 @@ void OptionsCache::recordRuntime(
338
400
v->recordedRuntimes .push_back (runtime);
339
401
}
340
402
341
- OptionsCachedEntry* OptionsCache::searchKernel (
342
- const std::string& id,
343
- const std::vector<const DLTensor*>& inputs,
344
- const std::vector<const DLTensor*>& outputs) {
345
- return searchKernelImpl (*this , id, inputs, outputs);
346
- }
347
-
348
- const OptionsCachedEntry* OptionsCache::searchKernel (
349
- const std::string& id,
350
- const std::vector<const DLTensor*>& inputs,
351
- const std::vector<const DLTensor*>& outputs) const {
352
- return searchKernelImpl (*this , id, inputs, outputs);
353
- }
354
-
355
403
OptionsCachedEntry::OptionsCachedEntry (
356
404
const std::string& id,
357
405
const std::vector<const DLTensor*>& inputs,
@@ -526,7 +574,7 @@ std::unique_ptr<CudaCacheRetrievalResult> ManualCudaCache::retrieveKernel(
526
574
const std::vector<const DLTensor*>& outputs) const {
527
575
std::lock_guard<std::mutex> lock (mtx_);
528
576
++numberAttemptedRetrievals;
529
- auto entry = searchKernel (id, inputs, outputs);
577
+ auto entry = searchKernel (entries_, id, inputs, outputs);
530
578
if (not entry) {
531
579
return nullptr ;
532
580
}
@@ -539,32 +587,11 @@ std::unique_ptr<CudaCacheRetrievalResult> ManualCudaCache::retrieveKernel(
539
587
entry->values .block });
540
588
}
541
589
542
- ManualCudaCachedEntry* ManualCudaCache::searchKernel (
543
- const std::string& id,
544
- const std::vector<detail::TensorInfo>& inputs,
545
- const std::vector<detail::TensorInfo>& outputs) {
546
- return searchKernelImpl (*this , id, inputs, outputs);
547
- }
548
-
549
- ManualCudaCachedEntry* ManualCudaCache::searchKernel (
550
- const std::string& id,
551
- const std::vector<const DLTensor*>& inputs,
552
- const std::vector<const DLTensor*>& outputs) {
553
- return searchKernelImpl (*this , id, inputs, outputs);
554
- }
555
-
556
- const ManualCudaCachedEntry* ManualCudaCache::searchKernel (
557
- const std::string& id,
558
- const std::vector<const DLTensor*>& inputs,
559
- const std::vector<const DLTensor*>& outputs) const {
560
- return searchKernelImpl (*this , id, inputs, outputs);
561
- }
562
-
563
590
void ManualCudaCache::cacheKernel (ManualCudaCachedEntry&& entry) {
564
591
std::lock_guard<std::mutex> lock (mtx_);
565
592
++numberCacheAttemps;
566
593
auto retrievedEntry =
567
- searchKernel (entry.key .id , entry.key .inputs , entry.key .outputs );
594
+ searchKernel (entries_, entry.key .id , entry.key .inputs , entry.key .outputs );
568
595
if (retrievedEntry) {
569
596
retrievedEntry->values .grid = entry.values .grid ;
570
597
retrievedEntry->values .block = entry.values .block ;
0 commit comments