@@ -111,39 +111,26 @@ CudaCachedEntry::CudaCachedEntry(const CudaCacheEntryProto& buf)
111
111
Grid (buf.grid_dims ()),
112
112
Block (buf.block_dims ())} {}
113
113
114
- void CudaCache::cacheKernel (
115
- const std::string& id,
116
- const CudaMappingOptions& options,
117
- const std::vector<const DLTensor*>& inputs,
118
- const std::vector<const DLTensor*>& outputs,
119
- const std::string& kernelSpecializedName,
120
- const std::vector<int >& kernelParameters,
121
- const std::string& cudaSource,
122
- const Grid& grid,
123
- const Block& block) {
114
+ void CudaCache::cacheKernel (CudaCachedEntry&& entry) {
124
115
std::lock_guard<std::mutex> lock (mtx_);
125
116
++numberCacheAttemps;
126
- auto entry = searchKernel (id, options, inputs, outputs);
127
- if (entry) {
128
- if (entry->values .cudaSource == cudaSource or entry->values .grid == grid or
129
- entry->values .block == block) {
117
+ auto retrievedEntry = searchKernel (
118
+ entry.key .id ,
119
+ entry.key .mappingOptions ,
120
+ entry.key .inputs ,
121
+ entry.key .outputs );
122
+ if (retrievedEntry) {
123
+ if (retrievedEntry->values .cudaSource == entry.values .cudaSource or
124
+ retrievedEntry->values .grid == entry.values .grid or
125
+ retrievedEntry->values .block == entry.values .block ) {
130
126
throw CacheEntrySameKeyDifferentValue (
131
- " CudaCache::CacheKernel: a kernel matching the id, options and inputs was previously cached with different cuda source or block or grid dimensions." );
127
+ " CudaCache::CacheKernel: a kernel matching the id, options and "
128
+ " inputs was previously cached with different cuda source or block "
129
+ " or grid dimensions." );
132
130
}
133
131
return ;
134
132
}
135
-
136
- entries_.emplace_back (
137
- id,
138
- kernelSpecializedName,
139
- kernelParameters,
140
- grid,
141
- block,
142
- options,
143
- inputs,
144
- outputs,
145
- cudaSource,
146
- CudaGPUInfo::GPUInfo ().GetCudaDeviceStr ());
133
+ entries_.emplace_back (entry);
147
134
}
148
135
149
136
CudaCachedEntry* CudaCache::searchKernel (
@@ -552,6 +539,13 @@ std::unique_ptr<CudaCacheRetrievalResult> ManualCudaCache::retrieveKernel(
552
539
entry->values .block });
553
540
}
554
541
542
+ ManualCudaCachedEntry* ManualCudaCache::searchKernel (
543
+ const std::string& id,
544
+ const std::vector<detail::TensorInfo>& inputs,
545
+ const std::vector<detail::TensorInfo>& outputs) {
546
+ return searchKernelImpl (*this , id, inputs, outputs);
547
+ }
548
+
555
549
ManualCudaCachedEntry* ManualCudaCache::searchKernel (
556
550
const std::string& id,
557
551
const std::vector<const DLTensor*>& inputs,
@@ -566,38 +560,23 @@ const ManualCudaCachedEntry* ManualCudaCache::searchKernel(
566
560
return searchKernelImpl (*this , id, inputs, outputs);
567
561
}
568
562
569
- void ManualCudaCache::cacheKernel (
570
- const std::string& id,
571
- const std::vector<const DLTensor*>& inputs,
572
- const std::vector<const DLTensor*>& outputs,
573
- const std::string& kernelSpecializedName,
574
- const std::vector<int >& kernelParameters,
575
- const std::string& cudaSource,
576
- const Grid& grid,
577
- const Block& block) {
563
+ void ManualCudaCache::cacheKernel (ManualCudaCachedEntry&& entry) {
578
564
std::lock_guard<std::mutex> lock (mtx_);
579
565
++numberCacheAttemps;
580
- auto entry = searchKernel (id, inputs, outputs);
581
- if (entry) {
582
- entry->values .grid = grid;
583
- entry->values .block = block;
584
- entry->values .cudaSource = cudaSource;
585
- entry->values .kernelSpecializedName = kernelSpecializedName;
586
- entry->values .kernelParameters = kernelParameters;
566
+ auto retrievedEntry =
567
+ searchKernel (entry.key .id , entry.key .inputs , entry.key .outputs );
568
+ if (retrievedEntry) {
569
+ retrievedEntry->values .grid = entry.values .grid ;
570
+ retrievedEntry->values .block = entry.values .block ;
571
+ retrievedEntry->values .cudaSource = entry.values .cudaSource ;
572
+ retrievedEntry->values .kernelSpecializedName =
573
+ entry.values .kernelSpecializedName ;
574
+ retrievedEntry->values .kernelParameters = entry.values .kernelParameters ;
587
575
return ;
588
576
}
589
-
590
- entries_.emplace_back (
591
- id,
592
- kernelSpecializedName,
593
- kernelParameters,
594
- grid,
595
- block,
596
- inputs,
597
- outputs,
598
- cudaSource,
599
- CudaGPUInfo::GPUInfo ().GetCudaDeviceStr ());
577
+ entries_.emplace_back (entry);
600
578
}
579
+
601
580
ManualCudaCachedEntry::ManualCudaCachedEntry (
602
581
const std::string& id,
603
582
const std::string& kernelSpecializedName,
0 commit comments