diff --git a/docs/source/framework/pytorch_integration/autotuning_layers.rst b/docs/source/framework/pytorch_integration/autotuning_layers.rst index 0aa81dd5f..b9d65f579 100644 --- a/docs/source/framework/pytorch_integration/autotuning_layers.rst +++ b/docs/source/framework/pytorch_integration/autotuning_layers.rst @@ -58,7 +58,6 @@ You can read about all the parameters here - :ref:`autotuner_parameters`. - :code:`threads` - set this to number of CPU cores available. - :code:`generations` - 5 to 10 generations is a good number. - :code:`pop_size` - 10 is usually reasonable. You can try 10 to 20. -- :code:`number_elites` - number of candidates preserved intact between generations. `1` is usually sufficient. - :code:`min_launch_total_threads` - If you have really input small sizes, set this to `1`. - :code:`gpus`: Number of gpus to use for autotuning. Default value is "0". Set this to "0,1" if you wish to use two gpus (for example). @@ -70,7 +69,7 @@ kernel timing. You can adopt the following parameter settings as starters for au .. code:: settings = { - "threads": 32, "generations": 2, "pop_size": 10, "number_elites": 1 + "threads": 32, "generations": 2, "pop_size": 10 } * The good defaults that run for a bit longer (in exchange for better performance): @@ -78,7 +77,7 @@ kernel timing. You can adopt the following parameter settings as starters for au .. code:: settings = { - "threads": 32, "generations": 5, "pop_size": 10, "number_elites": 1 + "threads": 32, "generations": 5, "pop_size": 10 } @@ -87,7 +86,7 @@ kernel timing. You can adopt the following parameter settings as starters for au .. code:: settings = { - "threads": 32, "generations": 25, "pop_size": 100, "number_elites": 10 + "threads": 32, "generations": 25, "pop_size": 100 } diff --git a/docs/source/tutorials/tutorial_tensordot_with_tc.rst b/docs/source/tutorials/tutorial_tensordot_with_tc.rst index 6a98695df..52cc3deac 100644 --- a/docs/source/tutorials/tutorial_tensordot_with_tc.rst +++ b/docs/source/tutorials/tutorial_tensordot_with_tc.rst @@ -132,7 +132,7 @@ later. You can control the amount of autotuning by changing the autotuner parameters. See :ref:`autotune_parameters` for how to change the settings. -For the setting ``settings={"generations": 25, "pop_size": 100, "number_elites": 10}``, we +For the setting ``settings={"generations": 25, "pop_size": 100}``, we get a decent kernel performance as shown in the screenshot below (tuned on one M40 GPU): .. figure:: ../_static/img/autotuning-py.jpg diff --git a/tc/autotuner/autotuner-inl.h b/tc/autotuner/autotuner-inl.h index c2792d113..9cfc80525 100644 --- a/tc/autotuner/autotuner-inl.h +++ b/tc/autotuner/autotuner-inl.h @@ -79,16 +79,16 @@ TuningHarness::bestMappingOptions() const { } template -template -void TuningHarness::doCompile(SearchStrategy& searchStrategy) { +template +void TuningHarness::doCompile(Candidates& candidates) { // Atomically fetch and add the next job until there are no jobs left while (true) { auto current = currentCompilationJob_.fetch_add(1); - if (current >= searchStrategy.population.size()) { + if (current >= candidates.size()) { break; } std::unique_ptr pExecutor(nullptr); - auto pConf = searchStrategy.population.at(current).get(); + auto pConf = candidates.at(current).get(); auto options = makeOptions(baseMapping_, *pConf); try { if (FLAGS_debug_tuner) { @@ -243,56 +243,76 @@ void TuningHarness::runOneIteration( size_t iteration) { // Define tensors per device once globally auto devices = detail::parseDevices(FLAGS_tuner_devices); - CHECK(executors_.empty()); - CHECK(configurations_.empty()); - - { - // Initialize for this round - currentCompilationJob_.store(0); - numEvaluations_.store(0); - Printer printer( - iteration, - searchStrategy.population.size(), - currentCompilationJob_, - numEvaluations_); - auto logIterations = FLAGS_tuner_gen_log_generations; - ScopeGuard sgPrinter([logIterations, &printer]() { - printer.stop(); - if (logIterations) { - printer.printAll(); - } - }); - - // Just spawn and join new threads for each iteration - std::vector cpuCompilationThreads; - cpuCompilationThreads.reserve(FLAGS_tuner_threads); - ScopeGuard sgCompilationThreads([&cpuCompilationThreads]() { - for (auto& cpuCompilationThread : cpuCompilationThreads) { - cpuCompilationThread.join(); - } - }); - for (size_t i = 0; i < FLAGS_tuner_threads; ++i) { - cpuCompilationThreads.emplace_back( - [this, &searchStrategy]() { this->doCompile(searchStrategy); }); - } + for (uint64_t step = 0; step < searchStrategy.stepsPerIteration; ++step) { + { + CHECK(executors_.empty()); + CHECK(configurations_.empty()); + auto& candidates = searchStrategy.candidatesOfStep(step); + auto firstNew = std::partition( + candidates.begin(), + candidates.end(), + [](const std::unique_ptr& c) { + return c->runtime != Duration::zero(); + }); + GeneticSearch::Population newCandidates( + std::distance(firstNew, candidates.end())); + std::move(firstNew, candidates.end(), newCandidates.begin()); + ScopeGuard candidatesSG([&]() { + std::move(newCandidates.begin(), newCandidates.end(), firstNew); + }); - // Just spawn and join new threads for each device - std::vector workerThreads; - workerThreads.reserve(devices.size()); - LOG_IF(INFO, tc::FLAGS_debug_tuner) - << "Start evaluation: " << devices.size() << " " << executors_.size() - << " " << configurations_.size(); - ScopeGuard sgDeviceWorkerThreads([&workerThreads]() { - for (auto& workerThread : workerThreads) { - workerThread.join(); + if (not newCandidates.empty()) { + auto populationSize = newCandidates.size(); + // Initialize for this round + currentCompilationJob_.store(0); + numEvaluations_.store(0); + Printer printer( + iteration, + step, + populationSize, + currentCompilationJob_, + numEvaluations_); + auto logIterations = FLAGS_tuner_gen_log_generations; + ScopeGuard sgPrinter([logIterations, &printer]() { + printer.stop(); + if (logIterations) { + printer.printAll(); + } + }); + + // Just spawn and join new threads for each iteration + std::vector cpuCompilationThreads; + cpuCompilationThreads.reserve(FLAGS_tuner_threads); + ScopeGuard sgCompilationThreads([&cpuCompilationThreads]() { + for (auto& cpuCompilationThread : cpuCompilationThreads) { + cpuCompilationThread.join(); + } + }); + for (size_t i = 0; i < FLAGS_tuner_threads; ++i) { + cpuCompilationThreads.emplace_back( + [this, &newCandidates]() { this->doCompile(newCandidates); }); + } + + // Just spawn and join new threads for each device + std::vector workerThreads; + workerThreads.reserve(devices.size()); + LOG_IF(INFO, tc::FLAGS_debug_tuner) + << "Start evaluation: " << devices.size() << " " + << executors_.size() << " " << configurations_.size(); + ScopeGuard sgDeviceWorkerThreads([&workerThreads]() { + for (auto& workerThread : workerThreads) { + workerThread.join(); + } + }); + for (auto device : devices) { + workerThreads.emplace_back( + [this, device, populationSize, &printer]() { + this->doEvaluate(device, populationSize, printer); + }); + } } - }); - auto populationSize = searchStrategy.population.size(); - for (auto device : devices) { - workerThreads.emplace_back([this, device, populationSize, &printer]() { - this->doEvaluate(device, populationSize, printer); - }); } + searchStrategy.finishStep(step); } // At this point everything is synchronized because out of scope, done @@ -303,7 +323,6 @@ void TuningHarness::runOneIteration( infoPrinter << bestMappingOptions(); LOG_LINE_BY_LINE(INFO, ssInfo); } - searchStrategy.updateParameters(); } } // namespace detail @@ -460,13 +479,15 @@ Autotuner::tune( }); // searchStrategy is passed to tuningHarness.run() + // XXX: this not generic SearchStrategy searchStrategy( configs, FLAGS_tuner_gen_generations, FLAGS_tuner_gen_pop_size, FLAGS_tuner_gen_crossover_rate, FLAGS_tuner_gen_mutation_rate, - FLAGS_tuner_gen_number_elites); + FLAGS_tuner_gen_mating_pool_size, + FLAGS_tuner_gen_selection_pool_size); // Create a tuning harness detail::TuningHarness tuningHarness( diff --git a/tc/autotuner/genetic_search.cc b/tc/autotuner/genetic_search.cc index 416528e88..a74c0b31c 100644 --- a/tc/autotuner/genetic_search.cc +++ b/tc/autotuner/genetic_search.cc @@ -16,6 +16,8 @@ #include "tc/autotuner/genetic_search.h" +#include +#include #include #include @@ -31,11 +33,8 @@ void randomizeParameter(Parameter& param, RNG& rng) { param.selectOption(paramIndex); } -template -void randomizePopulation( - GeneticSearch::Population::iterator begin, - GeneticSearch::Population::iterator end, - RNG& rng) { +template +void randomizePopulation(Iterator begin, Iterator end, RNG& rng) { for (auto candidate = begin; candidate != end; ++candidate) { auto& conf = (*candidate)->configuration; do { @@ -72,9 +71,35 @@ void mutate( } } -void normalizeVector(std::vector& v) { +double mean(std::vector& v) { + if (v.empty()) { + throw std::invalid_argument("Cannot compute the mean of an empty vector."); + } auto sum = std::accumulate(v.begin(), v.end(), 0.0); + return sum / v.size(); +} +double stdv(std::vector& v, double mean) { + std::vector diffs(v.size()); + std::transform(v.begin(), v.end(), diffs.begin(), [mean](double val) { + return val - mean; + }); + + auto squareSum = + std::inner_product(diffs.begin(), diffs.end(), diffs.begin(), 0.0); + return std::sqrt(squareSum / v.size()); +} + +void sigmaScale(std::vector& v) { + auto m = mean(v); + auto s = stdv(v, m); + std::transform(v.begin(), v.end(), v.begin(), [m, s](double val) { + return std::max(val - (m - 2 * s), 0.0); + }); +} + +void normalizeVector(std::vector& v) { + auto sum = std::accumulate(v.begin(), v.end(), 0.0); std::transform( v.begin(), v.end(), v.begin(), [sum](double v) { return v / sum; }); } @@ -90,6 +115,7 @@ std::vector computeNormalizedFitness( [](const std::unique_ptr& c) { return 1.0 / c->runtime.toMicroSeconds(); }); + sigmaScale(fitness); normalizeVector(fitness); return fitness; } @@ -131,7 +157,8 @@ void dropInvalidConfigurations(GeneticSearch::Population& population) { } // namespace #define VALIDATE() \ - CHECK_LT(numberElites, maxPopulationSize); \ + CHECK_LT(maxPopulationSize, matingPoolSize); \ + CHECK_LT(maxPopulationSize, selectionPoolSize); \ CHECK(mutationRate >= 0 and mutationRate <= 100) \ << "the mutation rate (" << mutationRate \ << ") should be in the [0,100] interval"; \ @@ -159,14 +186,16 @@ GeneticSearch::GeneticSearch( size_t populationSize, uint8_t crossOverRate, uint8_t mutationRate, - size_t numberElites) + size_t matingPoolSize, + size_t selectionPoolSize) : population(), lastBestConf(confs[0]), numGenerations(numGenerations), maxPopulationSize(populationSize), + matingPoolSize(matingPoolSize), + selectionPoolSize(selectionPoolSize), crossOverRate(crossOverRate), mutationRate(mutationRate), - numberElites(numberElites), rng{std::random_device{}()} { restoreRngState(rng); VALIDATE(); @@ -223,19 +252,33 @@ TuningConfiguration GeneticSearch::crossover( return a; } -void GeneticSearch::breed() { - auto accFitness = computeAccumulatedFitness(population); - Population new_population; - new_population.reserve(maxPopulationSize); - for (auto& p : population) { - new_population.push_back( - make_unique(p->configuration)); +std::vector GeneticSearch::stochasticUniversalSampling( + const std::vector& fitness) const { + std::vector matingPool; + matingPool.reserve(matingPoolSize); + + auto r = std::uniform_real_distribution(0, 1.0 / matingPoolSize)(rng); + size_t count = 0; + size_t i = 0; + while (count < matingPoolSize) { + while (r <= fitness[i]) { + matingPool.push_back(population[i]->configuration); + r += 1.0 / matingPoolSize; + ++count; + } + ++i; } + return matingPool; +} + +void GeneticSearch::breed() { + auto matingPool = + stochasticUniversalSampling(computeAccumulatedFitness(population)); - auto select = [&]() -> const TuningConfiguration& { - auto limit = std::uniform_real_distribution{}(rng); - auto lb = std::lower_bound(accFitness.begin(), accFitness.end(), limit); - return population.at(std::distance(accFitness.begin(), lb))->configuration; + auto select = [&]() -> TuningConfiguration& { + auto idx = std::uniform_int_distribution{ + size_t(0), matingPool.size() - 1}(rng); + return matingPool.at(idx); }; auto shouldCrossOver = [&]() -> bool { /* @@ -247,25 +290,44 @@ void GeneticSearch::breed() { return dist(rng); }; - while (new_population.size() < maxPopulationSize) { + while (selectionPool.size() < selectionPoolSize) { if (shouldCrossOver()) { auto parent1 = select(); auto parent2 = select(); auto parent3 = select(); - new_population.emplace_back(make_unique( + selectionPool.emplace_back(make_unique( crossover(parent1, parent2, parent3))); } else { - new_population.emplace_back( - make_unique(select())); + selectionPool.emplace_back(make_unique(select())); } } - population = std::move(new_population); } -void GeneticSearch::updateParameters() { - dropInvalidConfigurations(population); +bool GeneticSearch::resetPopulationIfNotEnoughCandidates() { + if (population.size() < minCandidatesForBreeding) { + LOG_IF(ERROR, FLAGS_debug_tuner) + << population.size() << " out of " << maxPopulationSize + << " candidates were valid and are not enough to form a new " + "generation. Likely, most of the tuning runs during this " + "generation were pruned for lack of parallelism in the " + "generated code. You can relax this constraints by setting " + "--tuner_min_launch_total_threads=1. This is mostly relevant " + "when autotuning a TC operating on small tensors. The next " + "generation will be randomly initialized."; + selectionPool.clear(); + for (size_t i = 0; i < selectionPoolSize; ++i) { + selectionPool.emplace_back( + make_unique(lastBestConf)); + } + // Don't lose the first one which was the best from before + randomizePopulation(selectionPool.begin() + 1, selectionPool.end(), rng); + return true; + } + return false; +} - // Sort population before taking any decision +namespace { +void sortByRuntime(GeneticSearch::Population& population) { std::sort( population.begin(), population.end(), @@ -275,38 +337,71 @@ void GeneticSearch::updateParameters() { checkRuntimeRecorded(b->runtime); return a->runtime < b->runtime; }); +} +} // namespace - // Update failsafe lastBestConf +void GeneticSearch::generateSelectionPool() { + dropInvalidConfigurations(population); + sortByRuntime(population); lastBestConf = population.size() > 0 ? population.front()->configuration : lastBestConf; + if (resetPopulationIfNotEnoughCandidates()) { + return; + } + selectionPool.clear(); + selectionPool.emplace_back(make_unique(lastBestConf)); + breed(); + for (size_t i = 1; i < selectionPool.size(); ++i) { + mutate(*selectionPool[i], mutationRate, mutateIterations, rng); + } +} - if (population.size() < minCandidatesForBreeding) { - LOG_IF(ERROR, FLAGS_debug_tuner) - << population.size() << " out of " << maxPopulationSize - << " candidates were valid and are not enough to form a new " - "generation. Likely, most of the tuning runs during this " - "generation were pruned for lack of parallelism in the " - "generated code. You can relax this constraints by setting " - "--tuner_min_launch_total_threads=1. This is mostly relevant " - "when autotuning a TC operating on small tensors. The next " - "generation will be randomly initialized."; - population.resize(0); - for (size_t i = 0; i < maxPopulationSize; ++i) { - population.emplace_back( +void GeneticSearch::selectSurvivors() { + dropInvalidConfigurations(selectionPool); + sortByRuntime(selectionPool); + population.clear(); + std::transform( + selectionPool.begin(), + selectionPool.begin() + std::min(selectionPool.size(), maxPopulationSize), + std::back_inserter(population), + [](const std::unique_ptr& c) { + CHECK(c); + return make_unique(*c); + }); + + if (selectionPool.size() < maxPopulationSize) { + auto numberMissing = maxPopulationSize - selectionPool.size(); + + for (size_t i = 0; i < numberMissing; ++i) { + selectionPool.emplace_back( make_unique(lastBestConf)); } - // Don't lose the first one which was the best from before - CHECK_LT(0u, population.size()); - randomizePopulation(population.begin() + 1, population.end(), rng); - return; + randomizePopulation( + selectionPool.rbegin(), selectionPool.rbegin() + numberMissing, rng); } +} - breed(); - for (size_t i = numberElites; i < population.size(); ++i) { - mutate(*population[i], mutationRate, mutateIterations, rng); +GeneticSearch::Population& GeneticSearch::candidatesOfStep(uint64_t step) { + if (step > 1) { + throw std::invalid_argument("GeneticSearch has only 2 steps."); + } + if (step == 0) { + return population; + } else { + return selectionPool; } } +void GeneticSearch::finishStep(uint64_t step) { + if (step > 1) { + throw std::invalid_argument("GeneticSearch has only 2 steps."); + } + if (step == 0) { + generateSelectionPool(); + } else { + selectSurvivors(); + } +} } // namespace autotune } // namespace tc diff --git a/tc/autotuner/genetic_search.h b/tc/autotuner/genetic_search.h index 5c5f7f02d..e139ae8e5 100644 --- a/tc/autotuner/genetic_search.h +++ b/tc/autotuner/genetic_search.h @@ -74,13 +74,20 @@ class GeneticSearch { size_t populationSize, uint8_t crossOverRate, uint8_t mutationRate, - size_t numberElites); + size_t matingPoolSize, + size_t selectionPoolSize); - void updateParameters(); + void generateSelectionPool(); + void selectSurvivors(); private: + std::vector stochasticUniversalSampling( + const std::vector& fitness) const; + void breed(); + bool resetPopulationIfNotEnoughCandidates(); + TuningConfiguration crossover( TuningConfiguration&, TuningConfiguration&, @@ -92,13 +99,19 @@ class GeneticSearch { using Population = std::vector>; + Population& candidatesOfStep(uint64_t); + void finishStep(uint64_t); + Population population; + Population selectionPool; TuningConfiguration lastBestConf; const size_t numGenerations; const size_t maxPopulationSize; + const size_t matingPoolSize; + const size_t selectionPoolSize; const uint8_t crossOverRate; const uint8_t mutationRate; - const size_t numberElites; + const size_t stepsPerIteration = 2; /* * c++11 seeding is (apparently) not of the highest quality: diff --git a/tc/autotuner/utils.cc b/tc/autotuner/utils.cc index 147db9690..9cfbf69b6 100644 --- a/tc/autotuner/utils.cc +++ b/tc/autotuner/utils.cc @@ -64,7 +64,7 @@ void Printer::printLoop() { std::this_thread::sleep_for(std::chrono::seconds(1)); std::stringstream ss; - ss << "Iteration " << iteration_; + ss << "Iteration.Step " << iteration_ << '.' << step_; ss << "\tJobs(Compiled, Evaluated)/total (" << std::min(total_, currentCompilationJob_.load()) << ", " << std::min(total_, numEvaluations_.load()) << ")/" << total_; @@ -100,10 +100,12 @@ void Printer::printLoop() { Printer::Printer( size_t iteration, + size_t step, size_t total, const std::atomic_size_t& currentCompilationJob, const std::atomic_size_t& numEvaluations) : iteration_(iteration), + step_(step), printerThread_([this]() { printLoop(); }), total_(total), currentCompilationJob_(currentCompilationJob), diff --git a/tc/autotuner/utils.h b/tc/autotuner/utils.h index b1ea673b8..824f9496a 100644 --- a/tc/autotuner/utils.h +++ b/tc/autotuner/utils.h @@ -47,6 +47,7 @@ class Printer { public: Printer( size_t iteration, + size_t step, size_t total, const std::atomic_size_t& currentCompilationJob, const std::atomic_size_t& numEvaluations); @@ -61,6 +62,7 @@ class Printer { void printLoop(); size_t iteration_; + size_t step_; std::vector runtimes_; mutable std::mutex runtimesMtx_; diff --git a/tc/core/flags.cc b/tc/core/flags.cc index 351a431b0..30f0d6882 100644 --- a/tc/core/flags.cc +++ b/tc/core/flags.cc @@ -63,6 +63,17 @@ DEFINE_uint32( tuner_gen_pop_size, 100, "Population size for genetic autotuning"); + +DEFINE_uint32( + tuner_gen_mating_pool_size, + 300, + "Mating pool size for genetic autotuning"); + +DEFINE_uint32( + tuner_gen_selection_pool_size, + 300, + "Selection pool size for genetic autotuning"); + DEFINE_uint32( tuner_gen_crossover_rate, 80, @@ -75,10 +86,6 @@ DEFINE_uint32( tuner_gen_generations, 25, "How many generations to run genetic tuning for"); -DEFINE_uint32( - tuner_gen_number_elites, - 10, - "The number of best candidates that are preserved intact between generations"); DEFINE_uint32(tuner_threads, 1, "Number of CPU threads to use when autotuning"); DEFINE_string( tuner_devices, diff --git a/tc/core/flags.h b/tc/core/flags.h index a4df4a74d..a486dc89b 100644 --- a/tc/core/flags.h +++ b/tc/core/flags.h @@ -41,10 +41,11 @@ DECLARE_uint32(benchmark_iterations); // Used in autotuning DECLARE_uint32(tuner_max_unroll_size); DECLARE_uint32(tuner_gen_pop_size); +DECLARE_uint32(tuner_gen_mating_pool_size); +DECLARE_uint32(tuner_gen_selection_pool_size); DECLARE_uint32(tuner_gen_crossover_rate); DECLARE_uint32(tuner_gen_mutation_rate); DECLARE_uint32(tuner_gen_generations); -DECLARE_uint32(tuner_gen_number_elites); DECLARE_uint32(tuner_threads); DECLARE_string(tuner_devices); DECLARE_bool(tuner_print_best); diff --git a/tc/core/utils/time.h b/tc/core/utils/time.h index 1686363b8..5b5421f69 100644 --- a/tc/core/utils/time.h +++ b/tc/core/utils/time.h @@ -70,6 +70,10 @@ struct Duration { return lhs.val_ == rhs.val_; } + friend inline bool operator!=(const Duration& lhs, const Duration& rhs) { + return lhs.val_ != rhs.val_; + } + private: std::chrono::microseconds val_; }; diff --git a/tc/examples/tensordot.cc b/tc/examples/tensordot.cc index 4c49806ef..b15c253b1 100644 --- a/tc/examples/tensordot.cc +++ b/tc/examples/tensordot.cc @@ -100,7 +100,8 @@ TEST(TensorDotGPU, SimpleAutotune) { // From root, run with: // ./build/examples/tensordot --tuner_threads=10 --tuner_gen_pop_size=10 -// --tuner_gen_generations=3 --tuner_gen_number_elites=4 +// --tuner_gen_generations=3 --tuner_gen_mating_pool_size=20 +// --tuner_gen_selection_pool_size=20 // --proto_path="/tmp/tensordot" int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); diff --git a/tensor_comprehensions/pybinds/pybind_autotuner.cc b/tensor_comprehensions/pybinds/pybind_autotuner.cc index 79c949ab8..6002b1be4 100644 --- a/tensor_comprehensions/pybinds/pybind_autotuner.cc +++ b/tensor_comprehensions/pybinds/pybind_autotuner.cc @@ -100,9 +100,14 @@ PYBIND11_MODULE(autotuner, m) { tc::FLAGS_tuner_gen_generations = generations; }) .def( - "number_elites", - [](ATenCudaTuner& instance, uint32_t& number_elites) { - tc::FLAGS_tuner_gen_number_elites = number_elites; + "mating_pool_size", + [](ATenCudaTuner& instance, uint32_t& mating_pool_size) { + tc::FLAGS_tuner_gen_mating_pool_size = mating_pool_size; + }) + .def( + "selection_pool_size", + [](ATenCudaTuner& instance, uint32_t& selection_pool_size) { + tc::FLAGS_tuner_gen_selection_pool_size = selection_pool_size; }) .def( "threads", diff --git a/tensor_comprehensions/tc_unit.py b/tensor_comprehensions/tc_unit.py index 1865120c2..cf8df8d2d 100644 --- a/tensor_comprehensions/tc_unit.py +++ b/tensor_comprehensions/tc_unit.py @@ -193,6 +193,9 @@ def __init__(self, tc_lang, **kwargs): def set_autotuner_parameters( self, pop_size=20, crossover_rate=80, mutation_rate=7, generations=10, number_elites=1, threads=8, gpus="0", restore_from_proto=False, + restore_number=10, log_generations=False + mating_pool_size=60, selection_pool_size=60, + threads=8, gpus="0", restore_from_proto=False, restore_number=10, log_generations=False, save_best_candidates_count=10, tuner_min_launch_total_threads=64, **kwargs ): @@ -200,7 +203,8 @@ def set_autotuner_parameters( self.autotuner.crossover_rate(crossover_rate) self.autotuner.mutation_rate(mutation_rate) self.autotuner.generations(generations) - self.autotuner.number_elites(number_elites) + self.autotuner.mating_pool_size(mating_pool_size) + self.autotuner.selection_pool_size(selection_pool_size) self.autotuner.threads(threads) self.autotuner.gpus(gpus) self.autotuner.restore_from_proto(restore_from_proto) @@ -550,9 +554,6 @@ def autotune(self, *inputs, **kwargs): mutation_rate (int): rate at which candidate options are randomly changed (mutated). Default 7 - number_elites (int): - number of best candidates that are preserved intact between generations (without any mutations). Default 10 - threads (int): The number of threads that are used to compile different candidates in parallel. Default 1 diff --git a/test/cuda/test_autotuner.cc b/test/cuda/test_autotuner.cc index e08f7b52b..cf012e3da 100644 --- a/test/cuda/test_autotuner.cc +++ b/test/cuda/test_autotuner.cc @@ -51,7 +51,9 @@ struct ATenCompilationUnitTest : public ::testing::Test { tc::FLAGS_tuner_gen_pop_size = 8; tc::FLAGS_tuner_gen_generations = 5; tc::FLAGS_tuner_threads = std::min(8u, tc::FLAGS_tuner_gen_pop_size); - tc::FLAGS_tuner_gen_number_elites = tc::FLAGS_tuner_gen_pop_size / 4; + tc::FLAGS_tuner_gen_mating_pool_size = tc::FLAGS_tuner_gen_pop_size * 3; + tc::FLAGS_tuner_gen_selection_pool_size = + tc::FLAGS_tuner_gen_pop_size * 3; } } diff --git a/test_python/common.py b/test_python/common.py index f6ee68246..1c5409b76 100644 --- a/test_python/common.py +++ b/test_python/common.py @@ -36,7 +36,7 @@ def assert_almost_equal(self, diff, inputs, operations, precision=1e-7): def autotune_store(self, cache_file, lang, tc_name, inputs, tc_type): tuner = TcAutotuner( - lang, threads=16, pop_size=10, number_elites=1, generations=1 + lang, threads=16, pop_size=10, generations=1 ) best_options = tuner.tune_and_store( tc_name, inputs, mapping_options=tc_type, cache_file=cache_file