Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 5e076a7

Browse files
nicolasvasilacheTheodoros Theodoridis
authored andcommitted
Reorder code in cuda_compilation_cache.h and drop forward declaration
1 parent 3f5e0f3 commit 5e076a7

File tree

1 file changed

+92
-96
lines changed

1 file changed

+92
-96
lines changed

tc/core/cuda/cuda_compilation_cache.h

Lines changed: 92 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -34,102 +34,6 @@
3434

3535
namespace tc {
3636

37-
class OptionsCache;
38-
39-
////////////////////////////////////////////////////////////////////////////////
40-
// CudaCache
41-
////////////////////////////////////////////////////////////////////////////////
42-
struct CudaCachedEntry {
43-
CudaCachedEntry(
44-
const std::string& id,
45-
const std::string& kernelSpecializedName,
46-
const std::vector<int>& kernelParameters,
47-
const Grid& grid,
48-
const Block& block,
49-
const CudaMappingOptions& mappingOptions,
50-
const std::vector<const DLTensor*>& inputs,
51-
const std::vector<const DLTensor*>& outputs,
52-
const std::string& cudaSource,
53-
const std::string& deviceStr);
54-
55-
CudaCachedEntry(const CudaCacheEntryProto& buf);
56-
CudaCacheEntryProto toProtobuf() const;
57-
58-
struct Key {
59-
std::string id;
60-
CudaMappingOptions mappingOptions;
61-
std::vector<detail::TensorInfo> inputs;
62-
std::vector<detail::TensorInfo> outputs;
63-
std::string deviceStr;
64-
std::string gitVersion;
65-
};
66-
67-
struct Values {
68-
std::string cudaSource;
69-
std::string kernelSpecializedName;
70-
std::vector<int> kernelParameters;
71-
Grid grid;
72-
Block block;
73-
};
74-
Key key;
75-
Values values;
76-
};
77-
78-
struct CudaCacheRetrievalResult {
79-
std::string source;
80-
std::string specializedName;
81-
std::vector<int> parameters;
82-
Grid grid;
83-
Block block;
84-
};
85-
86-
/**
87-
* CudaCache stores the Cuda source of optimized kernels
88-
* A CudaCache holds multiple CudaCachedEntry's.
89-
* Each CudaCachedEntry is split to two conceptual parts the key and the values.
90-
* The values are:
91-
* the specialized (wrt inputs) Cuda source code,
92-
* the kernel's specialized name,
93-
* the kernel parameters,
94-
* the Cuda block and grid dimensions
95-
* The key is:
96-
* the kernel/op's unique id (string),
97-
* the specialized input dimensions,
98-
* the isl options when the kernel was optimized,
99-
* the target architecture (string),
100-
* tc's version (string),
101-
*/
102-
class CudaCache : public Cache<CudaCache, CudaCachedEntry> {
103-
public:
104-
using ProtobufType = CudaCacheProto;
105-
using CachedEntry = CudaCachedEntry;
106-
using RetrievalResult = CudaCacheRetrievalResult;
107-
static std::shared_ptr<CudaCache>& getGlobalSharedCache();
108-
109-
CudaCache() = default;
110-
CudaCache(const CudaCacheProto& buf);
111-
CudaCacheProto toProtobuf() const;
112-
113-
/**
114-
* If op was previously cached and the inputs' shape, isl options, and the
115-
* target device are the same then this is a noop
116-
* Else (cudaSource, grid, block) is stored in the cache
117-
*/
118-
void cacheKernel(CudaCachedEntry&& entry);
119-
120-
/**
121-
* Returns the cache entry that matches op (id, isl options, target device)
122-
* and inputs' shapes.
123-
*/
124-
std::unique_ptr<CudaCacheRetrievalResult> retrieveKernel(
125-
const std::string& id,
126-
const CudaMappingOptions& options,
127-
const std::vector<const DLTensor*>& inputs,
128-
const std::vector<const DLTensor*>& outputs) const;
129-
130-
void removeEntriesNotInOptionsCache(const OptionsCache& oc);
131-
};
132-
13337
////////////////////////////////////////////////////////////////////////////////
13438
// OptionsCache
13539
////////////////////////////////////////////////////////////////////////////////
@@ -231,6 +135,98 @@ class OptionsCache : public Cache<OptionsCache, OptionsCachedEntry> {
231135
void keepOnlyBestCandidates(size_t numberToKeep);
232136
};
233137

138+
////////////////////////////////////////////////////////////////////////////////
139+
// CudaCache
140+
////////////////////////////////////////////////////////////////////////////////
141+
struct CudaCachedEntry {
142+
CudaCachedEntry(
143+
const std::string& id,
144+
const std::string& kernelSpecializedName,
145+
const std::vector<int>& kernelParameters,
146+
const Grid& grid,
147+
const Block& block,
148+
const CudaMappingOptions& mappingOptions,
149+
const std::vector<const DLTensor*>& inputs,
150+
const std::vector<const DLTensor*>& outputs,
151+
const std::string& cudaSource,
152+
const std::string& deviceStr);
153+
154+
CudaCachedEntry(const CudaCacheEntryProto& buf);
155+
CudaCacheEntryProto toProtobuf() const;
156+
157+
struct Key {
158+
std::string id;
159+
CudaMappingOptions mappingOptions;
160+
std::vector<detail::TensorInfo> inputs;
161+
std::vector<detail::TensorInfo> outputs;
162+
std::string deviceStr;
163+
std::string gitVersion;
164+
};
165+
166+
struct Values {
167+
std::string cudaSource;
168+
std::string kernelSpecializedName;
169+
std::vector<int> kernelParameters;
170+
Grid grid;
171+
Block block;
172+
};
173+
Key key;
174+
Values values;
175+
};
176+
177+
struct CudaCacheRetrievalResult {
178+
std::string source;
179+
std::string specializedName;
180+
std::vector<int> parameters;
181+
Grid grid;
182+
Block block;
183+
};
184+
185+
/**
186+
* CudaCache stores the Cuda source of optimized kernels
187+
* A CudaCache holds multiple CudaCachedEntry's.
188+
* Each CudaCachedEntry is split to two conceptual parts the key and the values.
189+
* The values are:
190+
* the specialized (wrt inputs) Cuda source code,
191+
* the kernel's specialized name,
192+
* the kernel parameters,
193+
* the Cuda block and grid dimensions
194+
* The key is:
195+
* the kernel/op's unique id (string),
196+
* the specialized input dimensions,
197+
* the isl options when the kernel was optimized,
198+
* the target architecture (string),
199+
* tc's version (string),
200+
*/
201+
class CudaCache : public Cache<CudaCache, CudaCachedEntry> {
202+
public:
203+
typedef CudaCacheProto ProtobufType;
204+
static std::shared_ptr<CudaCache>& getGlobalSharedCache();
205+
206+
CudaCache() = default;
207+
CudaCache(const CudaCacheProto& buf);
208+
CudaCacheProto toProtobuf() const;
209+
210+
/**
211+
* If op was previously cached and the inputs' shape, isl options, and the
212+
* target device are the same then this is a noop
213+
* Else (cudaSource, grid, block) is stored in the cache
214+
*/
215+
void cacheKernel(CudaCachedEntry&& entry);
216+
217+
/**
218+
* Returns the cache entry that matches op (id, isl options, target device)
219+
* and inputs' shapes.
220+
*/
221+
std::unique_ptr<CudaCacheRetrievalResult> retrieveKernel(
222+
const std::string& id,
223+
const CudaMappingOptions& options,
224+
const std::vector<const DLTensor*>& inputs,
225+
const std::vector<const DLTensor*>& outputs) const;
226+
227+
void removeEntriesNotInOptionsCache(const OptionsCache& oc);
228+
};
229+
234230
////////////////////////////////////////////////////////////////////////////////
235231
// ManualCudaCache
236232
////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)