Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Drop duplicate state in tuner / options cache which resulted in intermittent disagreements between tune and load from cache #576

Merged
merged 3 commits into from
Jul 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/benchmarks/python_overhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def matmul_bgrad(float(M,N) A, float(M,K) d_C) -> (d_B) {
tuner_config)
cache = MappingOptionsCache(cache_file.name)
top10 = cache.load(mm, "matmul", (A, B), 10)
assert top1.__str__() == top10[0].__str__()
assert top1.__str__() == top10[0].__str__(), (
"Expected {}\nGot {}\nTop10 are {}".format(
top1, top10[0], "\n".join(opt for opt in top10)))

# Compile and run with the new options
compilation_cache.compile("matmul", (A, B), top1)
Expand Down
2 changes: 2 additions & 0 deletions tc/aten/aten_autotuner-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ ATenAutotuner<Backend, Search>::tune(
const std::string& tcName,
const std::vector<at::Tensor>& inputs,
const std::vector<typename Backend::MappingOptionsType>& baseMappings,
size_t topK,
const tc::autotune::TuningParameterFixer& fixedParams) {
// TODO: some checks that inputs memory lives on the proper Backend device

Expand Down Expand Up @@ -91,6 +92,7 @@ ATenAutotuner<Backend, Search>::tune(
rawInputsPerDevice,
rawOutputsPerDevice,
baseMappings,
topK,
fixedParams);
}
} // namespace aten
Expand Down
1 change: 1 addition & 0 deletions tc/aten/aten_autotuner.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class ATenAutotuner : public tc::autotune::Autotuner<Backend, SearchStrategy> {
const std::string& tcEntryPoint,
const std::vector<at::Tensor>& inputs,
const std::vector<MappingOptionsType>& baseMappings,
size_t topK = 1,
const tc::autotune::TuningParameterFixer& fixedParams = {});

protected:
Expand Down
49 changes: 24 additions & 25 deletions tc/autotuner/autotuner-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#include <atomic>
#include <chrono>
#include <functional>
#include <numeric>
#include <thread>

