Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 5f39adf

Browse files
nicolasvasilacheTheodoros Theodoridis
authored andcommitted
Make searchKernel a non-member locally scoped function
This simplifies the cache API and reduces the number of necessary implementations.
1 parent 9499c50 commit 5f39adf

File tree

3 files changed

+92
-248
lines changed

3 files changed

+92
-248
lines changed

tc/core/cuda/cuda_compilation_cache-inl.h

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -96,96 +96,4 @@ void Cache<CC, CachedEntryType>::clear() {
9696
0;
9797
static_cast<CC*>(this)->entries_.clear();
9898
}
99-
100-
template <typename C, typename InputTy> // deduces whether C is const or
101-
// non-const
102-
auto CudaCache::searchKernelImpl(
103-
C& c,
104-
const std::string& id,
105-
const CudaMappingOptions& options,
106-
const std::vector<InputTy>& inputs,
107-
const std::vector<InputTy>& outputs)
108-
-> decltype(c.searchKernel(id, options, inputs, outputs)) {
109-
auto gpuStr = CudaGPUInfo::GPUInfo().GetCudaDeviceStr();
110-
auto it = std::find_if(
111-
c.entries_.begin(), c.entries_.end(), [&](const CachedEntry& c) {
112-
using tc::operator==;
113-
return id == c.key.id && options == c.key.mappingOptions &&
114-
inputs == c.key.inputs && outputs == c.key.outputs &&
115-
gpuStr == c.key.deviceStr;
116-
});
117-
if (it != c.entries_.end()) {
118-
if (it->key.gitVersion != tc::git_version) {
119-
std::cerr << "[WARNING] Proto version doesn't match. TC git version is: "
120-
<< tc::git_version
121-
<< " and Proto version is: " << it->key.gitVersion
122-
<< " .This proto might be incompatible"
123-
<< " with your TC binary and can break. Please autotune"
124-
<< " against the correct TC version." << std::endl;
125-
}
126-
return &*it;
127-
}
128-
return nullptr;
129-
}
130-
131-
// deduces whether C is const or non-const
132-
template <typename C>
133-
auto OptionsCache::searchKernelImpl(
134-
C& c,
135-
const std::string& id,
136-
const std::vector<const DLTensor*>& inputs,
137-
const std::vector<const DLTensor*>& outputs)
138-
-> decltype(c.searchKernel(id, inputs, outputs)) {
139-
auto gpuStr = CudaGPUInfo::GPUInfo().GetCudaDeviceStr();
140-
auto it = std::find_if(
141-
c.entries_.begin(), c.entries_.end(), [&](const CachedEntry& c) {
142-
using tc::operator==;
143-
return id == c.key.id && inputs == c.key.inputs &&
144-
outputs == c.key.outputs && gpuStr == c.key.deviceStr;
145-
});
146-
if (it != c.entries_.end()) {
147-
if (it->key.gitVersion != tc::git_version) {
148-
std::cerr << "[WARNING] Proto version doesn't match. TC git version is: "
149-
<< tc::git_version
150-
<< " and Proto version is: " << it->key.gitVersion
151-
<< " .This proto might be incompatible"
152-
<< " with your TC binary and can break. Please autotune"
153-
<< " against the correct TC version." << std::endl;
154-
;
155-
}
156-
return &*it;
157-
}
158-
return nullptr;
159-
}
160-
161-
// deduces whether C is const or non-const
162-
template <typename C, typename TensorTy>
163-
auto ManualCudaCache::searchKernelImpl(
164-
C& c,
165-
const std::string& id,
166-
const std::vector<TensorTy>& inputs,
167-
const std::vector<TensorTy>& outputs)
168-
-> decltype(c.searchKernel(id, inputs, outputs)) {
169-
auto gpuStr = CudaGPUInfo::GPUInfo().GetCudaDeviceStr();
170-
auto it = std::find_if(
171-
c.entries_.begin(), c.entries_.end(), [&](const CachedEntry& c) {
172-
using tc::operator==;
173-
return id == c.key.id && inputs == c.key.inputs &&
174-
outputs == c.key.outputs && gpuStr == c.key.deviceStr;
175-
});
176-
if (it != c.entries_.end()) {
177-
std::cout << "RETURNING IT: " << it->key.gitVersion << std::endl;
178-
if (it->key.gitVersion != tc::git_version) {
179-
std::cerr << "[WARNING] Proto version doesn't match. TC git version is: "
180-
<< tc::git_version
181-
<< " and Proto version is: " << it->key.gitVersion
182-
<< " .This proto might be incompatible"
183-
<< " with your TC binary and can break. Please autotune"
184-
<< " against the correct TC version." << std::endl;
185-
;
186-
}
187-
return &*it;
188-
}
189-
return nullptr;
190-
}
19199
} // namespace tc

