Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 5e21c1a

Browse files
authored
Merge pull request #259 from facebookresearch/cache_keys
Use canonicalized TC as part of cache keys
2 parents d88c6c3 + 38ed719 commit 5e21c1a

File tree

8 files changed

+69
-16
lines changed

8 files changed

+69
-16
lines changed

include/tc/autotuner/utils/utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ std::vector<CudaMappingOptions> restoreCandidates(
4949
const std::vector<const DLTensor*>& inputs,
5050
const std::vector<const DLTensor*>& outputs);
5151

52+
std::vector<CudaMappingOptions> restoreCandidates(
53+
const lang::TreeRef& tc,
54+
const std::vector<const DLTensor*>& inputs,
55+
const std::vector<const DLTensor*>& outputs);
56+
5257
llvm::Optional<CudaMappingOptions> getBestOptions(
5358
const std::string& id,
5459
const std::vector<const DLTensor*>& inputs,

include/tc/core/tc_executor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ class TcExecutor {
124124

125125
tc2halide::HalideComponents halideComponents_;
126126
lang::TreeRef tcTree_;
127+
std::string cacheKeyId;
127128
};
128129

129130
// templating to match both const and non-const DLTensor pointers

src/autotuner/genetic_autotuner.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ std::vector<CudaMappingOptions> GeneticAutotuner::load(
8080
ExecutionEngine<CudaTcExecutor> ee;
8181
ee.define(tc_);
8282
auto outputs = ee.inferOutputTensorInfo(tcName, inputs);
83-
return tc::autotune::restoreCandidates(tcName, inputs, outputs);
83+
return tc::autotune::restoreCandidates(
84+
tcNameMap_.at(tcName), inputs, outputs);
8485
}
8586

8687
namespace {

src/autotuner/utils/utils.cc

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
#include "tc/autotuner/utils/utils.h"
2121
#include "tc/core/cuda/cuda_compilation_cache.h"
2222
#include "tc/core/utils/math.h"
23+
#include "tc/lang/canonicalize.h"
24+
#include "tc/lang/parser.h"
25+
#include "tc/lang/sema.h"
26+
#include "tc/lang/tree.h"
2327

2428
namespace tc {
2529
namespace autotune {
@@ -70,11 +74,20 @@ std::vector<OptionsWithMedianTime> getOptionsAndMedianRuntimes(
7074
return c;
7175
}
7276

77+
namespace {
78+
std::string canonicalTC(const lang::TreeRef& tc) {
79+
std::stringstream ss;
80+
ss << lang::canonicalize(tc);
81+
return ss.str();
82+
}
83+
} // namespace
84+
7385
std::vector<CudaMappingOptions> restoreCandidates(
74-
const std::string& id,
86+
const lang::TreeRef& tc,
7587
const std::vector<const DLTensor*>& inputs,
7688
const std::vector<const DLTensor*>& outputs) {
77-
auto candidates = getOptionsAndMedianRuntimes(id, inputs, outputs);
89+
auto candidates = getOptionsAndMedianRuntimes(
90+
canonicalTC(lang::Sema().checkFunction(tc)), inputs, outputs);
7891
LOG_IF(INFO, candidates.size() < FLAGS_tuner_gen_restore_number)
7992
<< "Requested " << FLAGS_tuner_gen_restore_number
8093
<< " candidates but there are only " << candidates.size() << " in cache.";
@@ -96,6 +109,13 @@ std::vector<CudaMappingOptions> restoreCandidates(
96109
return res;
97110
}
98111

112+
std::vector<CudaMappingOptions> restoreCandidates(
113+
const std::string& tc,
114+
const std::vector<const DLTensor*>& inputs,
115+
const std::vector<const DLTensor*>& outputs) {
116+
return restoreCandidates(lang::Parser(tc).parseFunction(), inputs, outputs);
117+
}
118+
99119
llvm::Optional<CudaMappingOptions> getBestOptions(
100120
const std::string& id,
101121
const std::vector<const DLTensor*>& inputs,

src/core/cuda/cuda_tc_executor.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ void CudaTcExecutor::compile(const tc::CudaMappingOptions& options) {
5656
auto cachedOp = [&]() -> std::unique_ptr<CudaCache::RetrievalResult> {
5757
if (ManualCudaCache::cacheEnabled()) {
5858
auto rr = ManualCudaCache::getCache()->retrieveKernel(
59-
// TODO:replace this with pretty printed TC
60-
executionInfo_.kernelName,
59+
cacheKeyId,
6160
extractRawPtrs(executionInfo_.inputsInfo),
6261
extractRawPtrs(executionInfo_.outputsInfo));
6362
if (rr) {
@@ -72,7 +71,7 @@ void CudaTcExecutor::compile(const tc::CudaMappingOptions& options) {
7271
<< "options string is empty, are you trying compile "
7372
<< "a dummy CudaTcExecutor?";
7473
return CudaCache::getCache()->retrieveKernel(
75-
executionInfo_.kernelName, // TODO:replace this with pretty printed TC
74+
cacheKeyId,
7675
options,
7776
extractRawPtrs(executionInfo_.inputsInfo),
7877
extractRawPtrs(executionInfo_.outputsInfo));
@@ -94,7 +93,7 @@ void CudaTcExecutor::compile(const tc::CudaMappingOptions& options) {
9493
LOG_IF(INFO, FLAGS_debug_tc_mapper) << "original grid: " << grid;
9594
LOG_IF(INFO, FLAGS_debug_tc_mapper) << "original block: " << block;
9695
CudaCache::getCache()->cacheKernel(
97-
executionInfo_.kernelName, // TODO:replace this with pretty printed TC
96+
cacheKeyId,
9897
options,
9998
extractRawPtrs(executionInfo_.inputsInfo),
10099
extractRawPtrs(executionInfo_.outputsInfo),
@@ -213,8 +212,7 @@ Duration CudaTcExecutor::run(
213212
profile);
214213
if (profile and OptionsCache::cacheEnabled()) {
215214
OptionsCache::getCache()->recordRuntime(
216-
// TODO:replace this with pretty printed TC
217-
executionInfo_.kernelName,
215+
cacheKeyId,
218216
CudaMappingOptions(executionInfo_.options),
219217
inputs,
220218
constPtrs(outputs),

src/core/tc_executor.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
*/
1616
#include "tc/core/tc_executor.h"
1717

18+
#include <sstream>
1819
#include <string>
1920

2021
#include "tc/core/utils/dlpack.h"
22+
#include "tc/lang/canonicalize.h"
2123
#include "tc/lang/parser.h"
2224
#include "tc/lang/sema.h"
2325

@@ -30,6 +32,13 @@ int toTypeToken(DLDataType dtype) {
3032
return lang::TypeInfo(lang::TypeInfo::Code(dtype.code), dtype.bits)
3133
.toScalarToken();
3234
}
35+
36+
std::string canonicalizedTc(const lang::TreeRef tcDefinition) {
37+
std::stringstream ss;
38+
ss << canonicalize(lang::Sema().checkFunction(tcDefinition));
39+
return ss.str();
40+
}
41+
3342
} // namespace
3443

3544
TcExecutor::TcExecutor(
@@ -49,6 +58,7 @@ TcExecutor::TcExecutor(
4958
// TODO: check if this is wrong, packed tensors may have 0 strides stored
5059
executionInfo_.outputsInfo =
5160
tc::inferOutputTensorInfo(halideComponents_, inputsInfo);
61+
cacheKeyId = canonicalizedTc(tcDefinition);
5262
}
5363

5464
TcExecutor::~TcExecutor() {}

test/test_autotuner_utility.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ TEST(DivisorsAndPowers, Default) {
4343
}
4444

4545
std::vector<CudaMappingOptions> restoreCandidates(
46-
const std::string& kernelName,
46+
const std::string& tc,
4747
std::vector<at::Tensor>& inputs,
4848
std::vector<at::Tensor>& outputs) {
4949
auto inputsPair = toConstDlpackTensors(inputs);
@@ -54,13 +54,23 @@ std::vector<CudaMappingOptions> restoreCandidates(
5454
});
5555

5656
return tc::autotune::restoreCandidates(
57-
kernelName, inputsPair.first, outputsPair.first);
57+
tc, inputsPair.first, outputsPair.first);
5858
}
5959

6060
TEST(RestoreCandidates, NoCache) {
6161
std::vector<at::Tensor> inputs{at::CUDA(at::kFloat).rand({10, 16}),
6262
at::CUDA(at::kFloat).rand({16, 20})};
63-
ASSERT_THROW(restoreCandidates("bla", inputs, inputs), std::runtime_error);
63+
static constexpr auto tc = R"(
64+
def tc2(float(M,N) A, float(N,K) B) -> (output) {
65+
output(m, k) +=! A(m, nn) * B(nn, k) + 1
66+
})";
67+
ASSERT_THROW(restoreCandidates(tc, inputs, inputs), std::runtime_error);
68+
}
69+
70+
TEST(RestoreCandidates, NotATCid) {
71+
std::vector<at::Tensor> inputs{at::CUDA(at::kFloat).rand({10, 16}),
72+
at::CUDA(at::kFloat).rand({16, 20})};
73+
ASSERT_THROW(restoreCandidates("bla", inputs, inputs), lang::ErrorReport);
6474
}
6575

6676
static constexpr auto tc_ = R"(
@@ -89,7 +99,7 @@ TEST(RestoreCandidates, NoRuntimeRecorded) {
8999
atCompl.run("matmul", inputs, outputs_, handle);
90100

91101
FLAGS_tuner_gen_restore_number = 1;
92-
ASSERT_EQ(restoreCandidates("matmul", inputs, outputs_).size(), 0);
102+
ASSERT_EQ(restoreCandidates(tc_, inputs, outputs_).size(), 0);
93103
}
94104

95105
TEST(RestoreCandidates, Hit) {
@@ -110,11 +120,11 @@ TEST(RestoreCandidates, Hit) {
110120
atCompl.run("matmul", inputs, outputs_, handle, true);
111121

112122
FLAGS_tuner_gen_restore_number = 2;
113-
auto restored = restoreCandidates("matmul", inputs, outputs_);
123+
auto restored = restoreCandidates(tc_, inputs, outputs_);
114124
ASSERT_EQ(restored.size(), 2);
115125

116126
FLAGS_tuner_gen_restore_number = 1;
117-
restored = restoreCandidates("matmul", inputs, outputs_);
127+
restored = restoreCandidates(tc_, inputs, outputs_);
118128
ASSERT_EQ(restored.size(), 1);
119129
}
120130

test/test_tc_mapper.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
#include "tc/core/cuda/cuda_compilation_cache.h"
2525
#include "tc/core/cuda/cuda_tc_executor.h"
2626
#include "tc/core/scope_guard.h"
27+
#include "tc/lang/canonicalize.h"
28+
#include "tc/lang/sema.h"
29+
#include "tc/lang/tree.h"
2730

2831
#include "test_harness_aten_cuda.h"
2932

@@ -59,7 +62,12 @@ struct TcMapperTest : public ::testing::Test {
5962
tc::deleteDlmTensors(outputDLTensorsPair.second);
6063
});
6164
auto cached = tc::CudaCache::getCache()->retrieveKernel(
62-
name,
65+
[&]() {
66+
std::stringstream ss;
67+
ss << lang::canonicalize(
68+
lang::Sema().checkFunction(lang::Parser(tc).parseFunction()));
69+
return ss.str();
70+
}(),
6371
mappingOptions,
6472
inputDLTensorsPair.first,
6573
outputDLTensorsPair.first);

0 commit comments

Comments
 (0)