Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 58207a9

Browse files
Merge pull request #506 from nicolasvasilache/pr/drop-implicit-tuner-cache-interactions
Drop implicit tuner cache interactions
2 parents 89906d9 + 2448f68 commit 58207a9

27 files changed

+243
-375
lines changed

README.md

Lines changed: 23 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -41,29 +41,27 @@ def tensordot(float(N, C1, C2, H, W) I0,
4141
O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w)
4242
}
4343
)TC";
44-
tc::ATenCompilationUnit<tc::CudaBackend> atCompl;
45-
atCompl.define(tc);
4644

4745
// 2. Allocate tensors with random data.
4846
at::Tensor I0 = at::CUDA(at::kFloat).rand({32, 8, 16, 17, 25});
4947
at::Tensor I1 = at::CUDA(at::kFloat).rand({32, 16, 2, 17, 25});
5048

5149
// 3. Run autotuning with evolutionary search starting from a naive option.
52-
auto options = tc::CudaMappingOptions::makeNaiveMappingOptions();
53-
tc::autotune::GeneticAutotunerATen geneticAutotuneATen(tc);
54-
auto bestOption = geneticAutotuneATen.tune(
55-
"/tmp/save_results", "tensordot", {I0, I1}, options);
56-
57-
// 4. Compile and run the TC with the best option.
58-
// Outputs get allocated; could also be pre-allocated and passed.
59-
auto handle = atCompl.compile("tensordot", {I0, I1}, bestOption.getValue());
60-
std::vector<at::Tensor> outputs;
61-
auto duration = atCompl.run("tensordot", {I0, I1}, outputs, handle, true);
62-
std::cout
63-
<< "tensordot size I0: " << I0.sizes() << ", "
64-
<< "size I1: " << I1.sizes() << " ran in: "
65-
<< std::chrono::duration_cast<std::chrono::microseconds>(duration).count()
66-
<< "us\n";
50+
auto naiveOptions = Backend::MappingOptionsType::makeNaiveMappingOptions();
51+
tc::aten::ATenAutotuner<tc::CudaBackend, tc::autotune::GeneticSearch>
52+
geneticAutotuneATen(tc);
53+
auto bestOption =
54+
geneticAutotuneATen.tune("tensordot", {I0, I1}, {naiveOptions});
55+
56+
// 4. Compile and run the TC with the best option after allocating output
57+
// tensors.
58+
auto pExecutor =
59+
tc::aten::compile<Backend>(tc, "tensordot", {I0, I1}, bestOption[0]);
60+
auto outputs = tc::aten::prepareOutputs(tc, "tensordot", {I0, I1});
61+
auto timings = tc::aten::profile(*pExecutor, {I0, I1}, outputs);
62+
std::cout << "tensordot size I0: " << I0.sizes() << ", "
63+
<< "size I1: " << I1.sizes()
64+
<< " ran in: " << timings.kernelRuntime.toMicroSeconds() << "us\n";
6765
}
6866
```
6967
@@ -76,15 +74,15 @@ for (auto sizes : std::vector<std::pair<at::IntList, at::IntList>>{
7674
{{4, 9, 7, 16, 14}, {4, 7, 3, 16, 14}},
7775
{{8, 5, 11, 10, 10}, {8, 11, 16, 10, 10}},
7876
}) {
79-
at::Tensor I0 = at::CUDA(at::kFloat).rand(sizes.first);
80-
at::Tensor I1 = at::CUDA(at::kFloat).rand(sizes.second);
81-
auto handle = atCompl.compile("tensordot", {I0, I1}, bestOption.getValue());
82-
std::vector<at::Tensor> outputs;
83-
auto duration = atCompl.run("tensordot", {I0, I1}, outputs, handle, true);
77+
at::Tensor I0 = makeATenTensor<Backend>(sizes.first);
78+
at::Tensor I1 = makeATenTensor<Backend>(sizes.second);
79+
auto pExecutor =
80+
tc::aten::compile<Backend>(tc, "tensordot", {I0, I1}, bestOption[0]);
81+
auto outputs = tc::aten::prepareOutputs(tc, "tensordot", {I0, I1});
82+
auto timings = tc::aten::profile(*pExecutor, {I0, I1}, outputs);
8483
std::cout << "tensordot size I0: " << I0.sizes() << ", "
85-
<< "size I1: " << I1.sizes() << " ran in: "
86-
<< std::chrono::duration_cast<std::chrono::microseconds>(duration)
87-
.count()
84+
<< "size I1: " << I1.sizes()
85+
<< " ran in: " << timings.kernelRuntime.toMicroSeconds()
8886
<< "us\n";
8987
}
9088
```
@@ -96,11 +94,9 @@ Putting it all together, one may see:
9694
[----------] Global test environment set-up.
9795
[----------] 1 test from TensorDot
9896
[ RUN ] TensorDot.SimpleAutotune
99-
Loading proto from: /tmp/save_results.options and /tmp/save_results.cuda
10097
Generation 0 Jobs(Compiled, GPU)/total (10, 10)/10 (best/median/worst)us: 226/4238/7345
10198
Generation 1 Jobs(Compiled, GPU)/total (10, 10)/10 (best/median/worst)us: 220/221/233
10299
Generation 2 Jobs(Compiled, GPU)/total (10, 10)/10 (best/median/worst)us: 220/221/234
103-
Dumping cache to /tmp/save_results.cuda/options
104100
tensordot size I0: [16, 8, 16, 17, 25], size I1: [16, 16, 2, 17, 25] ran in: 239us
105101
tensordot size I0: [4, 9, 7, 16, 14], size I1: [4, 7, 3, 16, 14] ran in: 56us
106102
tensordot size I0: [8, 5, 11, 10, 10], size I1: [8, 11, 16, 10, 10] ran in: 210us
@@ -112,32 +108,6 @@ tensordot size I0: [8, 5, 11, 10, 10], size I1: [8, 11, 16, 10, 10] ran in: 210u
112108
[ PASSED ] 1 test.
113109
```
114110
115-
Tuning results are then available and reusable in ```/tmp/save_results.cuda``` and ```/tmp/save_results.proto```.
116-
117-
Interestingly, note that running the same example again will start form the best saved results and improve upon them.
118-
Of course this has diminishing returns:
119-
```shell
120-
> build$ ./examples/example_simple
121-
[==========] Running 1 test from 1 test case.
122-
[----------] Global test environment set-up.
123-
[----------] 1 test from TensorDot
124-
[ RUN ] TensorDot.SimpleAutotune
125-
Loading proto from: /tmp/save_results.options and /tmp/save_results.cuda
126-
Generation 0 Jobs(Compiled, GPU)/total (10, 10)/10 (best/median/worst)us: 256/258/270
127-
Generation 1 Jobs(Compiled, GPU)/total (10, 10)/10 (best/median/worst)us: 158/255/616
128-
Generation 2 Jobs(Compiled, GPU)/total (10, 10)/10 (best/median/worst)us: 157/252/720
129-
Dumping cache to /tmp/save_results.cuda/options
130-
tensordot size I0: [16, 8, 16, 17, 25], size I1: [16, 16, 2, 17, 25] ran in: 172us
131-
tensordot size I0: [4, 9, 7, 16, 14], size I1: [4, 7, 3, 16, 14] ran in: 44us
132-
tensordot size I0: [8, 5, 11, 10, 10], size I1: [8, 11, 16, 10, 10] ran in: 88us
133-
[ OK ] TensorDot.SimpleAutotune (28232 ms)
134-
[----------] 1 test from TensorDot (28232 ms total)
135-
136-
[----------] Global test environment tear-down
137-
[==========] 1 test from 1 test case ran. (28232 ms total)
138-
[ PASSED ] 1 test.
139-
```
140-
141111
We have not yet characterized the precise fraction of peak performance we obtain but it is not uncommon to obtain 80%+ of peak shared memory bandwidth after autotuning. Solid register-level optimizations are still in the work but TC in its current form already addresses the productivity gap between the needs of research and the needs of production. Which is why we are excited to share it with the entire community and bring this collaborative effort in the open.
142112
143113
# Documentation

tc/aten/aten_autotuner-inl.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ std::vector<typename Backend::MappingOptionsType>
5252
ATenAutotuner<Backend, Search>::tune(
5353
const std::string& tcName,
5454
const std::vector<at::Tensor>& inputs,
55-
const typename Backend::MappingOptionsType& baseMapping,
56-
const std::string& cacheFileName,
55+
const std::vector<typename Backend::MappingOptionsType>& baseMappings,
5756
const tc::autotune::TuningParameterFixer& fixedParams) {
5857
// TODO: some checks that inputs memory lives on the proper Backend device
5958

@@ -91,8 +90,7 @@ ATenAutotuner<Backend, Search>::tune(
9190
tcName,
9291
rawInputsPerDevice,
9392
rawOutputsPerDevice,
94-
baseMapping,
95-
cacheFileName,
93+
baseMappings,
9694
fixedParams);
9795
}
9896
} // namespace aten

tc/aten/aten_autotuner.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ class ATenAutotuner : public tc::autotune::Autotuner<Backend, SearchStrategy> {
7979
std::vector<MappingOptionsType> tune(
8080
const std::string& tcEntryPoint,
8181
const std::vector<at::Tensor>& inputs,
82-
const MappingOptionsType& baseMapping,
83-
const std::string& cacheFileName = "",
82+
const std::vector<MappingOptionsType>& baseMappings,
8483
const tc::autotune::TuningParameterFixer& fixedParams = {});
8584

8685
protected:

tc/autotuner/autotuner-inl.h

Lines changed: 8 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -321,45 +321,6 @@ namespace {
321321
volatile std::sig_atomic_t sigint_ = 0;
322322
volatile std::sig_atomic_t sigterm_ = 0;
323323

324-
template <typename Backend>
325-
std::vector<typename Backend::MappingOptionsType> loadThroughCache(
326-
lang::TreeRef tree,
327-
std::shared_ptr<OptionsCache<Backend>> optionsCache,
328-
const std::string& cacheFileName,
329-
const std::vector<const DLConstTensor*>& inputs,
330-
const size_t numCandidates) {
331-
LOG_IF(INFO, FLAGS_debug_tuner)
332-
<< "Loading proto from: " << tc::makeOptionsFilename(cacheFileName)
333-
<< std::endl;
334-
if (!cacheFileName.empty()) {
335-
optionsCache->loadCacheFromFile(tc::makeOptionsFilename(cacheFileName));
336-
}
337-
auto outputs = tc::detail::inferOutputTensorInfo(tree, inputs);
338-
return optionsCache->getTopKOptions(
339-
canonicalTc(tree),
340-
makeTensorInfoVector(inputs),
341-
outputs,
342-
Backend::backendString(),
343-
numCandidates);
344-
}
345-
346-
template <typename Backend>
347-
void storeTopKInCache(
348-
const std::shared_ptr<OptionsCache<Backend>>& optionsCache,
349-
const std::string& cacheFilename) {
350-
if (cacheFilename.empty()) {
351-
LOG_IF(INFO, FLAGS_debug_tuner)
352-
<< "No filepath provided, not saving cache" << std::endl;
353-
} else {
354-
LOG_IF(INFO, FLAGS_debug_tuner)
355-
<< "Dumping cache to " << tc::makeOptionsFilename(cacheFilename)
356-
<< std::endl;
357-
OptionsCache<Backend> cache(*optionsCache);
358-
cache.pruneKeepTopK(tc::FLAGS_tuner_save_best_candidates_count);
359-
cache.storeCacheToFile(tc::makeOptionsFilename(cacheFilename));
360-
}
361-
}
362-
363324
void removeDuplicates(std::vector<size_t>& v) {
364325
std::sort(v.begin(), v.end());
365326
v.erase(std::unique(v.begin(), v.end()), v.end());
@@ -416,7 +377,7 @@ void setupTuningParameters(
416377

417378
template <typename Backend, typename SearchStrategy>
418379
Autotuner<Backend, SearchStrategy>::Autotuner()
419-
: optionsCache_(new OptionsCache<Backend>()) {}
380+
: optionsCache(new OptionsCache<Backend>()) {}
420381

421382
template <typename Backend, typename SearchStrategy>
422383
std::vector<typename Backend::MappingOptionsType>
@@ -425,8 +386,7 @@ Autotuner<Backend, SearchStrategy>::tune(
425386
const std::string& tcEntryPoint,
426387
const std::unordered_map<size_t, std::vector<const DLConstTensor*>>& inputs,
427388
std::unordered_map<size_t, std::vector<const DLTensor*>>& outputs,
428-
const typename Backend::MappingOptionsType& baseMapping,
429-
const std::string& cacheFileName,
389+
const std::vector<typename Backend::MappingOptionsType>& baseMappings,
430390
const TuningParameterFixer& fixedParams) {
431391
std::map<std::string, lang::TreeRef> tcEntryPointMap(tc::detail::parse(tc));
432392
TC_CHECK_EQ(tcEntryPointMap.count(tcEntryPoint), 1u)
@@ -438,28 +398,13 @@ Autotuner<Backend, SearchStrategy>::tune(
438398
setupTuningParameters(inputs.begin()->second, modelConfiguration);
439399
modelConfiguration.fixParameters(fixedParams);
440400

441-
// Build starting points from baseMapping + whatever we recover from cache
442-
std::vector<typename Backend::MappingOptionsType> startingPoints{baseMapping};
443-
auto restoredCandidates = loadThroughCache<Backend>(
444-
tcEntryPointMap.at(tcEntryPoint),
445-
optionsCache_,
446-
cacheFileName,
447-
inputs.begin()->second,
448-
FLAGS_tuner_gen_restore_number);
449-
if (restoredCandidates.size() > 0) {
450-
startingPoints.reserve(1 + restoredCandidates.size());
451-
std::move(
452-
restoredCandidates.begin(),
453-
restoredCandidates.end(),
454-
std::back_inserter(startingPoints));
455-
}
456-
457401
// Create initial configs based on options + model configuration
402+
const std::vector<typename Backend::MappingOptionsType> options{baseMappings};
458403
std::vector<TuningConfiguration> configs;
459-
configs.reserve(startingPoints.size());
404+
configs.reserve(options.size());
460405
std::transform(
461-
startingPoints.begin(),
462-
startingPoints.end(),
406+
options.begin(),
407+
options.end(),
463408
std::back_inserter(configs),
464409
[this, &fixedParams, &modelConfiguration](
465410
const typename Backend::MappingOptionsType& options) {
@@ -484,9 +429,9 @@ Autotuner<Backend, SearchStrategy>::tune(
484429
tcEntryPointMap.at(tcEntryPoint),
485430
inputs,
486431
outputs,
487-
baseMapping,
432+
options[0],
488433
fixedParams,
489-
optionsCache_);
434+
optionsCache);
490435

491436
// Setup handlers
492437
sigterm_ = 0;
@@ -505,10 +450,6 @@ Autotuner<Backend, SearchStrategy>::tune(
505450
try {
506451
tuningHarness.run(searchStrategy);
507452
} catch (const std::exception& e) {
508-
std::cerr << "Exception during autotuning: " << e.what()
509-
<< "\n dumping cache to "
510-
<< tc::makeOptionsFilename(cacheFileName) << std::endl;
511-
storeTopKInCache<Backend>(optionsCache_, cacheFileName);
512453
tuningHarnessThreadEx = std::current_exception();
513454
}
514455
tuningHarnessFinished = true;
@@ -517,11 +458,9 @@ Autotuner<Backend, SearchStrategy>::tune(
517458
std::this_thread::sleep_for(std::chrono::milliseconds(100));
518459
if (sigint_) {
519460
tuningHarness.stopAfterCurrentIteration();
520-
storeTopKInCache<Backend>(optionsCache_, cacheFileName);
521461
}
522462
if (sigterm_) {
523463
std::cerr << "Autotuning aborted." << std::endl;
524-
storeTopKInCache<Backend>(optionsCache_, cacheFileName);
525464
std::abort();
526465
}
527466
}
@@ -532,8 +471,6 @@ Autotuner<Backend, SearchStrategy>::tune(
532471
std::rethrow_exception(tuningHarnessThreadEx);
533472
}
534473

535-
storeTopKInCache<Backend>(optionsCache_, cacheFileName);
536-
537474
return {tuningHarness.bestMappingOptions()};
538475
}
539476
} // namespace autotune

tc/autotuner/autotuner.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,14 @@ class Autotuner {
164164
const std::unordered_map<size_t, std::vector<const DLConstTensor*>>&
165165
inputs,
166166
std::unordered_map<size_t, std::vector<const DLTensor*>>& outputs,
167-
const MappingOptionsType& baseMapping,
168-
const std::string& cacheFileName = "",
167+
const std::vector<MappingOptionsType>& baseMapping,
169168
const TuningParameterFixer& fixedParams = TuningParameterFixer());
170169

171-
private:
172-
std::shared_ptr<OptionsCache<Backend>> optionsCache_;
170+
public:
171+
/// This is accessed by multiple threads in the tuning harness.
172+
/// Even though manipulations are threadsafe, you want to be sure tuning
173+
/// has finished before accessing the optionsCache.
174+
std::shared_ptr<OptionsCache<Backend>> optionsCache;
173175
};
174176

175177
/// Helper functions that need specializing for various backends.

tc/autotuner/options_cache-inl.h

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <llvm/ADT/Optional.h>
2828

2929
#include "tc/core/check.h"
30+
#include "tc/core/compiler.h"
3031
#include "tc/core/tensor.h"
3132
#include "tc/core/utils/math.h"
3233
#include "tc/core/utils/time.h"
@@ -163,12 +164,10 @@ void OptionsCache<Backend>::storeCacheToFile(
163164
std::lock_guard<std::mutex> lock(mutex);
164165
std::fstream serialized(
165166
filename, std::ios::binary | std::ios::trunc | std::ios::out);
166-
if (!serialized.is_open()) {
167-
LOG(ERROR) << "Failed to open the output stream for dumping protobuf: "
168-
<< filename;
169-
} else {
170-
proto.SerializePartialToOstream(&serialized);
171-
}
167+
TC_CHECK(serialized.is_open(), std::invalid_argument)
168+
<< "Failed to open the output stream for dumping protobuf: "
169+
<< filename;
170+
proto.SerializePartialToOstream(&serialized);
172171
}
173172
}
174173

@@ -317,9 +316,37 @@ void OptionsCache<Backend>::fromProtobuf(
317316
}
318317
}
319318

320-
} // namespace autotune
319+
template <typename Backend>
320+
std::vector<typename Backend::MappingOptionsType> loadTopKFromCacheFile(
321+
const std::string& tc,
322+
const std::string& entryPoint,
323+
const std::string& cacheFilename,
324+
const std::vector<const DLConstTensor*>& inputs,
325+
size_t count) {
326+
OptionsCache<Backend> optionsCache;
327+
optionsCache.loadCacheFromFile(cacheFilename);
328+
auto outputs = tc::inferOutputTensorInfo(tc, entryPoint, inputs);
329+
return optionsCache.getTopKOptions(
330+
lang::canonicalTc(tc::detail::parse(tc).at(entryPoint)),
331+
tc::makeTensorInfoVector(inputs),
332+
outputs,
333+
Backend::backendString(),
334+
count);
335+
}
321336

322-
inline std::string makeOptionsFilename(const std::string& fn) {
323-
return fn + ".options";
337+
template <typename Backend>
338+
void appendTopKToCacheFile(
339+
const std::shared_ptr<OptionsCache<Backend>>& cache,
340+
const std::string& cacheFilename,
341+
uint32_t count) {
342+
OptionsCache<Backend> copy(*cache);
343+
copy.pruneKeepTopK(count);
344+
auto proto = copy.toProtobuf();
345+
OptionsCache<Backend> optionsCache;
346+
optionsCache.loadCacheFromFile(cacheFilename);
347+
optionsCache.fromProtobuf(proto);
348+
optionsCache.storeCacheToFile(cacheFilename);
324349
}
350+
351+
} // namespace autotune
325352
} // namespace tc

0 commit comments

Comments
 (0)