Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 11375df

Browse files
nicolasvasilacheTheodoros Theodoridis
authored andcommitted
Hoist entries to base class
This changeset moves the vector of entries from the specialized caches to the base cache class. This will help refactoring and simplication and allow removing all the searchKernel methods in the subsequent changesets. In particular this requires breaking out the dependent CachedEntry types and making them standalone types that can be passed as template parameters to the base Cache class. Private visibility and friends are also removed.
1 parent 24f649a commit 11375df

File tree

4 files changed

+189
-203
lines changed

4 files changed

+189
-203
lines changed

tc/core/compilation_cache.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ struct TensorInfo {
5656
};
5757
} // namespace detail
5858

59-
template <typename CC>
59+
template <typename CC, typename CachedEntryType>
6060
class Cache {
6161
public:
6262
static void enableCache();
@@ -68,6 +68,12 @@ class Cache {
6868
static std::shared_ptr<CC> getCache();
6969
static bool cacheEnabled();
7070

71+
typename std::vector<CachedEntryType>::const_iterator begin() const {
72+
return entries_.begin();
73+
}
74+
typename std::vector<CachedEntryType>::const_iterator end() const {
75+
return entries_.end();
76+
}
7177
size_t size() const;
7278
void clear();
7379

@@ -78,6 +84,8 @@ class Cache {
7884
protected:
7985
// XXX:this should be a std or boost shared_mutex
8086
mutable std::mutex mtx_;
87+
88+
std::vector<CachedEntryType> entries_;
8189
};
8290

8391
class CacheEntrySameKeyDifferentValue : public std::invalid_argument {

tc/core/cuda/cuda_compilation_cache-inl.h

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,28 @@
2525

2626
namespace tc {
2727

28-
template <typename CC>
29-
void Cache<CC>::enableCache() {
28+
template <typename CC, typename CachedEntryType>
29+
void Cache<CC, CachedEntryType>::enableCache() {
3030
CC::getGlobalSharedCache() = std::make_shared<CC>();
3131
}
3232

33-
template <typename CC>
34-
void Cache<CC>::disableCache() {
33+
template <typename CC, typename CachedEntryType>
34+
void Cache<CC, CachedEntryType>::disableCache() {
3535
CC::getGlobalSharedCache() = nullptr;
3636
}
3737

38-
template <typename CC>
39-
std::shared_ptr<CC> Cache<CC>::getCache() {
38+
template <typename CC, typename CachedEntryType>
39+
std::shared_ptr<CC> Cache<CC, CachedEntryType>::getCache() {
4040
if (not cacheEnabled()) {
4141
throw std::runtime_error(
4242
"EnableCache or LoadCacheFromProtobuf must be called before using the cache.");
4343
}
4444
return CC::getGlobalSharedCache();
4545
}
4646

47-
template <typename CC>
48-
void Cache<CC>::dumpCacheToProtobuf(const std::string& filename) {
47+
template <typename CC, typename CachedEntryType>
48+
void Cache<CC, CachedEntryType>::dumpCacheToProtobuf(
49+
const std::string& filename) {
4950
std::fstream serialized(
5051
filename, std::ios::binary | std::ios::trunc | std::ios::out);
5152
if (!serialized) {
@@ -56,8 +57,9 @@ void Cache<CC>::dumpCacheToProtobuf(const std::string& filename) {
5657
}
5758
}
5859

59-
template <typename CC>
60-
void Cache<CC>::loadCacheFromProtobuf(const std::string& filename) {
60+
template <typename CC, typename CachedEntryType>
61+
void Cache<CC, CachedEntryType>::loadCacheFromProtobuf(
62+
const std::string& filename) {
6163
typename CC::Protobuf buf;
6264
struct stat buffer = {0};
6365
if (stat(filename.c_str(), &buffer) == 0) {
@@ -67,28 +69,28 @@ void Cache<CC>::loadCacheFromProtobuf(const std::string& filename) {
6769
loadCacheFromProtobuf(buf);
6870
}
6971

70-
template <typename CC>
72+
template <typename CC, typename CachedEntryType>
7173
template <typename Protobuf>
72-
void Cache<CC>::loadCacheFromProtobuf(const Protobuf& buf) {
74+
void Cache<CC, CachedEntryType>::loadCacheFromProtobuf(const Protobuf& buf) {
7375
static_assert(
7476
std::is_same<Protobuf, typename CC::Protobuf>::value,
7577
"LoadCacheFromProtobuf called with invalide protobuf type.");
7678
CC::getGlobalSharedCache() = std::make_shared<CC>(buf);
7779
}
7880

79-
template <typename CC>
80-
bool Cache<CC>::cacheEnabled() {
81+
template <typename CC, typename CachedEntryType>
82+
bool Cache<CC, CachedEntryType>::cacheEnabled() {
8183
return CC::getGlobalSharedCache() != nullptr;
8284
}
8385

84-
template <typename CC>
85-
size_t Cache<CC>::size() const {
86+
template <typename CC, typename CachedEntryType>
87+
size_t Cache<CC, CachedEntryType>::size() const {
8688
std::lock_guard<std::mutex> lock(mtx_);
8789
return static_cast<const CC*>(this)->entries_.size();
8890
}
8991

90-
template <typename CC>
91-
void Cache<CC>::clear() {
92+
template <typename CC, typename CachedEntryType>
93+
void Cache<CC, CachedEntryType>::clear() {
9294
std::lock_guard<std::mutex> lock(mtx_);
9395
numberAttemptedRetrievals = numberSuccessfulRetrievals = numberCacheAttemps =
9496
0;
@@ -186,5 +188,4 @@ auto ManualCudaCache::searchKernelImpl(
186188
}
187189
return nullptr;
188190
}
189-
190191
} // namespace tc

tc/core/cuda/cuda_compilation_cache.cc

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ CudaCache::CudaCache(const CudaCacheProto& buf) {
7878
entries_.emplace_back(entry_buf);
7979
}
8080

81-
CudaCache::CachedEntry::CachedEntry(
81+
CudaCachedEntry::CudaCachedEntry(
8282
const std::string& id,
8383
const std::string& kernelSpecializedName,
8484
const std::vector<int>& kernelParameters,
@@ -98,7 +98,7 @@ CudaCache::CachedEntry::CachedEntry(
9898
values{cudaSource, kernelSpecializedName, kernelParameters, grid, block} {
9999
}
100100

101-
CudaCache::CachedEntry::CachedEntry(const CudaCacheEntryProto& buf)
101+
CudaCachedEntry::CudaCachedEntry(const CudaCacheEntryProto& buf)
102102
: key{buf.id(),
103103
CudaMappingOptions{buf.kernel_options()},
104104
ProtoToTensorInfoVector(buf.inputs()),
@@ -146,23 +146,23 @@ void CudaCache::cacheKernel(
146146
CudaGPUInfo::GPUInfo().GetCudaDeviceStr());
147147
}
148148

149-
CudaCache::CachedEntry* CudaCache::searchKernel(
149+
CudaCachedEntry* CudaCache::searchKernel(
150150
const std::string& id,
151151
const CudaMappingOptions& options,
152152
const std::vector<detail::TensorInfo>& inputs,
153153
const std::vector<detail::TensorInfo>& outputs) {
154154
return searchKernelImpl(*this, id, options, inputs, outputs);
155155
}
156156

157-
CudaCache::CachedEntry* CudaCache::searchKernel(
157+
CudaCachedEntry* CudaCache::searchKernel(
158158
const std::string& id,
159159
const CudaMappingOptions& options,
160160
const std::vector<const DLTensor*>& inputs,
161161
const std::vector<const DLTensor*>& outputs) {
162162
return searchKernelImpl(*this, id, options, inputs, outputs);
163163
}
164164

165-
const CudaCache::CachedEntry* CudaCache::searchKernel(
165+
const CudaCachedEntry* CudaCache::searchKernel(
166166
const std::string& id,
167167
const CudaMappingOptions& options,
168168
const std::vector<const DLTensor*>& inputs,
@@ -351,21 +351,21 @@ void OptionsCache::recordRuntime(
351351
v->recordedRuntimes.push_back(runtime);
352352
}
353353

354-
OptionsCache::CachedEntry* OptionsCache::searchKernel(
354+
OptionsCachedEntry* OptionsCache::searchKernel(
355355
const std::string& id,
356356
const std::vector<const DLTensor*>& inputs,
357357
const std::vector<const DLTensor*>& outputs) {
358358
return searchKernelImpl(*this, id, inputs, outputs);
359359
}
360360

361-
const OptionsCache::CachedEntry* OptionsCache::searchKernel(
361+
const OptionsCachedEntry* OptionsCache::searchKernel(
362362
const std::string& id,
363363
const std::vector<const DLTensor*>& inputs,
364364
const std::vector<const DLTensor*>& outputs) const {
365365
return searchKernelImpl(*this, id, inputs, outputs);
366366
}
367367

368-
OptionsCache::CachedEntry::CachedEntry(
368+
OptionsCachedEntry::OptionsCachedEntry(
369369
const std::string& id,
370370
const std::vector<const DLTensor*>& inputs,
371371
const std::vector<const DLTensor*>& outputs,
@@ -376,7 +376,7 @@ OptionsCache::CachedEntry::CachedEntry(
376376
values.emplace_back(options, runtime);
377377
}
378378

379-
OptionsCache::CachedEntry::Key::Key(
379+
OptionsCachedEntry::Key::Key(
380380
const std::string& id,
381381
const std::vector<const DLTensor*>& inputs_,
382382
const std::vector<const DLTensor*>& outputs_,
@@ -388,7 +388,7 @@ OptionsCache::CachedEntry::Key::Key(
388388
deviceStr,
389389
gitVersion) {}
390390

391-
OptionsCache::CachedEntry::Key::Key(
391+
OptionsCachedEntry::Key::Key(
392392
const std::string& id,
393393
std::vector<detail::TensorInfo>&& inputs_,
394394
std::vector<detail::TensorInfo>&& outputs_,
@@ -400,12 +400,12 @@ OptionsCache::CachedEntry::Key::Key(
400400
deviceStr(deviceStr),
401401
gitVersion(gitVersion) {}
402402

403-
OptionsCache::CachedEntry::Values::Values(
403+
OptionsCachedEntry::Values::Values(
404404
const CudaMappingOptions& options,
405405
Duration runtime)
406406
: mappingOptions(options), recordedRuntimes{runtime} {}
407407

408-
OptionsCache::CachedEntry::Values::Values(
408+
OptionsCachedEntry::Values::Values(
409409
const CudaMappingOptions& options,
410410
std::vector<Duration>&& runtimes)
411411
: mappingOptions(options), recordedRuntimes(std::move(runtimes)) {}
@@ -416,29 +416,21 @@ OptionsCache::OptionsCache(const OptionsCacheProto& buf) {
416416
entries_.emplace_back(entry_buf);
417417
}
418418

419-
decltype(OptionsCache::entries_)::const_iterator OptionsCache::begin() const {
420-
return entries_.begin();
421-
}
422-
423-
decltype(OptionsCache::entries_)::const_iterator OptionsCache::end() const {
424-
return entries_.end();
425-
}
426-
427-
OptionsCache::CachedEntry::CachedEntry(const OptionsCacheEntryProto& buf)
419+
OptionsCachedEntry::OptionsCachedEntry(const OptionsCacheEntryProto& buf)
428420
: key(buf.id(),
429421
ProtoToTensorInfoVector(buf.inputs()),
430422
ProtoToTensorInfoVector(buf.outputs()),
431423
buf.device_str(),
432424
buf.git_version()) {
433425
if (buf.values_size() == 0) {
434426
throw std::invalid_argument(
435-
"OptionsCache::CachedEntry invalid protobuf: each entry should have at least one value field.");
427+
"OptionsCachedEntry invalid protobuf: each entry should have at least one value field.");
436428
}
437429

438430
for (const auto& value : buf.values()) {
439431
if (value.recorded_runtimes_size() == 0) {
440432
throw std::invalid_argument(
441-
"OptionsCache::CachedEntry invalid protobuf: each entry value should have at least one recorded runtime.");
433+
"OptionsCachedEntry invalid protobuf: each entry value should have at least one recorded runtime.");
442434
}
443435
std::vector<Duration> runtimes;
444436
runtimes.reserve(value.recorded_runtimes_size());
@@ -464,7 +456,7 @@ OptionsCacheProto OptionsCache::toProtobuf() const {
464456
return buf;
465457
}
466458

467-
OptionsCacheEntryProto OptionsCache::CachedEntry::toProtobuf() const {
459+
OptionsCacheEntryProto OptionsCachedEntry::toProtobuf() const {
468460
OptionsCacheEntryProto buf;
469461
buf.set_id(key.id);
470462
std::transform(
@@ -509,7 +501,7 @@ CudaCacheProto CudaCache::toProtobuf() const {
509501
return buf;
510502
}
511503

512-
CudaCacheEntryProto CudaCache::CachedEntry::toProtobuf() const {
504+
CudaCacheEntryProto CudaCachedEntry::toProtobuf() const {
513505
CudaCacheEntryProto buf;
514506
buf.set_id(key.id);
515507
*buf.mutable_kernel_options() = key.mappingOptions.proto();
@@ -560,14 +552,14 @@ std::unique_ptr<CudaCache::RetrievalResult> ManualCudaCache::retrieveKernel(
560552
entry->values.block});
561553
}
562554

563-
ManualCudaCache::CachedEntry* ManualCudaCache::searchKernel(
555+
ManualCudaCachedEntry* ManualCudaCache::searchKernel(
564556
const std::string& id,
565557
const std::vector<const DLTensor*>& inputs,
566558
const std::vector<const DLTensor*>& outputs) {
567559
return searchKernelImpl(*this, id, inputs, outputs);
568560
}
569561

570-
const ManualCudaCache::CachedEntry* ManualCudaCache::searchKernel(
562+
const ManualCudaCachedEntry* ManualCudaCache::searchKernel(
571563
const std::string& id,
572564
const std::vector<const DLTensor*>& inputs,
573565
const std::vector<const DLTensor*>& outputs) const {
@@ -606,7 +598,7 @@ void ManualCudaCache::cacheKernel(
606598
cudaSource,
607599
CudaGPUInfo::GPUInfo().GetCudaDeviceStr());
608600
}
609-
ManualCudaCache::CachedEntry::CachedEntry(
601+
ManualCudaCachedEntry::ManualCudaCachedEntry(
610602
const std::string& id,
611603
const std::string& kernelSpecializedName,
612604
const std::vector<int>& kernelParameters,

0 commit comments

Comments
 (0)