Expand Down Expand Up @@ -48,8 +49,6 @@ TuningHarness<Backend>::TuningHarness(
baseMapping_(baseMapping),
inputs_(inputs),
outputs_(outputs),
bestTime_(Duration::max()),
bestMappingOptions_(baseMapping),
optionsCache_(optionsCache) {}

template <typename Backend>
Expand All @@ -67,13 +66,6 @@ void TuningHarness<Backend>::stopAfterCurrentIteration() {
stopRequested_ = true;
}

template <typename Backend>
const typename Backend::MappingOptionsType&
TuningHarness<Backend>::bestMappingOptions() const {
std::lock_guard<std::mutex> lock(bestTimeMutex_);
return bestMappingOptions_;
}

#define LOG_LINE_BY_LINE(GSTREAM, ISTREAM) \
for (std::string line; std::getline(ISTREAM, line);) { \
LOG(GSTREAM) << line; \
Expand Down Expand Up @@ -180,11 +172,14 @@ void TuningHarness<Backend>::doEvaluate(

std::vector<Duration> runtimes{Duration::max()};
try {
Duration bestTimeSoFar(Duration::max());
{
std::lock_guard<std::mutex> lock(bestTimeMutex_);
bestTimeSoFar = bestTime_;
}
auto vBest = optionsCache_->getTopKEntries(
lang::canonicalTc(tcTree_),
makeTensorInfoVector(inputs),
makeTensorInfoVector(outputs),
Backend::backendString(),
1);
Duration bestTimeSoFar =
(vBest.size() > 0) ? vBest[0].second : Duration::max();
auto prune = detail::skipExecutionOrWarmup<Backend>(
*pExecutor, outputs, inputs, bestTimeSoFar);
if (prune) {
Expand Down Expand Up @@ -234,15 +229,6 @@ void TuningHarness<Backend>::doEvaluate(
Backend::backendString(),
options,
prof);

// Save best time under lock
{
std::lock_guard<std::mutex> lock(bestTimeMutex_);
if (prof < bestTime_) {
bestTime_ = prof;
bestMappingOptions_ = options;
}
}
} // end while
}

Expand Down Expand Up @@ -310,7 +296,14 @@ void TuningHarness<Backend>::runOneIteration(
LOG(INFO) << "[TUNER][ITERATION LOG] best option so far:";
std::stringstream ssInfo;
typename Backend::MappingOptionsCppPrinter infoPrinter(ssInfo);
infoPrinter << bestMappingOptions();
auto vBest = optionsCache_->getTopKOptions(
lang::canonicalTc(tcTree_),
makeTensorInfoVector(inputs_.begin()->second),
makeTensorInfoVector(outputs_.begin()->second),
Backend::backendString(),
1);
CHECK_GT(vBest.size(), 0);
infoPrinter << vBest[0];
LOG_LINE_BY_LINE(INFO, ssInfo);
}
searchStrategy.updateParameters();
Expand Down Expand Up @@ -426,6 +419,7 @@ Autotuner<Backend, SearchStrategy>::tune(
const std::unordered_map<size_t, std::vector<const DLConstTensor*>>& inputs,
std::unordered_map<size_t, std::vector<const DLTensor*>>& outputs,
const std::vector<typename Backend::MappingOptionsType>& baseMappings,
size_t topK,
const TuningParameterFixer& fixedParams) {
std::map<std::string, lang::TreeRef> tcEntryPointMap(tc::detail::parse(tc));
TC_CHECK_EQ(tcEntryPointMap.count(tcEntryPoint), 1u)
Expand Down Expand Up @@ -511,7 +505,12 @@ Autotuner<Backend, SearchStrategy>::tune(
std::rethrow_exception(tuningHarnessThreadEx);
}

return {tuningHarness.bestMappingOptions()};
return optionsCache->getTopKOptions(
lang::canonicalTc(tcEntryPointMap.at(tcEntryPoint)),
makeTensorInfoVector(inputs.begin()->second),
makeTensorInfoVector(outputs.begin()->second),
Backend::backendString(),
topK);
}
} // namespace autotune
} // namespace tc
9 changes: 1 addition & 8 deletions tc/autotuner/autotuner.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ class TuningHarness {
/// TODO: we should detect when we come from python and exit properly in C++.
void stopAfterCurrentIteration();

/// Under lock, returns the best mapping options found so far
const MappingOptionsType& bestMappingOptions() const;

private:
/// Traverse one iteration of candidates in parallel and evaluate their
/// runtimes
Expand All @@ -92,7 +89,6 @@ class TuningHarness {
/// This way it is easy to implement multi-threaded termination by just
/// taking an atomic counter and pushing/popping the queues under lock until
/// we have evaluated searchStrategy->population.size() compilation results.
mutable std::mutex bestTimeMutex_;
std::mutex executorsMutex_;
std::atomic_bool stopRequested_;
std::atomic_size_t currentCompilationJob_;
Expand All @@ -112,10 +108,6 @@ class TuningHarness {
const std::unordered_map<size_t, std::vector<const DLConstTensor*>> inputs_;
std::unordered_map<size_t, std::vector<const DLTensor*>> outputs_;

// results
Duration bestTime_;
MappingOptionsType bestMappingOptions_;

// backing options cache
std::shared_ptr<OptionsCache<Backend>> optionsCache_;
};
Expand Down Expand Up @@ -165,6 +157,7 @@ class Autotuner {
inputs,
std::unordered_map<size_t, std::vector<const DLTensor*>>& outputs,
const std::vector<MappingOptionsType>& baseMapping,
size_t topK = 1,
const TuningParameterFixer& fixedParams = TuningParameterFixer());

public:
Expand Down
28 changes: 23 additions & 5 deletions tc/autotuner/options_cache-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "tc/core/check.h"
#include "tc/core/compiler.h"
#include "tc/core/functional.h"
#include "tc/core/tensor.h"
#include "tc/core/utils/math.h"
#include "tc/core/utils/time.h"
Expand Down Expand Up @@ -235,8 +236,8 @@ std::vector<OptionsWithMedianAndRuntimes<Backend>> sortedOptions(
} // namespace detail

template <typename Backend>
std::vector<typename Backend::MappingOptionsType>
OptionsCache<Backend>::getTopKOptions(
std::vector<std::pair<typename Backend::MappingOptionsType, Duration>>
OptionsCache<Backend>::getTopKEntries(
const lang::CanonicalTcString& tc,
const std::vector<TensorInfo>& inputs,
const std::vector<TensorInfo>& outputs,
Expand All @@ -249,15 +250,32 @@ OptionsCache<Backend>::getTopKOptions(
if (sorted.size() == 0u) {
return {};
}
std::vector<typename Backend::MappingOptionsType> res;
res.reserve(K);
std::vector<std::pair<typename Backend::MappingOptionsType, Duration>> res;
res.reserve(std::min(K, sorted.size()));
for (size_t i = 0; i < std::min(K, sorted.size()); ++i) {
res.push_back(sorted[i].mappingOptions);
res.push_back(std::make_pair(sorted[i].mappingOptions, sorted[i].median));
}
++numberSuccessfulRetrievals;
return res;
}

template <typename Backend>
std::vector<typename Backend::MappingOptionsType>
OptionsCache<Backend>::getTopKOptions(
const lang::CanonicalTcString& tc,
const std::vector<TensorInfo>& inputs,
const std::vector<TensorInfo>& outputs,
const std::string& backendStr,
size_t K) const {
auto vBest = getTopKEntries(tc, inputs, outputs, backendStr, K);
using ReturnType = typename Backend::MappingOptionsType;
using ValueType = typename decltype(vBest)::value_type;
std::function<ReturnType(ValueType)> map = [](ValueType in) {
return in.first;
};
return tc::functional::Map(map, vBest);
}

template <typename Backend>
std::unordered_set<OptionsCacheKey, OptionsCacheKeyHash>
OptionsCache<Backend>::getKeys() const {
Expand Down
13 changes: 13 additions & 0 deletions tc/autotuner/options_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,19 @@ struct OptionsCache {
/// particular TC/inputs/outputs/device. Note that the result may be empty
/// (in particular if problem size is small and pruning threshold is too high
/// for the problem size).
/// \returns a vector of pair<mapping options, Duration>
std::vector<std::pair<typename Backend::MappingOptionsType, Duration>>
getTopKEntries(
const lang::CanonicalTcString& tc,
const std::vector<TensorInfo>& inputs,
const std::vector<TensorInfo>& outputs,
const std::string& backendStr,
size_t K) const;

/// Returns the top-K mapping options that have the best median runtime for a
/// particular TC/inputs/outputs/device. Note that the result may be empty
/// (in particular if problem size is small and pruning threshold is too high
/// for the problem size).
/// \returns a vector of mapping options
std::vector<typename Backend::MappingOptionsType> getTopKOptions(
const lang::CanonicalTcString& tc,
Expand Down
2 changes: 1 addition & 1 deletion tc/benchmarks/benchmark_fixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ struct Benchmark : public ::testing::Test {
geneticAutotuneATen(tc);
auto bestOptions = [&]() {
auto options = geneticAutotuneATen.tune(
kernelName, inputs, {baseMapping}, fixedParams);
kernelName, inputs, {baseMapping}, 1, fixedParams);
TC_CHECK_GE(options.size(), 1u) << "Benchmark mode: at least one "
<< "options expected";
return options[0];
Expand Down
2 changes: 0 additions & 2 deletions tc/core/polyhedral/functional.h → tc/core/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <vector>

namespace tc {
namespace polyhedral {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be done in the commit that moves functional.h

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks

namespace functional {

template <typename I>
Expand Down Expand Up @@ -178,5 +177,4 @@ R MapReduce(std::function<R(R, I, bool)> fun, const std::vector<I>& vec) {
}

} // namespace functional
} // namespace polyhedral
} // namespace tc
2 changes: 1 addition & 1 deletion tc/core/polyhedral/cuda/mapped_scop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
#include "tc/core/check.h"
#include "tc/core/cuda/cuda_libraries.h"
#include "tc/core/flags.h"
#include "tc/core/functional.h"
#include "tc/core/gpu.h"
#include "tc/core/polyhedral/cuda/codegen.h"
#include "tc/core/polyhedral/cuda/mapping_types.h"
#include "tc/core/polyhedral/cuda/memory_promotion_heuristic.h"
#include "tc/core/polyhedral/cuda/tighten_launch_bounds.h"
#include "tc/core/polyhedral/exceptions.h"
#include "tc/core/polyhedral/functional.h"
#include "tc/core/polyhedral/schedule_transforms.h"
#include "tc/core/polyhedral/schedule_tree_matcher.h"
#include "tc/core/polyhedral/schedule_utils.h"
Expand Down
2 changes: 1 addition & 1 deletion tc/core/polyhedral/schedule_transforms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

#include "tc/core/check.h"
#include "tc/core/constants.h"
#include "tc/core/polyhedral/functional.h"
#include "tc/core/functional.h"
#include "tc/core/polyhedral/mapping_types.h"
#include "tc/core/polyhedral/schedule_tree_elem.h"
#include "tc/core/polyhedral/schedule_tree_matcher.h"
Expand Down
2 changes: 1 addition & 1 deletion tc/core/polyhedral/schedule_transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include <unordered_set>
#include <vector>

#include "tc/core/polyhedral/functional.h"
#include "tc/core/functional.h"
#include "tc/core/polyhedral/mapping_types.h"
#include "tc/core/polyhedral/options.h"
#include "tc/core/polyhedral/schedule_tree.h"
Expand Down
2 changes: 1 addition & 1 deletion tc/core/polyhedral/schedule_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

#include "tc/core/check.h"
#include "tc/core/constants.h"
#include "tc/core/polyhedral/functional.h"
#include "tc/core/functional.h"
#include "tc/core/polyhedral/schedule_tree_elem.h"
#include "tc/core/scope_guard.h"
#include "tc/external/isl.h"
Expand Down
2 changes: 1 addition & 1 deletion tc/core/polyhedral/scop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
#include <vector>

#include "tc/core/check.h"
#include "tc/core/functional.h"
#include "tc/core/halide2isl.h"
#include "tc/core/polyhedral/body.h"
#include "tc/core/polyhedral/functional.h"
#include "tc/core/polyhedral/memory_promotion.h"
#include "tc/core/polyhedral/schedule_isl_conversion.h"
#include "tc/core/polyhedral/schedule_transforms.h"
Expand Down
1 change: 1 addition & 0 deletions tensor_comprehensions/pybinds/tclib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "tc/core/cuda/cuda_backend.h"
#include "tc/core/cuda/cuda_tc_executor.h"
#include "tc/core/flags.h"
#include "tc/core/functional.h"
#include "tc/core/tensor.h"
#include "tc/lang/canonicalize.h"

Expand Down
3 changes: 2 additions & 1 deletion test/cuda/test_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ struct ATenCompilationUnitTest : public ::testing::Test {
if (FLAGS_no_memory_promotion) {
fix.fixUseSharedMemory(false).fixUsePrivateMemory(false);
}
auto options = geneticAutotuneATen.tune(name, inputs, {baseMapping}, fix);
auto options = geneticAutotuneATen.tune(
name, inputs, {baseMapping}, std::numeric_limits<size_t>::max(), fix);
if (options.size() > 0) {
return options[0];
}
Expand Down
2 changes: 1 addition & 1 deletion test/test_cuda_mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
#include "tc/core/constants.h"
#include "tc/core/cuda/cuda_libraries.h"
#include "tc/core/cuda/cuda_mapping_options.h"
#include "tc/core/functional.h"
#include "tc/core/polyhedral/cuda/codegen.h"
#include "tc/core/polyhedral/cuda/mapped_scop.h"
#include "tc/core/polyhedral/functional.h"
#include "tc/core/polyhedral/mapping_types.h"
#include "tc/core/polyhedral/schedule_isl_conversion.h"
#include "tc/core/polyhedral/schedule_transforms.h"
Expand Down