diff --git a/docs/source/framework/pytorch_integration/autotuning_layers.rst b/docs/source/framework/pytorch_integration/autotuning_layers.rst index 4574e4f85..f00ce658b 100644 --- a/docs/source/framework/pytorch_integration/autotuning_layers.rst +++ b/docs/source/framework/pytorch_integration/autotuning_layers.rst @@ -58,7 +58,6 @@ You can read about all the parameters here - :ref:`autotuner_parameters`. - :code:`threads` - set this to number of CPU cores available. - :code:`generations` - 5 to 10 generations is a good number. - :code:`pop_size` - 10 is usually reasonable. You can try 10 to 20. -- :code:`number_elites` - number of candidates preserved intact between generations. `1` is usually sufficient. - :code:`min_launch_total_threads` - If you have really input small sizes, set this to `1`. - :code:`gpus`: Number of gpus to use for autotuning. Default value is "0". Set this to "0,1" if you wish to use two gpus (for example). @@ -70,7 +69,7 @@ kernel timing. You can adopt the following parameter settings as starters for au .. code:: settings = { - "threads": 32, "generations": 2, "pop_size": 10, "number_elites": 1 + "threads": 32, "generations": 2, "pop_size": 10 } * The good defaults that run for a bit longer (in exchange for better performance): @@ -78,7 +77,7 @@ kernel timing. You can adopt the following parameter settings as starters for au .. code:: settings = { - "threads": 32, "generations": 5, "pop_size": 10, "number_elites": 1 + "threads": 32, "generations": 5, "pop_size": 10 } @@ -87,7 +86,7 @@ kernel timing. You can adopt the following parameter settings as starters for au .. code:: settings = { - "threads": 32, "generations": 25, "pop_size": 100, "number_elites": 10 + "threads": 32, "generations": 25, "pop_size": 100 } diff --git a/docs/source/tutorials/tutorial_tensordot_with_tc.rst b/docs/source/tutorials/tutorial_tensordot_with_tc.rst index 6a98695df..52cc3deac 100644 --- a/docs/source/tutorials/tutorial_tensordot_with_tc.rst +++ b/docs/source/tutorials/tutorial_tensordot_with_tc.rst @@ -132,7 +132,7 @@ later. You can control the amount of autotuning by changing the autotuner parameters. See :ref:`autotune_parameters` for how to change the settings. -For the setting ``settings={"generations": 25, "pop_size": 100, "number_elites": 10}``, we +For the setting ``settings={"generations": 25, "pop_size": 100}``, we get a decent kernel performance as shown in the screenshot below (tuned on one M40 GPU): .. figure:: ../_static/img/autotuning-py.jpg diff --git a/tc/autotuner/genetic_autotuner.cc b/tc/autotuner/genetic_autotuner.cc index a4493eef2..bbf53426b 100644 --- a/tc/autotuner/genetic_autotuner.cc +++ b/tc/autotuner/genetic_autotuner.cc @@ -120,7 +120,8 @@ llvm::Optional GeneticAutotuner::tune( FLAGS_tuner_gen_pop_size, FLAGS_tuner_gen_crossover_rate, FLAGS_tuner_gen_mutation_rate, - FLAGS_tuner_gen_number_elites, + FLAGS_tuner_gen_mating_pool_size, + FLAGS_tuner_gen_selection_pool_size, tcNameMap_.at(tcName), tcName, inputs, diff --git a/tc/autotuner/genetic_search.cc b/tc/autotuner/genetic_search.cc index c8cc01127..7d514eeb6 100644 --- a/tc/autotuner/genetic_search.cc +++ b/tc/autotuner/genetic_search.cc @@ -16,9 +16,13 @@ #include "tc/autotuner/genetic_search.h" +#include +#include #include #include +#include "tc/autotuner/utils/utils.h" + namespace tc { namespace autotune { @@ -72,13 +76,6 @@ void mutate( } } -void normalizeVector(std::vector& v) { - auto sum = std::accumulate(v.begin(), v.end(), 0.0); - - std::transform( - v.begin(), v.end(), v.begin(), [sum](double v) { return v / sum; }); -} - std::vector computeNormalizedFitness( const GeneticSearch::Population& population) { std::vector fitness; @@ -92,6 +89,7 @@ std::vector computeNormalizedFitness( std::chrono::duration_cast(c->runtime) .count(); }); + sigmaScale(fitness); normalizeVector(fitness); return fitness; } @@ -133,7 +131,8 @@ void dropInvalidConfigurations(GeneticSearch::Population& population) { } // namespace #define VALIDATE() \ - CHECK_LT(kNumberElites, kMaxPopulationSize); \ + CHECK_LT(kMaxPopulationSize, kMatingPoolSize); \ + CHECK_LT(kMaxPopulationSize, kSelectionPoolSize); \ CHECK(kMutationRate >= 0 and kMutationRate <= 100) \ << "the mutation rate (" << kMutationRate \ << ") should be in the [0,100] interval"; \ @@ -160,13 +159,15 @@ GeneticSearch::GeneticSearch( size_t n, uint8_t crossOverRate, uint8_t mutationRate, - size_t numberElites) + size_t matingPoolSize, + size_t selectionPoolSize) : population(), lastBestConf(confs[0]), kMaxPopulationSize(n), + kMatingPoolSize(matingPoolSize), + kSelectionPoolSize(selectionPoolSize), kCrossOverRate(crossOverRate), kMutationRate(mutationRate), - kNumberElites(numberElites), rng{std::random_device{}()} { restoreRngState(rng); VALIDATE(); @@ -192,13 +193,15 @@ GeneticSearch::GeneticSearch( size_t n, uint8_t crossOverRate, uint8_t mutationRate, - size_t numberElites) + size_t matingPoolSize, + size_t selectionPoolSize) : population(), lastBestConf(conf), kMaxPopulationSize(n), + kMatingPoolSize(matingPoolSize), + kSelectionPoolSize(selectionPoolSize), kCrossOverRate(crossOverRate), kMutationRate(mutationRate), - kNumberElites(numberElites), rng{std::random_device{}()} { restoreRngState(rng); VALIDATE(); @@ -246,19 +249,34 @@ TuningConfiguration GeneticSearch::crossover( return a; } -void GeneticSearch::breed() { - auto accFitness = computeAccumulatedFitness(population); - Population new_population; - new_population.reserve(kMaxPopulationSize); - for (auto& p : population) { - new_population.push_back( - make_unique(p->configuration)); +std::vector GeneticSearch::stochasticUniversalSampling( + const std::vector& fitness) const { + std::vector matingPool; + matingPool.reserve(kMatingPoolSize); + + auto r = + std::uniform_real_distribution(0, 1.0 / kMatingPoolSize)(rng); + size_t count = 0; + size_t i = 0; + while (count < kMatingPoolSize) { + while (r <= fitness[i]) { + matingPool.push_back(population[i]->configuration); + r += 1.0 / kMatingPoolSize; + ++count; + } + ++i; } + return matingPool; +} + +void GeneticSearch::breed() { + auto matingPool = + stochasticUniversalSampling(computeAccumulatedFitness(population)); auto select = [&]() -> TuningConfiguration& { - auto limit = std::uniform_real_distribution{}(rng); - auto lb = std::lower_bound(accFitness.begin(), accFitness.end(), limit); - return population.at(std::distance(accFitness.begin(), lb))->configuration; + auto idx = std::uniform_int_distribution{ + size_t(0), matingPool.size() - 1}(rng); + return matingPool.at(idx); }; auto shouldCrossOver = [&]() -> bool { /* @@ -270,45 +288,20 @@ void GeneticSearch::breed() { return dist(rng); }; - while (new_population.size() < kMaxPopulationSize) { + while (selectionPool.size() < kSelectionPoolSize) { if (shouldCrossOver()) { auto parent1 = select(); auto parent2 = select(); auto parent3 = select(); - new_population.emplace_back(make_unique( + selectionPool.emplace_back(make_unique( crossover(parent1, parent2, parent3))); } else { - new_population.emplace_back( - make_unique(select())); + selectionPool.emplace_back(make_unique(select())); } } - population = std::move(new_population); } -void GeneticSearch::updateParameters() { - dropInvalidConfigurations(population); - - // Sort population before taking any decision - std::sort( - population.begin(), - population.end(), - [](const std::unique_ptr& a, - const std::unique_ptr& b) { - checkRuntimeRecorded(a->runtime); - checkRuntimeRecorded(b->runtime); - return a->runtime < b->runtime; - }); - - // Update failsafe lastBestConf - lastBestConf = - population.size() > 0 ? population.front()->configuration : lastBestConf; - if (FLAGS_tuner_print_best) { - CudaMappingOptions options( - CudaMappingOptions::makeSingleThreadCudaMappingOptions()); - lastBestConf.applyToCudaMappingOptions(options); - LOG(INFO) << "Best so far:\n" << options; - } - +bool GeneticSearch::resetPopulationIfNotEnoughCandidates() { if (population.size() < kMinCandidatesForBreeding) { LOG_IF(ERROR, FLAGS_debug_tuner) << population.size() << " out of " << kMaxPopulationSize @@ -327,12 +320,81 @@ void GeneticSearch::updateParameters() { // Don't lose the first one which was the best from before CHECK_LT(0, population.size()); randomizePopulation(population.begin() + 1, population.end(), rng); - return; + + selectionPool.clear(); + for (size_t i = 0; i < kSelectionPoolSize; ++i) { + selectionPool.emplace_back( + make_unique(lastBestConf)); + } + randomizePopulation(selectionPool.begin() + 1, selectionPool.end(), rng); + return true; + } + return false; +} + +namespace { +void sortByRuntime(GeneticSearch::Population& population) { + std::sort( + population.begin(), + population.end(), + [](const std::unique_ptr& a, + const std::unique_ptr& b) { + checkRuntimeRecorded(a->runtime); + checkRuntimeRecorded(b->runtime); + return a->runtime < b->runtime; + }); +} +} // namespace + +void GeneticSearch::updateBestCandidate(const TuningConfiguration& c) { + lastBestConf = c; + if (FLAGS_tuner_print_best) { + CudaMappingOptions options( + CudaMappingOptions::makeSingleThreadCudaMappingOptions()); + lastBestConf.applyToCudaMappingOptions(options); + LOG(INFO) << "Best so far:\n" << options; } +} +void GeneticSearch::generateSelectionPool() { + dropInvalidConfigurations(population); + sortByRuntime(population); + updateBestCandidate( + population.size() > 0 ? population.front()->configuration : lastBestConf); + if (resetPopulationIfNotEnoughCandidates()) { + return; + } + selectionPool.clear(); + selectionPool.emplace_back(make_unique(lastBestConf)); breed(); - for (int i = kNumberElites; i < population.size(); ++i) { - mutate(*population[i], kMutationRate, kMutateIterations, rng); + for (size_t i = 1; i < selectionPool.size(); ++i) { + mutate(*selectionPool[i], kMutationRate, kMutateIterations, rng); + } +} + +void GeneticSearch::selectSurvivors() { + dropInvalidConfigurations(selectionPool); + sortByRuntime(selectionPool); + population.clear(); + std::transform( + selectionPool.begin(), + selectionPool.begin() + + std::min(selectionPool.size(), kMaxPopulationSize), + std::back_inserter(population), + [](const std::unique_ptr& c) { + CHECK(c); + return make_unique(*c); + }); + + if (selectionPool.size() < kMaxPopulationSize) { + auto numberMissing = kMaxPopulationSize - selectionPool.size(); + + for (size_t i = 0; i < numberMissing; ++i) { + selectionPool.emplace_back( + make_unique(lastBestConf)); + } + randomizePopulation( + selectionPool.end() - numberMissing, selectionPool.end(), rng); } } diff --git a/tc/autotuner/genetic_search.h b/tc/autotuner/genetic_search.h index 048670211..1c89ca5ee 100644 --- a/tc/autotuner/genetic_search.h +++ b/tc/autotuner/genetic_search.h @@ -70,7 +70,8 @@ class GeneticSearch { size_t n, uint8_t crossOverRate, uint8_t mutationRate, - size_t numberElites); + size_t matingPoolSize, + size_t selectionPoolSize); /** * confs are used to seed the first generation, the rest of the population is @@ -92,13 +93,22 @@ class GeneticSearch { size_t n, uint8_t crossOverRate, uint8_t mutationRate, - size_t numberElites); + size_t matingPoolSize, + size_t selectionPoolSize); - void updateParameters(); + void generateSelectionPool(); + void selectSurvivors(); private: + std::vector stochasticUniversalSampling( + const std::vector& fitness) const; + void breed(); + void updateBestCandidate(const TuningConfiguration& c); + + bool resetPopulationIfNotEnoughCandidates(); + TuningConfiguration crossover( TuningConfiguration&, TuningConfiguration&, @@ -111,11 +121,13 @@ class GeneticSearch { using Population = std::vector>; Population population; + Population selectionPool; TuningConfiguration lastBestConf; const size_t kMaxPopulationSize; + const size_t kMatingPoolSize; + const size_t kSelectionPoolSize; const uint8_t kCrossOverRate; const uint8_t kMutationRate; - const size_t kNumberElites; /* * c++11 seeding is (apparently) not of the highest quality: diff --git a/tc/autotuner/genetic_tuning_harness.cc b/tc/autotuner/genetic_tuning_harness.cc index 927adf627..df8922b9d 100644 --- a/tc/autotuner/genetic_tuning_harness.cc +++ b/tc/autotuner/genetic_tuning_harness.cc @@ -44,7 +44,8 @@ GeneticTunerHarness::GeneticTunerHarness( size_t n, uint8_t crossoverRate, uint8_t mutationRate, - size_t numberElites, + size_t matingPoolSize, + size_t selectionPoolSize, lang::TreeRef tc, std::string kernelName, const std::unordered_map>& inputs, @@ -55,7 +56,8 @@ GeneticTunerHarness::GeneticTunerHarness( : kMaxPopulationSize(n), kCrossOverRate(crossoverRate), kMutationRate(mutationRate), - kNumberElites(numberElites), + kMatingPoolSize(matingPoolSize), + kSelectionPoolSize(selectionPoolSize), bestCudaMappingOptions_(baseMapping), kTc_(std::move(tc)), kKernelName_(std::move(kernelName)), @@ -85,14 +87,16 @@ GeneticTunerHarness::GeneticTunerHarness( kMaxPopulationSize, kCrossOverRate, kMutationRate, - kNumberElites); + kMatingPoolSize, + kSelectionPoolSize); } else { tuner_ = make_unique( configuration, kMaxPopulationSize, kCrossOverRate, kMutationRate, - kNumberElites); + kMatingPoolSize, + kSelectionPoolSize); } } @@ -300,15 +304,18 @@ bool GeneticTunerHarness::warmupOrPrune( return false; } -template -void GeneticTunerHarness::doCompile(ExecutorType& engine) { +template +void GeneticTunerHarness::doCompile( + ExecutorType& engine, + Population& population) { // Atomically fetch and add the next job until there are no jobs left while (true) { auto current = currentCompilationJob_.fetch_add(1); - if (current >= tuner_->population.size()) { + if (current >= population.size()) { break; } - auto& pConf = tuner_->population.at(current); + + auto& pConf = population.at(current); auto options = makeOptions(*pConf); try { if (FLAGS_debug_tuner) { @@ -340,10 +347,44 @@ void GeneticTunerHarness::doCompile(ExecutorType& engine) { } } +namespace { +std::vector toConstDlpackTensors( + const std::vector& v) { + std::vector out(v.begin(), v.end()); + return out; +} +} // namespace + template +std::vector retrieveCachedRuntimes( + ExecutorType& engine, + const std::string& id, + const std::vector& inputs, + const std::vector& outputs, + const CudaMappingOptions& options) { + if (not OptionsCache::cacheEnabled()) { + return {}; + } + auto cache = OptionsCache::getCache(); + auto allResults = cache->retrieveOptionsAndRuntimes( + id, inputs, toConstDlpackTensors(outputs)); + auto wantedResult = std::find_if( + allResults.begin(), + allResults.end(), + [&options](const OptionsCache::RetrievalResult& r) { + return r.options == options; + }); + if (wantedResult == allResults.end()) { + return {}; + } + return wantedResult->recordedRuntimes; +} + +template void GeneticTunerHarness::doGpuWork( size_t gpu, ExecutorType& engine, + Population& population, Printer& printer) { WithDevice wd(gpu); CHECK_EQ(1, kInputs_.count(gpu)); @@ -367,7 +408,7 @@ void GeneticTunerHarness::doGpuWork( // Found work to do, increment number of evaluations performed numEvaluations_.fetch_add(1); } else { - if (numEvaluations_.load() >= tuner_->population.size()) { + if (numEvaluations_.load() >= population.size()) { // No more work can arrive, exit return; } @@ -376,7 +417,7 @@ void GeneticTunerHarness::doGpuWork( continue; } - auto& pConf = tuner_->population.at(current); + auto& pConf = population.at(current); if (pConf->invalid) { continue; } @@ -392,51 +433,56 @@ void GeneticTunerHarness::doGpuWork( LOG_LINE_BY_LINE(INFO, ssInfo); } - std::vector runtimes; - try { - size_t bestTimeSoFar; - { - std::lock_guard lock(bestTimeMtx_); - bestTimeSoFar = bestTime_; - } - auto prune = - warmupOrPrune(engine, outputs, inputs, handle, bestTimeSoFar); - if (prune) { + std::vector runtimes = + retrieveCachedRuntimes(engine, kKernelName_, inputs, outputs, options); + if (runtimes.empty()) { + try { + size_t bestTimeSoFar; + { + std::lock_guard lock(bestTimeMtx_); + bestTimeSoFar = bestTime_; + } + auto prune = + warmupOrPrune(engine, outputs, inputs, handle, bestTimeSoFar); + if (prune) { + pConf->invalid = true; + continue; + } else { + runtimes.reserve(kReducedBenchmarkIterations); + for (size_t i = 0; i < kReducedBenchmarkIterations; ++i) { + runtimes.push_back(engine.run(handle, inputs, outputs, true)); + } + engine.clear(handle); + } + } catch (std::exception& e) { + if (FLAGS_debug_tuner) { + LOG(WARNING) << "Runtime error gpu " << gpu << ": " << e.what(); + std::stringstream ssWarning; + CudaMappingOptionsCppPrinter warningPrinter(ssWarning); + warningPrinter << options; + LOG(WARNING) << "Aborted execution on gpu " << gpu; + LOG_LINE_BY_LINE(WARNING, ssWarning); + } + while (cudaGetLastError() != cudaSuccess) { + // In case of errors in the generated, we cannot rely on deviceReset + // to set the GPU in a clean state. So instead we just pop and discard + // all the errors accumulated on the GPU until we get to a clean slate + // (i.e. cudaSuccess). + ; + } + try { + // Some errors, such as illegal memory access, cannot be recovered + // from without a cudaDeviceReset (i.e. because user protection) In + // those cases we have no choice than to fail hard. + TC_CUDA_RUNTIMEAPI_ENFORCE(cudaDeviceSynchronize()); + } catch (const std::exception& e) { + LOG(FATAL) << "[CUDA][FATAL] cuda error on gpu " << gpu << ": " + << e.what() << "\n" + << CudaMappingOptionsAsCpp(options); + } pConf->invalid = true; continue; - } else { - runtimes.reserve(kReducedBenchmarkIterations); - for (size_t i = 0; i < kReducedBenchmarkIterations; ++i) { - runtimes.push_back(engine.run(handle, inputs, outputs, true)); - } - engine.clear(handle); - } - } catch (std::exception& e) { - LOG(WARNING) << "Runtime error gpu " << gpu << ": " << e.what(); - std::stringstream ssWarning; - CudaMappingOptionsCppPrinter warningPrinter(ssWarning); - warningPrinter << options; - LOG(WARNING) << "Aborted execution on gpu " << gpu; - LOG_LINE_BY_LINE(WARNING, ssWarning); - while (cudaGetLastError() != cudaSuccess) { - // In case of errors in the generated, we cannot rely on deviceReset to - // set the GPU in a clean state. So instead we just pop and discard all - // the errors accumulated on the GPU until we get to a clean slate - // (i.e. cudaSuccess). - ; } - try { - // Some errors, such as illegal memory access, cannot be recovered from - // without a cudaDeviceReset (i.e. because user protection) - // In those cases we have no choice than to fail hard. - TC_CUDA_RUNTIMEAPI_ENFORCE(cudaDeviceSynchronize()); - } catch (const std::exception& e) { - LOG(FATAL) << "[CUDA][FATAL] cuda error on gpu " << gpu << ": " - << e.what() << "\n" - << CudaMappingOptionsAsCpp(options); - } - pConf->invalid = true; - continue; } auto prof = median(runtimes); @@ -479,57 +525,85 @@ void GeneticTunerHarness::runOneGeneration(size_t generation) { tc::ExecutionEngine engine; engine.define({kTc_}); - { - // Initialize for this round - currentCompilationJob_.store(0); - numEvaluations_.store(0); - readyToEvaluate_.resize(0); - for (int i = 0; i < kMaxPopulationSize; ++i) { - readyToEvaluate_.emplace_back(); - readyToEvaluate_[i].store(false); + auto setUpJobsAndRun = [&](GeneticSearch::Population& population, + const std::string& printerText) { + // Most candidates should have been evaluated during the previous + // generation's selection phase. + // There are two exceptions: + // 1) the 1st generation + // 2) too many invalid configurations were previously encounted and the + // valid ones were not enough to form a new generation. + auto firstNew = std::partition( + population.begin(), + population.end(), + [](const std::unique_ptr& c) { + return c->runtime != Duration::zero(); + }); + if (std::distance(firstNew, population.end()) == 0) { + return; } - Printer printer( - generation, - readyToEvaluate_.size(), - currentCompilationJob_, - numEvaluations_); - auto logGenerations = FLAGS_tuner_gen_log_generations; - ScopeGuard sgPrinter([logGenerations, &printer]() { - printer.stop(); - if (logGenerations) { - printer.printAll(); - } - }); - - // Just spawn and join new threads for each generation - std::vector cpuCompilationThreads; - cpuCompilationThreads.reserve(FLAGS_tuner_threads); - ScopeGuard sgCompilationThreads([&cpuCompilationThreads]() { - for (auto& cpuCompilationThread : cpuCompilationThreads) { - cpuCompilationThread.join(); + GeneticSearch::Population newCandidates( + std::distance(firstNew, population.end())); + std::move(firstNew, population.end(), newCandidates.begin()); + { + // Initialize for this round + currentCompilationJob_.store(0); + numEvaluations_.store(0); + readyToEvaluate_.resize(0); + for (size_t i = 0; i < newCandidates.size(); ++i) { + readyToEvaluate_.emplace_back(); + readyToEvaluate_[i].store(false); } - }); - for (int i = 0; i < FLAGS_tuner_threads; ++i) { - cpuCompilationThreads.emplace_back( - [this, &engine]() { this->doCompile(engine); }); - } + Printer printer( + printerText, + readyToEvaluate_.size(), + currentCompilationJob_, + numEvaluations_); + auto logGenerations = FLAGS_tuner_gen_log_generations; + ScopeGuard sgPrinter([logGenerations, &printer]() { + printer.stop(); + if (logGenerations) { + printer.printAll(); + } + }); - // Just spawn and join new threads for each generation - std::vector gpuWorkerThreads; - gpuWorkerThreads.reserve(gpus.size()); - ScopeGuard sgGpuWorkerThreads([&gpuWorkerThreads]() { - for (auto& gpuWorkerThread : gpuWorkerThreads) { - gpuWorkerThread.join(); + // Just spawn and join new threads for each generation + std::vector cpuCompilationThreads; + cpuCompilationThreads.reserve(FLAGS_tuner_threads); + ScopeGuard sgCompilationThreads([&cpuCompilationThreads]() { + for (auto& cpuCompilationThread : cpuCompilationThreads) { + cpuCompilationThread.join(); + } + }); + for (int i = 0; i < FLAGS_tuner_threads; ++i) { + cpuCompilationThreads.emplace_back([this, &engine, &newCandidates]() { + this->doCompile(engine, newCandidates); + }); } - }); - for (auto gpu : gpus) { - gpuWorkerThreads.emplace_back([this, gpu, &engine, &printer]() { - this->doGpuWork(gpu, engine, printer); + + // Just spawn and join new threads for each generation + std::vector gpuWorkerThreads; + gpuWorkerThreads.reserve(gpus.size()); + ScopeGuard sgGpuWorkerThreads([&gpuWorkerThreads]() { + for (auto& gpuWorkerThread : gpuWorkerThreads) { + gpuWorkerThread.join(); + } }); + for (auto gpu : gpus) { + gpuWorkerThreads.emplace_back( + [this, gpu, &engine, &newCandidates, &printer]() { + this->doGpuWork(gpu, engine, newCandidates, printer); + }); + } } - } - - // At this point everything is synchronized because out of scope, done + // At this point everything is synchronized because out of scope, done + std::move(newCandidates.begin(), newCandidates.end(), firstNew); + }; + std::cout << "Generation " << generation << ':' << std::endl; + setUpJobsAndRun(tuner_->population, "New Candidates"); + tuner_->generateSelectionPool(); + setUpJobsAndRun(tuner_->selectionPool, "Selection Pool"); + tuner_->selectSurvivors(); if (FLAGS_debug_tuner) { LOG(INFO) << "[TUNER][GENERATION LOG] best option so far:"; @@ -538,7 +612,6 @@ void GeneticTunerHarness::runOneGeneration(size_t generation) { infoPrinter << bestMappingOption(); LOG_LINE_BY_LINE(INFO, ssInfo); } - tuner_->updateParameters(); } } // namespace detail diff --git a/tc/autotuner/genetic_tuning_harness.h b/tc/autotuner/genetic_tuning_harness.h index 0c7b476c0..f3e5b3495 100644 --- a/tc/autotuner/genetic_tuning_harness.h +++ b/tc/autotuner/genetic_tuning_harness.h @@ -38,7 +38,8 @@ class GeneticTunerHarness { size_t n, uint8_t crossoverRate, uint8_t mutationRate, - size_t numberElites, + size_t matingPoolSize, + size_t selectionPoolSize, lang::TreeRef tc, std::string kernelName, const std::unordered_map>& inputs, @@ -66,12 +67,16 @@ class GeneticTunerHarness { size_t bestTimeSoFar); /// Helper function to delegate compiling on the cpu to different threads - template - void doCompile(ExecutorType& engine); + template + void doCompile(ExecutorType& engine, Population& population); /// Helper function to delegate running on the gpu to different threads - template - void doGpuWork(size_t gpu, ExecutorType& engine, Printer& printer); + template + void doGpuWork( + size_t gpu, + ExecutorType& engine, + Population& population, + Printer& printer); /// Make options from conf tc::CudaMappingOptions makeOptions(const CandidateConfiguration& conf); @@ -90,7 +95,8 @@ class GeneticTunerHarness { const size_t kMaxPopulationSize; const uint8_t kCrossOverRate; const uint8_t kMutationRate; - const size_t kNumberElites; + const size_t kMatingPoolSize; + const size_t kSelectionPoolSize; TuningConfiguration configuration; diff --git a/tc/autotuner/utils/printer.cc b/tc/autotuner/utils/printer.cc index d6a39c545..3b9d218fe 100644 --- a/tc/autotuner/utils/printer.cc +++ b/tc/autotuner/utils/printer.cc @@ -41,8 +41,7 @@ void Printer::printLoop() { std::this_thread::sleep_for(std::chrono::seconds(1)); std::stringstream ss; - ss << "Generation " << generation_; - ss << "\tJobs(Compiled, GPU)/total (" + ss << prefix_ << "\tJobs(Compiled, GPU)/total (" << std::min(total_, currentCompilationJob_.load()) << ", " << std::min(total_, numEvaluations_.load()) << ")/" << total_; @@ -76,11 +75,11 @@ void Printer::printLoop() { } Printer::Printer( - size_t generation, + std::string prefix, size_t total, const std::atomic_size_t& currentCompilationJob, const std::atomic_size_t& numEvaluations) - : generation_(generation), + : prefix_(std::move(prefix)), printerThread_([this]() { printLoop(); }), total_(total), currentCompilationJob_(currentCompilationJob), diff --git a/tc/autotuner/utils/printer.h b/tc/autotuner/utils/printer.h index 4a2760564..b2ddff853 100644 --- a/tc/autotuner/utils/printer.h +++ b/tc/autotuner/utils/printer.h @@ -33,7 +33,7 @@ namespace autotune { class Printer { public: Printer( - size_t generation, + std::string prefix, size_t total, const std::atomic_size_t& currentCompilationJob, const std::atomic_size_t& numEvaluations); @@ -47,7 +47,7 @@ class Printer { private: void printLoop(); - size_t generation_; + std::string prefix_; std::vector runtimes_; mutable std::mutex runtimesMtx_; diff --git a/tc/autotuner/utils/utils.cc b/tc/autotuner/utils/utils.cc index e38e752cd..3efeeb17c 100644 --- a/tc/autotuner/utils/utils.cc +++ b/tc/autotuner/utils/utils.cc @@ -15,6 +15,7 @@ */ #include #include +#include #include "tc/aten/aten_compiler.h" #include "tc/autotuner/utils/utils.h" @@ -109,5 +110,38 @@ llvm::Optional getBestOptions( return llvm::Optional{}; } +double mean(std::vector& v) { + if (v.empty()) { + throw std::invalid_argument("Cannot compute the mean of an empty vector."); + } + auto sum = std::accumulate(v.begin(), v.end(), 0.0); + return sum / v.size(); +} + +double stdv(std::vector& v, double mean) { + std::vector diffs(v.size()); + std::transform(v.begin(), v.end(), diffs.begin(), [mean](double val) { + return val - mean; + }); + + auto squareSum = + std::inner_product(diffs.begin(), diffs.end(), diffs.begin(), 0.0); + return std::sqrt(squareSum / v.size()); +} + +void sigmaScale(std::vector& v) { + auto m = mean(v); + auto s = stdv(v, m); + std::transform(v.begin(), v.end(), v.begin(), [m, s](double val) { + return std::max(val - (m - 2 * s), 0.0); + }); +} + +void normalizeVector(std::vector& v) { + auto sum = std::accumulate(v.begin(), v.end(), 0.0); + std::transform( + v.begin(), v.end(), v.begin(), [sum](double v) { return v / sum; }); +} + } // namespace autotune } // namespace tc diff --git a/tc/autotuner/utils/utils.h b/tc/autotuner/utils/utils.h index 2e5367e92..33c659c93 100644 --- a/tc/autotuner/utils/utils.h +++ b/tc/autotuner/utils/utils.h @@ -59,6 +59,10 @@ struct OptionsWithMedianTime { std::vector getOptionsAndMedianRuntimes( const lang::CanonicalTcString& id, const std::vector& inputs); +double mean(std::vector& v); +double stdv(std::vector& v, double mean); +void normalizeVector(std::vector& v); +void sigmaScale(std::vector& v); } // namespace autotune } // namespace tc diff --git a/tc/core/flags.cc b/tc/core/flags.cc index e2f4f5a25..2e328da8a 100644 --- a/tc/core/flags.cc +++ b/tc/core/flags.cc @@ -59,6 +59,17 @@ DEFINE_uint32( tuner_gen_pop_size, 100, "Population size for genetic autotuning"); + +DEFINE_uint32( + tuner_gen_mating_pool_size, + 300, + "Mating pool size for genetic autotuning"); + +DEFINE_uint32( + tuner_gen_selection_pool_size, + 300, + "Selection pool size for genetic autotuning"); + DEFINE_uint32( tuner_gen_crossover_rate, 80, @@ -71,10 +82,6 @@ DEFINE_uint32( tuner_gen_generations, 25, "How many generations to run genetic tuning for"); -DEFINE_uint32( - tuner_gen_number_elites, - 10, - "The number of best candidates that are preserved intact between generations"); DEFINE_uint32(tuner_threads, 1, "Number of CPU threads to use when autotuning"); DEFINE_string( tuner_gpus, diff --git a/tc/core/flags.h b/tc/core/flags.h index 0324b00ad..c430dcb05 100644 --- a/tc/core/flags.h +++ b/tc/core/flags.h @@ -40,10 +40,11 @@ DECLARE_uint32(benchmark_iterations); // Used in autotuning DECLARE_uint32(tuner_gen_pop_size); +DECLARE_uint32(tuner_gen_mating_pool_size); +DECLARE_uint32(tuner_gen_selection_pool_size); DECLARE_uint32(tuner_gen_crossover_rate); DECLARE_uint32(tuner_gen_mutation_rate); DECLARE_uint32(tuner_gen_generations); -DECLARE_uint32(tuner_gen_number_elites); DECLARE_uint32(tuner_threads); DECLARE_string(tuner_gpus); DECLARE_bool(tuner_print_best); diff --git a/tc/examples/tensordot.cc b/tc/examples/tensordot.cc index a87e1606b..0e463e775 100644 --- a/tc/examples/tensordot.cc +++ b/tc/examples/tensordot.cc @@ -89,7 +89,8 @@ def tensordot(float(N, C1, C2, H, W) I0, // From root, run with: // ./build/examples/tensordot --tuner_threads=10 --tuner_gen_pop_size=10 -// --tuner_gen_generations=3 --tuner_gen_number_elites=4 +// --tuner_gen_generations=3 --tuner_gen_mating_pool_size=20 +// --tuner_gen_selection_pool_size=20 // --proto_path="/tmp/tensordot" int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); diff --git a/tensor_comprehensions/pybinds/pybind_autotuner.cc b/tensor_comprehensions/pybinds/pybind_autotuner.cc index 4a5337ca2..3a70ae8a6 100644 --- a/tensor_comprehensions/pybinds/pybind_autotuner.cc +++ b/tensor_comprehensions/pybinds/pybind_autotuner.cc @@ -70,10 +70,16 @@ PYBIND11_MODULE(autotuner, m) { tc::FLAGS_tuner_gen_generations = generations; }) .def( - "number_elites", + "mating_pool_size", [](tc::autotune::GeneticAutotunerATen& instance, - uint32_t& number_elites) { - tc::FLAGS_tuner_gen_number_elites = number_elites; + uint32_t& mating_pool_size) { + tc::FLAGS_tuner_gen_mating_pool_size = mating_pool_size; + }) + .def( + "selection_pool_size", + [](tc::autotune::GeneticAutotunerATen& instance, + uint32_t& selection_pool_size) { + tc::FLAGS_tuner_gen_selection_pool_size = selection_pool_size; }) .def( "threads", diff --git a/tensor_comprehensions/tc_unit.py b/tensor_comprehensions/tc_unit.py index f512fad30..8185a00cb 100644 --- a/tensor_comprehensions/tc_unit.py +++ b/tensor_comprehensions/tc_unit.py @@ -193,6 +193,9 @@ def __init__(self, tc_lang, **kwargs): def set_autotuner_parameters( self, pop_size=20, crossover_rate=80, mutation_rate=7, generations=10, number_elites=1, threads=8, gpus="0", restore_from_proto=False, + restore_number=10, log_generations=False + mating_pool_size=60, selection_pool_size=60, + threads=8, gpus="0", restore_from_proto=False, restore_number=10, log_generations=False, save_best_candidates_count=10, tuner_min_launch_total_threads=64, **kwargs ): @@ -200,7 +203,8 @@ def set_autotuner_parameters( self.autotuner.crossover_rate(crossover_rate) self.autotuner.mutation_rate(mutation_rate) self.autotuner.generations(generations) - self.autotuner.number_elites(number_elites) + self.autotuner.mating_pool_size(mating_pool_size) + self.autotuner.selection_pool_size(selection_pool_size) self.autotuner.threads(threads) self.autotuner.gpus(gpus) self.autotuner.restore_from_proto(restore_from_proto) @@ -575,9 +579,6 @@ def autotune(self, *inputs, **kwargs): mutation_rate (int): rate at which candidate options are randomly changed (mutated). Default 7 - number_elites (int): - number of best candidates that are preserved intact between generations (without any mutations). Default 10 - threads (int): The number of threads that are used to compile different candidates in parallel. Default 1 diff --git a/test/cuda/test_autotuner.cc b/test/cuda/test_autotuner.cc index 7de32a0c3..8871d7037 100644 --- a/test/cuda/test_autotuner.cc +++ b/test/cuda/test_autotuner.cc @@ -51,7 +51,9 @@ struct ATenCompilationUnitTest : public ::testing::Test { tc::FLAGS_tuner_gen_pop_size = 8; tc::FLAGS_tuner_gen_generations = 5; tc::FLAGS_tuner_threads = std::min(8u, tc::FLAGS_tuner_gen_pop_size); - tc::FLAGS_tuner_gen_number_elites = tc::FLAGS_tuner_gen_pop_size / 4; + tc::FLAGS_tuner_gen_mating_pool_size = tc::FLAGS_tuner_gen_pop_size * 3; + tc::FLAGS_tuner_gen_selection_pool_size = + tc::FLAGS_tuner_gen_pop_size * 3; } } diff --git a/test_python/common.py b/test_python/common.py index 0ccd9605f..18f887892 100644 --- a/test_python/common.py +++ b/test_python/common.py @@ -36,7 +36,7 @@ def assert_almost_equal(self, diff, inputs, operations, precision=1e-7): def autotune_store(self, cache_file, lang, tc_name, inputs, tc_type): tuner = TcAutotuner( - lang, threads=16, pop_size=10, number_elites=1, generations=1 + lang, threads=16, pop_size=10, generations=1 ) best_options = tuner.tune_and_store( tc_name, inputs, mapping_options=tc_type, cache_file=cache_file