tc/core/cuda/cuda_compilation_cache.cc

Lines changed: 92 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,90 @@ void WriteProtobufArray(const Array& arr, Buf* buf) {
5555
arr.begin(), arr.end());
5656
buf->Swap(&data);
5757
}
58+
59+
template <typename CachedEntryType, typename TensorType>
60+
const CachedEntryType* searchKernel(
61+
const std::vector<CachedEntryType>& entries,
62+
const std::string& id,
63+
const std::vector<TensorType>& inputs,
64+
const std::vector<TensorType>& outputs) {
65+
auto gpuStr = CudaGPUInfo::GPUInfo().GetCudaDeviceStr();
66+
auto it = std::find_if(
67+
entries.begin(), entries.end(), [&](const CachedEntryType& c) {
68+
using tc::operator==;
69+
return id == c.key.id && inputs == c.key.inputs &&
70+
outputs == c.key.outputs && gpuStr == c.key.deviceStr;
71+
});
72+
if (it != entries.end()) {
73+
if (it->key.gitVersion != tc::git_version) {
74+
std::cerr << "[WARNING] Proto version doesn't match. TC git version is: "
75+
<< tc::git_version
76+
<< " and Proto version is: " << it->key.gitVersion
77+
<< " .This proto might be incompatible"
78+
<< " with your TC binary and can break. Please autotune"
79+
<< " against the correct TC version." << std::endl;
80+
}
81+
return &*it;
82+
}
83+
return nullptr;
84+
}
85+
86+
template <typename CachedEntryType, typename TensorType>
87+
CachedEntryType* searchKernel(
88+
std::vector<CachedEntryType>& entries,
89+
const std::string& id,
90+
const std::vector<TensorType>& inputs,
91+
const std::vector<TensorType>& outputs) {
92+
return const_cast<CachedEntryType*>(searchKernel(
93+
static_cast<const std::vector<CachedEntryType>&>(entries),
94+
id,
95+
inputs,
96+
outputs));
97+
}
98+
99+
template <typename CachedEntryType, typename TensorType>
100+
const CachedEntryType* searchKernel(
101+
const std::vector<CachedEntryType>& entries,
102+
const std::string& id,
103+
const CudaMappingOptions& options,
104+
const std::vector<TensorType>& inputs,
105+
const std::vector<TensorType>& outputs) {
106+
auto gpuStr = CudaGPUInfo::GPUInfo().GetCudaDeviceStr();
107+
auto it = std::find_if(
108+
entries.begin(), entries.end(), [&](const CachedEntryType& c) {
109+
using tc::operator==;
110+
return id == c.key.id && options == c.key.mappingOptions &&
111+
inputs == c.key.inputs && outputs == c.key.outputs &&
112+
gpuStr == c.key.deviceStr;
113+
});
114+
if (it != entries.end()) {
115+
if (it->key.gitVersion != tc::git_version) {
116+
std::cerr << "[WARNING] Proto version doesn't match. TC git version is: "
117+
<< tc::git_version
118+
<< " and Proto version is: " << it->key.gitVersion
119+
<< " .This proto might be incompatible"
120+
<< " with your TC binary and can break. Please autotune"
121+
<< " against the correct TC version." << std::endl;
122+
}
123+
return &*it;
124+
}
125+
return nullptr;
126+
}
127+
128+
template <typename CachedEntryType, typename TensorType>
129+
CachedEntryType* searchKernel(
130+
std::vector<CachedEntryType>& entries,
131+
const std::string& id,
132+
const CudaMappingOptions& options,
133+
const std::vector<TensorType>& inputs,
134+
const std::vector<TensorType>& outputs) {
135+
return const_cast<CachedEntryType*>(searchKernel(
136+
static_cast<const std::vector<CachedEntryType>&>(entries),
137+
id,
138+
options,
139+
inputs,
140+
outputs));
141+
}
58142
} // namespace
59143

