Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 9499c50

Browse files
nicolasvasilacheTheodoros Theodoridis
authored andcommitted
Refactor cacheKernel call
This changeset makes the cacheKernel methods take a CacheEntry&& This simplifies the API and gets rid of the discrepancy in the parameter ordering between the constructor of a CacheEntry and the passing of the same parameters to cacheKernel.
1 parent 68b01ed commit 9499c50

File tree

5 files changed

+191
-201
lines changed

5 files changed

+191
-201
lines changed

tc/core/cuda/cuda_compilation_cache.cc

Lines changed: 33 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -111,39 +111,26 @@ CudaCachedEntry::CudaCachedEntry(const CudaCacheEntryProto& buf)
111111
Grid(buf.grid_dims()),
112112
Block(buf.block_dims())} {}
113113

114-
void CudaCache::cacheKernel(
115-
const std::string& id,
116-
const CudaMappingOptions& options,
117-
const std::vector<const DLTensor*>& inputs,
118-
const std::vector<const DLTensor*>& outputs,
119-
const std::string& kernelSpecializedName,
120-
const std::vector<int>& kernelParameters,
121-
const std::string& cudaSource,
122-
const Grid& grid,
123-
const Block& block) {
114+
void CudaCache::cacheKernel(CudaCachedEntry&& entry) {
124115
std::lock_guard<std::mutex> lock(mtx_);
125116
++numberCacheAttemps;
126-
auto entry = searchKernel(id, options, inputs, outputs);
127-
if (entry) {
128-
if (entry->values.cudaSource == cudaSource or entry->values.grid == grid or
129-
entry->values.block == block) {
117+
auto retrievedEntry = searchKernel(
118+
entry.key.id,
119+
entry.key.mappingOptions,
120+
entry.key.inputs,
121+
entry.key.outputs);
122+
if (retrievedEntry) {
123+
if (retrievedEntry->values.cudaSource == entry.values.cudaSource or
124+
retrievedEntry->values.grid == entry.values.grid or
125+
retrievedEntry->values.block == entry.values.block) {
130126
throw CacheEntrySameKeyDifferentValue(
131-
"CudaCache::CacheKernel: a kernel matching the id, options and inputs was previously cached with different cuda source or block or grid dimensions.");
127+
"CudaCache::CacheKernel: a kernel matching the id, options and "
128+
"inputs was previously cached with different cuda source or block "
129+
"or grid dimensions.");
132130
}
133131
return;
134132
}
135-
136-
entries_.emplace_back(
137-
id,
138-
kernelSpecializedName,
139-
kernelParameters,
140-
grid,
141-
block,
142-
options,
143-
inputs,
144-
outputs,
145-
cudaSource,
146-
CudaGPUInfo::GPUInfo().GetCudaDeviceStr());
133+
entries_.emplace_back(entry);
147134
}
148135

149136
CudaCachedEntry* CudaCache::searchKernel(
@@ -552,6 +539,13 @@ std::unique_ptr<CudaCacheRetrievalResult> ManualCudaCache::retrieveKernel(
552539
entry->values.block});
553540
}
554541

542+
ManualCudaCachedEntry* ManualCudaCache::searchKernel(
543+
const std::string& id,
544+
const std::vector<detail::TensorInfo>& inputs,
545+
const std::vector<detail::TensorInfo>& outputs) {
546+
return searchKernelImpl(*this, id, inputs, outputs);
547+
}
548+
555549
ManualCudaCachedEntry* ManualCudaCache::searchKernel(
556550
const std::string& id,
557551
const std::vector<const DLTensor*>& inputs,
@@ -566,38 +560,23 @@ const ManualCudaCachedEntry* ManualCudaCache::searchKernel(
566560
return searchKernelImpl(*this, id, inputs, outputs);
567561
}
568562

569-
void ManualCudaCache::cacheKernel(
570-
const std::string& id,
571-
const std::vector<const DLTensor*>& inputs,
572-
const std::vector<const DLTensor*>& outputs,
573-
const std::string& kernelSpecializedName,
574-
const std::vector<int>& kernelParameters,
575-
const std::string& cudaSource,
576-
const Grid& grid,
577-
const Block& block) {
563+
void ManualCudaCache::cacheKernel(ManualCudaCachedEntry&& entry) {
578564
std::lock_guard<std::mutex> lock(mtx_);
579565
++numberCacheAttemps;
580-
auto entry = searchKernel(id, inputs, outputs);
581-
if (entry) {
582-
entry->values.grid = grid;
583-
entry->values.block = block;
584-
entry->values.cudaSource = cudaSource;
585-
entry->values.kernelSpecializedName = kernelSpecializedName;
586-
entry->values.kernelParameters = kernelParameters;
566+
auto retrievedEntry =
567+
searchKernel(entry.key.id, entry.key.inputs, entry.key.outputs);
568+
if (retrievedEntry) {
569+
retrievedEntry->values.grid = entry.values.grid;
570+
retrievedEntry->values.block = entry.values.block;
571+
retrievedEntry->values.cudaSource = entry.values.cudaSource;
572+
retrievedEntry->values.kernelSpecializedName =
573+
entry.values.kernelSpecializedName;
574+
retrievedEntry->values.kernelParameters = entry.values.kernelParameters;
587575
return;
588576
}
589-
590-
entries_.emplace_back(
591-
id,
592-
kernelSpecializedName,
593-
kernelParameters,
594-
grid,
595-
block,
596-
inputs,
597-
outputs,
598-
cudaSource,
599-
CudaGPUInfo::GPUInfo().GetCudaDeviceStr());
577+
entries_.emplace_back(entry);
600578
}
579+
601580
ManualCudaCachedEntry::ManualCudaCachedEntry(
602581
const std::string& id,
603582
const std::string& kernelSpecializedName,

tc/core/cuda/cuda_compilation_cache.h

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -147,16 +147,7 @@ class CudaCache : public Cache<CudaCache, CudaCachedEntry> {
147147
* target device are the same then this is a noop
148148
* Else (cudaSource, grid, block) is stored in the cache
149149
*/
150-
void cacheKernel(
151-
const std::string& id,
152-
const CudaMappingOptions& options,
153-
const std::vector<const DLTensor*>& inputs,
154-
const std::vector<const DLTensor*>& outputs,
155-
const std::string& kernelSpecializedName,
156-
const std::vector<int>& kernelParameters,
157-
const std::string& cudaSource,
158-
const Grid& grid,
159-
const Block& block);
150+
void cacheKernel(CudaCachedEntry&& entry);
160151

161152
/**
162153
* Returns the cache entry that matches op (id, isl options, target device)
@@ -393,15 +384,7 @@ class ManualCudaCache : public Cache<ManualCudaCache, ManualCudaCachedEntry> {
393384
*target device). If the key already exist in the cache,
394385
*the values are replaced.
395386
*/
396-
void cacheKernel(
397-
const std::string& id,
398-
const std::vector<const DLTensor*>& inputs,
399-
const std::vector<const DLTensor*>& outputs,
400-
const std::string& kernelSpecializedName,
401-
const std::vector<int>& kernelParameters,
402-
const std::string& cudaSource,
403-
const Grid& grid,
404-
const Block& block);
387+
void cacheKernel(ManualCudaCachedEntry&& entry);
405388

406389
/*
407390
*Returns the cache entry that matches

tc/core/cuda/cuda_tc_executor.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,17 @@ void CudaTcExecutor::compile(const tc::CudaMappingOptions& options) {
9292
if (CudaCache::cacheEnabled()) {
9393
LOG_IF(INFO, FLAGS_debug_tc_mapper) << "original grid: " << grid;
9494
LOG_IF(INFO, FLAGS_debug_tc_mapper) << "original block: " << block;
95-
CudaCache::getCache()->cacheKernel(
95+
CudaCache::getCache()->cacheKernel(CudaCachedEntry(
9696
cacheKeyId_,
97+
kernelSpecializedName,
98+
executionInfo_.kernelParams,
99+
grid,
100+
block,
97101
options,
98102
extractRawPtrs(executionInfo_.inputsInfo),
99103
extractRawPtrs(executionInfo_.outputsInfo),
100-
kernelSpecializedName,
101-
executionInfo_.kernelParams,
102104
cudaSource,
103-
grid,
104-
block);
105+
CudaGPUInfo::GPUInfo().GetCudaDeviceStr()));
105106
}
106107
}
107108

tensor_comprehensions/pybinds/pybind_engine.cc

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include "pybind_utils.h"
2727
#include "tc/aten/aten_compiler.h"
28+
#include "tc/core/cuda/cuda.h"
2829
#include "tc/core/cuda/cuda_compilation_cache.h"
2930
#include "tc/core/cuda/cuda_mapping_options.h"
3031
#include "tc/core/cuda/cuda_tc_executor.h"
@@ -126,14 +127,16 @@ PYBIND11_MODULE(tc, m) {
126127
[&]() { tc::deleteDlmTensors(tensorsPair.second); });
127128
auto outTensorInfo = instance.inferOutputTensorInfo(name, atInputs);
128129
tc::ManualCudaCache::getCache()->cacheKernel(
129-
name,
130-
tensorsPair.first,
131-
outTensorInfo,
132-
injectedKernelName,
133-
{},
134-
cudaSource,
135-
tc::Grid(grid),
136-
tc::Block(block));
130+
tc::ManualCudaCachedEntry(
131+
name,
132+
injectedKernelName,
133+
{},
134+
tc::Grid(grid),
135+
tc::Block(block),
136+
tensorsPair.first,
137+
outTensorInfo,
138+
cudaSource,
139+
tc::CudaGPUInfo::GPUInfo().GetCudaDeviceStr()));
137140
});
138141
}
139142

0 commit comments

Comments
 (0)