Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit fc9bb4f

Browse files
Merge pull request #264 from facebookresearch/fix_tuner_cache_and_tensordot
Fix bug introduced by #259
2 parents 4229ac4 + 85ff857 commit fc9bb4f

File tree

11 files changed

+56
-91
lines changed

11 files changed

+56
-91
lines changed

benchmarks/benchmark_fixture.h

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "tc/core/cuda/cuda_tc_executor.h"
3737
#include "tc/core/flags.h"
3838
#include "tc/core/scope_guard.h"
39+
#include "tc/lang/canonicalize.h"
3940

4041
#include <cublas_v2.h> // Must be the same as Caffe2
4142
#include <cuda_runtime_api.h>
@@ -69,25 +70,6 @@ std::vector<const DLTensor*> inferOutputTensorInfo(
6970
return atCompl.inferOutputTensorInfo(name, inputs);
7071
}
7172

72-
tc::CudaMappingOptions loadOptionsFromProto(
73-
const std::string cacheFilename,
74-
const std::string& name,
75-
const std::vector<at::Tensor>& inputs,
76-
const std::vector<const DLTensor*>& outputs) {
77-
tc::OptionsCache::enableCache();
78-
tc::OptionsCache::loadCacheFromProtobuf(cacheFilename);
79-
tc::CudaCache::enableCache();
80-
tc::CudaCache::loadCacheFromProtobuf(tc::makeCudaFilename(cacheFilename));
81-
tc::FLAGS_tuner_gen_restore_number = 1;
82-
83-
auto mappingOptions = [&]() {
84-
auto inputsPair = tc::toConstDlpackTensors(inputs);
85-
tc::ScopeGuard g([&]() { tc::deleteDlmTensors(inputsPair.second); });
86-
return tc::autotune::restoreCandidates(name, inputsPair.first, outputs);
87-
}();
88-
return mappingOptions[0];
89-
}
90-
9173
struct Benchmark : public ::testing::Test {
9274
void SetUp() {
9375
if (!FLAGS_disable_version_checks) {
@@ -289,7 +271,8 @@ struct Benchmark : public ::testing::Test {
289271
auto inputsPair = tc::toConstDlpackTensors(inputs);
290272
auto outputs = atCompl.inferOutputTensorInfo(name, inputs);
291273
tc::ScopeGuard g([&]() { tc::deleteDlmTensors(inputsPair.second); });
292-
return tc::autotune::restoreCandidates(name, inputsPair.first, outputs);
274+
return tc::autotune::restoreCandidates(
275+
lang::canonicalTc(tc), inputsPair.first, outputs);
293276
}();
294277
auto handle = atCompl.compile(name, inputs, mappingOptions[0]);
295278
std::vector<at::Tensor> outputs;

include/tc/autotuner/utils/utils.h

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,14 @@
2121
#include "tc/core/cuda/cuda.h"
2222
#include "tc/core/cuda/cuda_mapping_options.h"
2323
#include "tc/core/utils/dlpack.h"
24+
#include "tc/lang/canonicalize.h"
25+
#include "tc/lang/tree.h"
2426

2527
#include <llvm/ADT/Optional.h>
2628

2729
namespace tc {
2830
namespace autotune {
2931

30-
struct OptionsWithMedianTime {
31-
CudaMappingOptions options;
32-
Duration medianRuntime;
33-
};
34-
3532
/// Returns all the powers of 2 up to the first one that is larger than val
3633
/// and the result of ceil(val/pow2) for each of those powers of 2 (except for
3734
/// the larger one)
@@ -40,25 +37,29 @@ std::vector<std::size_t> powers2andCeilDivisors(std::size_t val);
4037
template <typename Vector, typename... Vectors>
4138
Vector mergeVectors(Vector&& v, Vectors&&... vs);
4239

43-
std::vector<OptionsWithMedianTime> getOptionsAndMedianRuntimes(
44-
const std::string& id,
45-
const std::vector<const DLTensor*>& inputs);
46-
47-
std::vector<CudaMappingOptions> restoreCandidates(
48-
const std::string& id,
49-
const std::vector<const DLTensor*>& inputs,
50-
const std::vector<const DLTensor*>& outputs);
51-
40+
/// The following API allows interacting with the autotuner caches.
41+
/// Caches generally take arbitrary strings for keys.
42+
/// The autotuner uses a canonicalized TC expression to load / store into
43+
/// caches. Add a layer of type safety to interact with these.
5244
std::vector<CudaMappingOptions> restoreCandidates(
53-
const lang::TreeRef& tc,
45+
const lang::CanonicalTcString& tc,
5446
const std::vector<const DLTensor*>& inputs,
5547
const std::vector<const DLTensor*>& outputs);
5648

5749
llvm::Optional<CudaMappingOptions> getBestOptions(
58-
const std::string& id,
50+
const lang::CanonicalTcString& id,
5951
const std::vector<const DLTensor*>& inputs,
6052
const std::vector<const DLTensor*>& outputs);
6153

54+
struct OptionsWithMedianTime {
55+
CudaMappingOptions options;
56+
Duration medianRuntime;
57+
};
58+
59+
std::vector<OptionsWithMedianTime> getOptionsAndMedianRuntimes(
60+
const lang::CanonicalTcString& id,
61+
const std::vector<const DLTensor*>& inputs);
62+
6263
} // namespace autotune
6364
} // namespace tc
6465

include/tc/core/tc_executor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
#include "tc/core/utils/dlpack.h"
2626
#include "tc/core/utils/time.h"
2727

28-
#include "tc/lang/parser.h"
28+
#include "tc/lang/canonicalize.h"
2929

3030
namespace tc {
3131

@@ -124,7 +124,7 @@ class TcExecutor {
124124

125125
tc2halide::HalideComponents halideComponents_;
126126
lang::TreeRef tcTree_;
127-
std::string cacheKeyId;
127+
lang::CanonicalTcString cacheKeyId_;
128128
};
129129

130130
// templating to match both const and non-const DLTensor pointers

include/tc/lang/canonicalize.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717

1818
#include <string>
1919

20+
#include "tc/lang/parser.h"
21+
#include "tc/lang/sema.h"
2022
#include "tc/lang/tree.h"
2123
#include "tc/lang/tree_views.h"
2224

2325
namespace lang {
2426

2527
// takes a tree after semantic analysis and create
2628
// a canonicalized version that is agnostic to the choice of identifiers
27-
TreeRef canonicalize(TreeRef tree) {
29+
inline TreeRef canonicalize(TreeRef tree) {
2830
struct Context {
2931
std::unordered_map<std::string, std::string> identMap;
3032
std::string rename(const std::string& name) {
@@ -53,4 +55,19 @@ TreeRef canonicalize(TreeRef tree) {
5355
Context ctx;
5456
return ctx.apply(tree);
5557
}
58+
59+
struct CanonicalTcString : public std::string {
60+
explicit CanonicalTcString(const std::string& s) : std::string(s) {}
61+
};
62+
63+
inline CanonicalTcString canonicalTc(const lang::TreeRef& tc) {
64+
std::stringstream ss;
65+
// TODO: use tcFormat when more robust
66+
ss << lang::canonicalize(lang::Sema().checkFunction(tc));
67+
return CanonicalTcString(ss.str());
68+
}
69+
70+
inline CanonicalTcString canonicalTc(const std::string& tc) {
71+
return canonicalTc(lang::Parser(tc).parseFunction());
72+
}
5673
} // namespace lang

src/autotuner/genetic_autotuner.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ std::vector<CudaMappingOptions> GeneticAutotuner::load(
8181
ee.define(tc_);
8282
auto outputs = ee.inferOutputTensorInfo(tcName, inputs);
8383
return tc::autotune::restoreCandidates(
84-
tcNameMap_.at(tcName), inputs, outputs);
84+
canonicalTc(tcNameMap_.at(tcName)), inputs, outputs);
8585
}
8686

8787
namespace {
@@ -186,7 +186,7 @@ llvm::Optional<CudaMappingOptions> GeneticAutotuner::tune(
186186

187187
CHECK_GT(inputs.size(), 0);
188188
return tc::autotune::getBestOptions(
189-
tcName, inputs.begin()->second, outputPtrs);
189+
canonicalTc(tcNameMap_.at(tcName)), inputs.begin()->second, outputPtrs);
190190
}
191191

192192
} // namespace detail

src/autotuner/utils/utils.cc

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121
#include "tc/core/cuda/cuda_compilation_cache.h"
2222
#include "tc/core/utils/math.h"
2323
#include "tc/lang/canonicalize.h"
24-
#include "tc/lang/parser.h"
25-
#include "tc/lang/sema.h"
26-
#include "tc/lang/tree.h"
2724

2825
namespace tc {
2926
namespace autotune {
@@ -56,7 +53,7 @@ std::vector<std::size_t> powers2andCeilDivisors(std::size_t val) {
5653
}
5754

5855
std::vector<OptionsWithMedianTime> getOptionsAndMedianRuntimes(
59-
const std::string& id,
56+
const lang::CanonicalTcString& id,
6057
const std::vector<const DLTensor*>& inputs,
6158
const std::vector<const DLTensor*>& outputs) {
6259
auto candidates =
@@ -74,20 +71,11 @@ std::vector<OptionsWithMedianTime> getOptionsAndMedianRuntimes(
7471
return c;
7572
}
7673

77-
namespace {
78-
std::string canonicalTC(const lang::TreeRef& tc) {
79-
std::stringstream ss;
80-
ss << lang::canonicalize(tc);
81-
return ss.str();
82-
}
83-
} // namespace
84-
8574
std::vector<CudaMappingOptions> restoreCandidates(
86-
const lang::TreeRef& tc,
75+
const lang::CanonicalTcString& tc,
8776
const std::vector<const DLTensor*>& inputs,
8877
const std::vector<const DLTensor*>& outputs) {
89-
auto candidates = getOptionsAndMedianRuntimes(
90-
canonicalTC(lang::Sema().checkFunction(tc)), inputs, outputs);
78+
auto candidates = getOptionsAndMedianRuntimes(tc, inputs, outputs);
9179
LOG_IF(INFO, candidates.size() < FLAGS_tuner_gen_restore_number)
9280
<< "Requested " << FLAGS_tuner_gen_restore_number
9381
<< " candidates but there are only " << candidates.size() << " in cache.";
@@ -109,15 +97,8 @@ std::vector<CudaMappingOptions> restoreCandidates(
10997
return res;
11098
}
11199

112-
std::vector<CudaMappingOptions> restoreCandidates(
113-
const std::string& tc,
114-
const std::vector<const DLTensor*>& inputs,
115-
const std::vector<const DLTensor*>& outputs) {
116-
return restoreCandidates(lang::Parser(tc).parseFunction(), inputs, outputs);
117-
}
118-
119100
llvm::Optional<CudaMappingOptions> getBestOptions(
120-
const std::string& id,
101+
const lang::CanonicalTcString& id,
121102
const std::vector<const DLTensor*>& inputs,
122103
const std::vector<const DLTensor*>& outputs) {
123104
auto bestOptions =

src/core/cuda/cuda_tc_executor.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ void CudaTcExecutor::compile(const tc::CudaMappingOptions& options) {
5656
auto cachedOp = [&]() -> std::unique_ptr<CudaCache::RetrievalResult> {
5757
if (ManualCudaCache::cacheEnabled()) {
5858
auto rr = ManualCudaCache::getCache()->retrieveKernel(
59-
cacheKeyId,
59+
cacheKeyId_,
6060
extractRawPtrs(executionInfo_.inputsInfo),
6161
extractRawPtrs(executionInfo_.outputsInfo));
6262
if (rr) {
@@ -71,7 +71,7 @@ void CudaTcExecutor::compile(const tc::CudaMappingOptions& options) {
7171
<< "options string is empty, are you trying compile "
7272
<< "a dummy CudaTcExecutor?";
7373
return CudaCache::getCache()->retrieveKernel(
74-
cacheKeyId,
74+
cacheKeyId_,
7575
options,
7676
extractRawPtrs(executionInfo_.inputsInfo),
7777
extractRawPtrs(executionInfo_.outputsInfo));
@@ -93,7 +93,7 @@ void CudaTcExecutor::compile(const tc::CudaMappingOptions& options) {
9393
LOG_IF(INFO, FLAGS_debug_tc_mapper) << "original grid: " << grid;
9494
LOG_IF(INFO, FLAGS_debug_tc_mapper) << "original block: " << block;
9595
CudaCache::getCache()->cacheKernel(
96-
cacheKeyId,
96+
cacheKeyId_,
9797
options,
9898
extractRawPtrs(executionInfo_.inputsInfo),
9999
extractRawPtrs(executionInfo_.outputsInfo),
@@ -212,7 +212,7 @@ Duration CudaTcExecutor::run(
212212
profile);
213213
if (profile and OptionsCache::cacheEnabled()) {
214214
OptionsCache::getCache()->recordRuntime(
215-
cacheKeyId,
215+
cacheKeyId_,
216216
CudaMappingOptions(executionInfo_.options),
217217
inputs,
218218
constPtrs(outputs),

src/core/tc_executor.cc

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020

2121
#include "tc/core/utils/dlpack.h"
2222
#include "tc/lang/canonicalize.h"
23-
#include "tc/lang/parser.h"
24-
#include "tc/lang/sema.h"
2523

2624
namespace tc {
2725

@@ -32,13 +30,6 @@ int toTypeToken(DLDataType dtype) {
3230
return lang::TypeInfo(lang::TypeInfo::Code(dtype.code), dtype.bits)
3331
.toScalarToken();
3432
}
35-
36-
std::string canonicalizedTc(const lang::TreeRef tcDefinition) {
37-
std::stringstream ss;
38-
ss << canonicalize(lang::Sema().checkFunction(tcDefinition));
39-
return ss.str();
40-
}
41-
4233
} // namespace
4334

4435
TcExecutor::TcExecutor(
@@ -49,7 +40,8 @@ TcExecutor::TcExecutor(
4940
: identifier(id),
5041
inputsInfo(dlutils::makeDLTensorVector(inputsInfo)),
5142
options(options),
52-
tcTree_(tcDefinition) {
43+
tcTree_(tcDefinition),
44+
cacheKeyId_(lang::canonicalTc(tcDefinition)) {
5345
executionInfo_.kernelName = lang::Def(tcTree_).name().name();
5446
halideComponents_ =
5547
tc2halide::translate(isl::with_exceptions::globalIslCtx(), tcTree_);
@@ -58,7 +50,6 @@ TcExecutor::TcExecutor(
5850
// TODO: check if this is wrong, packed tensors may have 0 strides stored
5951
executionInfo_.outputsInfo =
6052
tc::inferOutputTensorInfo(halideComponents_, inputsInfo);
61-
cacheKeyId = canonicalizedTc(tcDefinition);
6253
}
6354

6455
TcExecutor::~TcExecutor() {}

test/cuda/test_autotuner_utility.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "tc/core/cuda/cuda_compilation_cache.h"
2121
#include "tc/core/cuda/cuda_tc_executor.h"
2222
#include "tc/core/scope_guard.h"
23+
#include "tc/lang/canonicalize.h"
2324

2425
using namespace tc;
2526
using namespace autotune;
@@ -54,7 +55,7 @@ std::vector<CudaMappingOptions> restoreCandidates(
5455
});
5556

5657
return tc::autotune::restoreCandidates(
57-
tc, inputsPair.first, outputsPair.first);
58+
lang::canonicalTc(tc), inputsPair.first, outputsPair.first);
5859
}
5960

6061
TEST(RestoreCandidates, NoCache) {

test/test_lang.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,6 @@ TreeRef loadText(const std::string& text) {
147147
return Sema().checkFunction(Parser(text).parseFunction());
148148
}
149149

150-
std::string canonicalText(const std::string& text) {
151-
std::stringstream ss;
152-
ss << canonicalize(loadText(text));
153-
return ss.str();
154-
}
155-
156150
void testTcFormat() {
157151
static std::ios_base::Init initIostreams;
158152
auto source = R"(def fun2(float(B, N, M) X, float(B, M, K) Y) -> (Q) {
@@ -334,7 +328,7 @@ int main(int argc, char** argv) {
334328
Q(b, ii, j) += X(b, ii, k) * Y(b, k, j)
335329
}
336330
)";
337-
ASSERT(canonicalText(option_one) == canonicalText(option_two));
331+
ASSERT(lang::canonicalTc(option_one) == lang::canonicalTc(option_two));
338332

339333
testTcFormat();
340334

0 commit comments

Comments
 (0)