60144
std::shared_ptr<CudaCache>& CudaCache::getGlobalSharedCache() {
@@ -115,6 +199,7 @@ void CudaCache::cacheKernel(CudaCachedEntry&& entry) {
115199
std::lock_guard<std::mutex> lock(mtx_);
116200
++numberCacheAttemps;
117201
auto retrievedEntry = searchKernel(
202+
entries_,
118203
entry.key.id,
119204
entry.key.mappingOptions,
120205
entry.key.inputs,
@@ -133,38 +218,14 @@ void CudaCache::cacheKernel(CudaCachedEntry&& entry) {
133218
entries_.emplace_back(entry);
134219
}
135220

136-
CudaCachedEntry* CudaCache::searchKernel(
137-
const std::string& id,
138-
const CudaMappingOptions& options,
139-
const std::vector<detail::TensorInfo>& inputs,
140-
const std::vector<detail::TensorInfo>& outputs) {
141-
return searchKernelImpl(*this, id, options, inputs, outputs);
142-
}
143-
144-
CudaCachedEntry* CudaCache::searchKernel(
145-
const std::string& id,
146-
const CudaMappingOptions& options,
147-
const std::vector<const DLTensor*>& inputs,
148-
const std::vector<const DLTensor*>& outputs) {
149-
return searchKernelImpl(*this, id, options, inputs, outputs);
150-
}
151-
152-
const CudaCachedEntry* CudaCache::searchKernel(
153-
const std::string& id,
154-
const CudaMappingOptions& options,
155-
const std::vector<const DLTensor*>& inputs,
156-
const std::vector<const DLTensor*>& outputs) const {
157-
return searchKernelImpl(*this, id, options, inputs, outputs);
158-
}
159-
160221
std::unique_ptr<CudaCacheRetrievalResult> CudaCache::retrieveKernel(
161222
const std::string& id,
162223
const CudaMappingOptions& options,
163224
const std::vector<const DLTensor*>& inputs,
164225
const std::vector<const DLTensor*>& outputs) const {
165226
std::lock_guard<std::mutex> lock(mtx_);
166227
++numberAttemptedRetrievals;
167-
auto entry = searchKernel(id, options, inputs, outputs);
228+
auto entry = searchKernel(entries_, id, options, inputs, outputs);
168229
if (not entry) {
169230
return nullptr;
170231
}
@@ -182,6 +243,7 @@ void CudaCache::removeEntriesNotInOptionsCache(const OptionsCache& oc) {
182243
for (const auto& entry : oc) {
183244
for (const auto& options : entry.values) {
184245
auto cudaEntry = searchKernel(
246+
entries_,
185247
entry.key.id,
186248
options.mappingOptions,
187249
entry.key.inputs,
@@ -222,7 +284,7 @@ OptionsCache::retrieveOptionsAndRuntimes(
222284
const std::vector<const DLTensor*>& outputs) const {
223285
std::lock_guard<std::mutex> lock(mtx_);
224286
++numberAttemptedRetrievals;
225-
auto ret = searchKernel(id, inputs, outputs);
287+
auto ret = searchKernel(entries_, id, inputs, outputs);
226288
if (not ret) {
227289
return {};
228290
}
@@ -244,7 +306,7 @@ std::vector<CudaMappingOptions> OptionsCache::retrieveTopKOptions(
244306
const std::vector<const DLTensor*>& inputs,
245307
const std::vector<const DLTensor*>& outputs,
246308
size_t k) const {
247-
auto candidates = searchKernel(id, inputs, outputs);
309+
auto candidates = searchKernel(entries_, id, inputs, outputs);
248310
std::lock_guard<std::mutex> lock(mtx_);
249311
++numberAttemptedRetrievals;
250312
if (not candidates) {
@@ -319,7 +381,7 @@ void OptionsCache::recordRuntime(
319381
++numberCacheAttemps;
320382
auto gpuStr = CudaGPUInfo::GPUInfo().GetCudaDeviceStr();
321383

322-
auto kernel = searchKernel(id, inputs, outputs);
384+
auto kernel = searchKernel(entries_, id, inputs, outputs);
323385
if (not kernel) {
324386
entries_.emplace_back(id, inputs, outputs, gpuStr, options, runtime);
325387
return;
@@ -338,20 +400,6 @@ void OptionsCache::recordRuntime(
338400
v->recordedRuntimes.push_back(runtime);
339401
}
340402

341-
OptionsCachedEntry* OptionsCache::searchKernel(
342-
const std::string& id,
343-
const std::vector<const DLTensor*>& inputs,
344-
const std::vector<const DLTensor*>& outputs) {
345-
return searchKernelImpl(*this, id, inputs, outputs);
346-
}
347-
348-
const OptionsCachedEntry* OptionsCache::searchKernel(
349-
const std::string& id,
350-
const std::vector<const DLTensor*>& inputs,
351-
const std::vector<const DLTensor*>& outputs) const {
352-
return searchKernelImpl(*this, id, inputs, outputs);
353-
}
354-
355403
OptionsCachedEntry::OptionsCachedEntry(
356404
const std::string& id,
357405
const std::vector<const DLTensor*>& inputs,
@@ -526,7 +574,7 @@ std::unique_ptr<CudaCacheRetrievalResult> ManualCudaCache::retrieveKernel(
526574
const std::vector<const DLTensor*>& outputs) const {
527575
std::lock_guard<std::mutex> lock(mtx_);
528576
++numberAttemptedRetrievals;
529-
auto entry = searchKernel(id, inputs, outputs);
577+
auto entry = searchKernel(entries_, id, inputs, outputs);
530578
if (not entry) {
531579
return nullptr;
532580
}
@@ -539,32 +587,11 @@ std::unique_ptr<CudaCacheRetrievalResult> ManualCudaCache::retrieveKernel(
539587
entry->values.block});
540588
}
541589

542-
ManualCudaCachedEntry* ManualCudaCache::searchKernel(
543-
const std::string& id,
544-
const std::vector<detail::TensorInfo>& inputs,
545-
const std::vector<detail::TensorInfo>& outputs) {
546-
return searchKernelImpl(*this, id, inputs, outputs);
547-
}
548-
549-
ManualCudaCachedEntry* ManualCudaCache::searchKernel(
550-
const std::string& id,
551-
const std::vector<const DLTensor*>& inputs,
552-
const std::vector<const DLTensor*>& outputs) {
553-
return searchKernelImpl(*this, id, inputs, outputs);
554-
}
555-
556-
const ManualCudaCachedEntry* ManualCudaCache::searchKernel(
557-
const std::string& id,
558-
const std::vector<const DLTensor*>& inputs,
559-
const std::vector<const DLTensor*>& outputs) const {
560-
return searchKernelImpl(*this, id, inputs, outputs);
561-
}
562-
563590
void ManualCudaCache::cacheKernel(ManualCudaCachedEntry&& entry) {
564591
std::lock_guard<std::mutex> lock(mtx_);
565592
++numberCacheAttemps;
566593
auto retrievedEntry =
567-
searchKernel(entry.key.id, entry.key.inputs, entry.key.outputs);
594+
searchKernel(entries_, entry.key.id, entry.key.inputs, entry.key.outputs);
568595
if (retrievedEntry) {
569596
retrievedEntry->values.grid = entry.values.grid;
570597
retrievedEntry->values.block = entry.values.block;

0 commit comments

Comments
 (0)