Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 6d24f99

Browse files
Merge pull request #576 from nicolasvasilache/pr/debug
Drop duplicate state in tuner / options cache which resulted in intermittent disagreements between tune and load from cache
2 parents d7442e5 + 71c013c commit 6d24f99

17 files changed

+77
-49
lines changed

python/benchmarks/python_overhead.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def matmul_bgrad(float(M,N) A, float(M,K) d_C) -> (d_B) {
111111
tuner_config)
112112
cache = MappingOptionsCache(cache_file.name)
113113
top10 = cache.load(mm, "matmul", (A, B), 10)
114-
assert top1.__str__() == top10[0].__str__()
114+
assert top1.__str__() == top10[0].__str__(), (
115+
"Expected {}\nGot {}\nTop10 are {}".format(
116+
top1, top10[0], "\n".join(opt for opt in top10)))
115117

116118
# Compile and run with the new options
117119
compilation_cache.compile("matmul", (A, B), top1)

tc/aten/aten_autotuner-inl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ ATenAutotuner<Backend, Search>::tune(
5353
const std::string& tcName,
5454
const std::vector<at::Tensor>& inputs,
5555
const std::vector<typename Backend::MappingOptionsType>& baseMappings,
56+
size_t topK,
5657
const tc::autotune::TuningParameterFixer& fixedParams) {
5758
// TODO: some checks that inputs memory lives on the proper Backend device
5859

@@ -91,6 +92,7 @@ ATenAutotuner<Backend, Search>::tune(
9192
rawInputsPerDevice,
9293
rawOutputsPerDevice,
9394
baseMappings,
95+
topK,
9496
fixedParams);
9597
}
9698
} // namespace aten

tc/aten/aten_autotuner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class ATenAutotuner : public tc::autotune::Autotuner<Backend, SearchStrategy> {
8080
const std::string& tcEntryPoint,
8181
const std::vector<at::Tensor>& inputs,
8282
const std::vector<MappingOptionsType>& baseMappings,
83+
size_t topK = 1,
8384
const tc::autotune::TuningParameterFixer& fixedParams = {});
8485

8586
protected:

tc/autotuner/autotuner-inl.h

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
#include <atomic>
1717
#include <chrono>
18+
#include <functional>
1819
#include <numeric>
1920
#include <thread>
2021

@@ -48,8 +49,6 @@ TuningHarness<Backend>::TuningHarness(
4849
baseMapping_(baseMapping),
4950
inputs_(inputs),
5051
outputs_(outputs),
51-
bestTime_(Duration::max()),
52-
bestMappingOptions_(baseMapping),
5352
optionsCache_(optionsCache) {}
5453

5554
template <typename Backend>
@@ -67,13 +66,6 @@ void TuningHarness<Backend>::stopAfterCurrentIteration() {
6766
stopRequested_ = true;
6867
}
6968

70-
template <typename Backend>
71-
const typename Backend::MappingOptionsType&
72-
TuningHarness<Backend>::bestMappingOptions() const {
73-
std::lock_guard<std::mutex> lock(bestTimeMutex_);
74-
return bestMappingOptions_;
75-
}
76-
7769
#define LOG_LINE_BY_LINE(GSTREAM, ISTREAM) \
7870
for (std::string line; std::getline(ISTREAM, line);) { \
7971
LOG(GSTREAM) << line; \
@@ -180,11 +172,14 @@ void TuningHarness<Backend>::doEvaluate(
180172

181173
std::vector<Duration> runtimes{Duration::max()};
182174
try {
183-
Duration bestTimeSoFar(Duration::max());
184-
{
185-
std::lock_guard<std::mutex> lock(bestTimeMutex_);
186-
bestTimeSoFar = bestTime_;
187-
}
175+
auto vBest = optionsCache_->getTopKEntries(
176+
lang::canonicalTc(tcTree_),
177+
makeTensorInfoVector(inputs),
178+
makeTensorInfoVector(outputs),
179+
Backend::backendString(),
180+
1);
181+
Duration bestTimeSoFar =
182+
(vBest.size() > 0) ? vBest[0].second : Duration::max();
188183
auto prune = detail::skipExecutionOrWarmup<Backend>(
189184
*pExecutor, outputs, inputs, bestTimeSoFar);
190185
if (prune) {
@@ -234,15 +229,6 @@ void TuningHarness<Backend>::doEvaluate(
234229
Backend::backendString(),
235230
options,
236231
prof);
237-
238-
// Save best time under lock
239-
{
240-
std::lock_guard<std::mutex> lock(bestTimeMutex_);
241-
if (prof < bestTime_) {
242-
bestTime_ = prof;
243-
bestMappingOptions_ = options;
244-
}
245-
}
246232
} // end while
247233
}
248234

@@ -310,7 +296,14 @@ void TuningHarness<Backend>::runOneIteration(
310296
LOG(INFO) << "[TUNER][ITERATION LOG] best option so far:";
311297
std::stringstream ssInfo;
312298
typename Backend::MappingOptionsCppPrinter infoPrinter(ssInfo);
313-
infoPrinter << bestMappingOptions();
299+
auto vBest = optionsCache_->getTopKOptions(
300+
lang::canonicalTc(tcTree_),
301+
makeTensorInfoVector(inputs_.begin()->second),
302+
makeTensorInfoVector(outputs_.begin()->second),
303+
Backend::backendString(),
304+
1);
305+
CHECK_GT(vBest.size(), 0);
306+
infoPrinter << vBest[0];
314307
LOG_LINE_BY_LINE(INFO, ssInfo);
315308
}
316309
searchStrategy.updateParameters();
@@ -426,6 +419,7 @@ Autotuner<Backend, SearchStrategy>::tune(
426419
const std::unordered_map<size_t, std::vector<const DLConstTensor*>>& inputs,
427420
std::unordered_map<size_t, std::vector<const DLTensor*>>& outputs,
428421
const std::vector<typename Backend::MappingOptionsType>& baseMappings,
422+
size_t topK,
429423
const TuningParameterFixer& fixedParams) {
430424
std::map<std::string, lang::TreeRef> tcEntryPointMap(tc::detail::parse(tc));
431425
TC_CHECK_EQ(tcEntryPointMap.count(tcEntryPoint), 1u)
@@ -511,7 +505,12 @@ Autotuner<Backend, SearchStrategy>::tune(
511505
std::rethrow_exception(tuningHarnessThreadEx);
512506
}
513507

514-
return {tuningHarness.bestMappingOptions()};
508+
return optionsCache->getTopKOptions(
509+
lang::canonicalTc(tcEntryPointMap.at(tcEntryPoint)),
510+
makeTensorInfoVector(inputs.begin()->second),
511+
makeTensorInfoVector(outputs.begin()->second),
512+
Backend::backendString(),
513+
topK);
515514
}
516515
} // namespace autotune
517516
} // namespace tc

tc/autotuner/autotuner.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ class TuningHarness {
6767
/// TODO: we should detect when we come from python and exit properly in C++.
6868
void stopAfterCurrentIteration();
6969

70-
/// Under lock, returns the best mapping options found so far
71-
const MappingOptionsType& bestMappingOptions() const;
72-
7370
private:
7471
/// Traverse one iteration of candidates in parallel and evaluate their
7572
/// runtimes
@@ -92,7 +89,6 @@ class TuningHarness {
9289
/// This way it is easy to implement multi-threaded termination by just
9390
/// taking an atomic counter and pushing/popping the queues under lock until
9491
/// we have evaluated searchStrategy->population.size() compilation results.
95-
mutable std::mutex bestTimeMutex_;
9692
std::mutex executorsMutex_;
9793
std::atomic_bool stopRequested_;
9894
std::atomic_size_t currentCompilationJob_;
@@ -112,10 +108,6 @@ class TuningHarness {
112108
const std::unordered_map<size_t, std::vector<const DLConstTensor*>> inputs_;
113109
std::unordered_map<size_t, std::vector<const DLTensor*>> outputs_;
114110

115-
// results
116-
Duration bestTime_;
117-
MappingOptionsType bestMappingOptions_;
118-
119111
// backing options cache
120112
std::shared_ptr<OptionsCache<Backend>> optionsCache_;
121113
};
@@ -165,6 +157,7 @@ class Autotuner {
165157
inputs,
166158
std::unordered_map<size_t, std::vector<const DLTensor*>>& outputs,
167159
const std::vector<MappingOptionsType>& baseMapping,
160+
size_t topK = 1,
168161
const TuningParameterFixer& fixedParams = TuningParameterFixer());
169162

170163
public:

tc/autotuner/options_cache-inl.h

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include "tc/core/check.h"
3030
#include "tc/core/compiler.h"
31+
#include "tc/core/functional.h"
3132
#include "tc/core/tensor.h"
3233
#include "tc/core/utils/math.h"
3334
#include "tc/core/utils/time.h"
@@ -235,8 +236,8 @@ std::vector<OptionsWithMedianAndRuntimes<Backend>> sortedOptions(
235236
} // namespace detail
236237

237238
template <typename Backend>
238-
std::vector<typename Backend::MappingOptionsType>
239-
OptionsCache<Backend>::getTopKOptions(
239+
std::vector<std::pair<typename Backend::MappingOptionsType, Duration>>
240+
OptionsCache<Backend>::getTopKEntries(
240241
const lang::CanonicalTcString& tc,
241242
const std::vector<TensorInfo>& inputs,
242243
const std::vector<TensorInfo>& outputs,
@@ -249,15 +250,32 @@ OptionsCache<Backend>::getTopKOptions(
249250
if (sorted.size() == 0u) {
250251
return {};
251252
}
252-
std::vector<typename Backend::MappingOptionsType> res;
253-
res.reserve(K);
253+
std::vector<std::pair<typename Backend::MappingOptionsType, Duration>> res;
254+
res.reserve(std::min(K, sorted.size()));
254255
for (size_t i = 0; i < std::min(K, sorted.size()); ++i) {
255-
res.push_back(sorted[i].mappingOptions);
256+
res.push_back(std::make_pair(sorted[i].mappingOptions, sorted[i].median));
256257
}
257258
++numberSuccessfulRetrievals;
258259
return res;
259260
}
260261

262+
template <typename Backend>
263+
std::vector<typename Backend::MappingOptionsType>
264+
OptionsCache<Backend>::getTopKOptions(
265+
const lang::CanonicalTcString& tc,
266+
const std::vector<TensorInfo>& inputs,
267+
const std::vector<TensorInfo>& outputs,
268+
const std::string& backendStr,
269+
size_t K) const {
270+
auto vBest = getTopKEntries(tc, inputs, outputs, backendStr, K);
271+
using ReturnType = typename Backend::MappingOptionsType;
272+
using ValueType = typename decltype(vBest)::value_type;
273+
std::function<ReturnType(ValueType)> map = [](ValueType in) {
274+
return in.first;
275+
};
276+
return tc::functional::Map(map, vBest);
277+
}
278+
261279
template <typename Backend>
262280
std::unordered_set<OptionsCacheKey, OptionsCacheKeyHash>
263281
OptionsCache<Backend>::getKeys() const {

tc/autotuner/options_cache.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,19 @@ struct OptionsCache {
132132
/// particular TC/inputs/outputs/device. Note that the result may be empty
133133
/// (in particular if problem size is small and pruning threshold is too high
134134
/// for the problem size).
135+
/// \returns a vector of pair<mapping options, Duration>
136+
std::vector<std::pair<typename Backend::MappingOptionsType, Duration>>
137+
getTopKEntries(
138+
const lang::CanonicalTcString& tc,
139+
const std::vector<TensorInfo>& inputs,
140+
const std::vector<TensorInfo>& outputs,
141+
const std::string& backendStr,
142+
size_t K) const;
143+
144+
/// Returns the top-K mapping options that have the best median runtime for a
145+
/// particular TC/inputs/outputs/device. Note that the result may be empty
146+
/// (in particular if problem size is small and pruning threshold is too high
147+
/// for the problem size).
135148
/// \returns a vector of mapping options
136149
std::vector<typename Backend::MappingOptionsType> getTopKOptions(
137150
const lang::CanonicalTcString& tc,

tc/benchmarks/benchmark_fixture.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ struct Benchmark : public ::testing::Test {
188188
geneticAutotuneATen(tc);
189189
auto bestOptions = [&]() {
190190
auto options = geneticAutotuneATen.tune(
191-
kernelName, inputs, {baseMapping}, fixedParams);
191+
kernelName, inputs, {baseMapping}, 1, fixedParams);
192192
TC_CHECK_GE(options.size(), 1u) << "Benchmark mode: at least one "
193193
<< "options expected";
194194
return options[0];

tc/core/polyhedral/functional.h renamed to tc/core/functional.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include <vector>
2222

2323
namespace tc {
24-
namespace polyhedral {
2524
namespace functional {
2625

2726
template <typename I>
@@ -178,5 +177,4 @@ R MapReduce(std::function<R(R, I, bool)> fun, const std::vector<I>& vec) {
178177
}
179178

180179
} // namespace functional
181-
} // namespace polyhedral
182180
} // namespace tc

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
#include "tc/core/check.h"
2727
#include "tc/core/cuda/cuda_libraries.h"
2828
#include "tc/core/flags.h"
29+
#include "tc/core/functional.h"
2930
#include "tc/core/gpu.h"
3031
#include "tc/core/polyhedral/cuda/codegen.h"
3132
#include "tc/core/polyhedral/cuda/mapping_types.h"
3233
#include "tc/core/polyhedral/cuda/memory_promotion_heuristic.h"
3334
#include "tc/core/polyhedral/cuda/tighten_launch_bounds.h"
3435
#include "tc/core/polyhedral/exceptions.h"
35-
#include "tc/core/polyhedral/functional.h"
3636
#include "tc/core/polyhedral/schedule_transforms.h"
3737
#include "tc/core/polyhedral/schedule_tree_matcher.h"
3838
#include "tc/core/polyhedral/schedule_utils.h"

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
#include "tc/core/check.h"
3232
#include "tc/core/constants.h"
33-
#include "tc/core/polyhedral/functional.h"
33+
#include "tc/core/functional.h"
3434
#include "tc/core/polyhedral/mapping_types.h"
3535
#include "tc/core/polyhedral/schedule_tree_elem.h"
3636
#include "tc/core/polyhedral/schedule_tree_matcher.h"

tc/core/polyhedral/schedule_transforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include <unordered_set>
2222
#include <vector>
2323

24-
#include "tc/core/polyhedral/functional.h"
24+
#include "tc/core/functional.h"
2525
#include "tc/core/polyhedral/mapping_types.h"
2626
#include "tc/core/polyhedral/options.h"
2727
#include "tc/core/polyhedral/schedule_tree.h"

tc/core/polyhedral/schedule_tree.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
#include "tc/core/check.h"
3131
#include "tc/core/constants.h"
32-
#include "tc/core/polyhedral/functional.h"
32+
#include "tc/core/functional.h"
3333
#include "tc/core/polyhedral/schedule_tree_elem.h"
3434
#include "tc/core/scope_guard.h"
3535
#include "tc/external/isl.h"

tc/core/polyhedral/scop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
#include <vector>
2525

2626
#include "tc/core/check.h"
27+
#include "tc/core/functional.h"
2728
#include "tc/core/halide2isl.h"
2829
#include "tc/core/polyhedral/body.h"
29-
#include "tc/core/polyhedral/functional.h"
3030
#include "tc/core/polyhedral/memory_promotion.h"
3131
#include "tc/core/polyhedral/schedule_isl_conversion.h"
3232
#include "tc/core/polyhedral/schedule_transforms.h"

tensor_comprehensions/pybinds/tclib.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "tc/core/cuda/cuda_backend.h"
3434
#include "tc/core/cuda/cuda_tc_executor.h"
3535
#include "tc/core/flags.h"
36+
#include "tc/core/functional.h"
3637
#include "tc/core/tensor.h"
3738
#include "tc/lang/canonicalize.h"
3839

test/cuda/test_autotuner.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ struct ATenCompilationUnitTest : public ::testing::Test {
7272
if (FLAGS_no_memory_promotion) {
7373
fix.fixUseSharedMemory(false).fixUsePrivateMemory(false);
7474
}
75-
auto options = geneticAutotuneATen.tune(name, inputs, {baseMapping}, fix);
75+
auto options = geneticAutotuneATen.tune(
76+
name, inputs, {baseMapping}, std::numeric_limits<size_t>::max(), fix);
7677
if (options.size() > 0) {
7778
return options[0];
7879
}

test/test_cuda_mapper.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
#include "tc/core/constants.h"
2828
#include "tc/core/cuda/cuda_libraries.h"
2929
#include "tc/core/cuda/cuda_mapping_options.h"
30+
#include "tc/core/functional.h"
3031
#include "tc/core/polyhedral/cuda/codegen.h"
3132
#include "tc/core/polyhedral/cuda/mapped_scop.h"
32-
#include "tc/core/polyhedral/functional.h"
3333
#include "tc/core/polyhedral/mapping_types.h"
3434
#include "tc/core/polyhedral/schedule_isl_conversion.h"
3535
#include "tc/core/polyhedral/schedule_transforms.h"

0 commit comments

Comments
 (0)