@@ -33,11 +33,8 @@ void randomizeParameter(Parameter& param, RNG& rng) {
33
33
param.selectOption (paramIndex);
34
34
}
35
35
36
- template <typename RNG>
37
- void randomizePopulation (
38
- GeneticSearch::Population::iterator begin,
39
- GeneticSearch::Population::iterator end,
40
- RNG& rng) {
36
+ template <typename RNG, typename Iterator>
37
+ void randomizePopulation (Iterator begin, Iterator end, RNG& rng) {
41
38
for (auto candidate = begin; candidate != end; ++candidate) {
42
39
auto & conf = (*candidate)->configuration ;
43
40
do {
@@ -160,7 +157,8 @@ void dropInvalidConfigurations(GeneticSearch::Population& population) {
160
157
} // namespace
161
158
162
159
#define VALIDATE () \
163
- CHECK_LT (numberElites, maxPopulationSize); \
160
+ CHECK_LT (maxPopulationSize, matingPoolSize); \
161
+ CHECK_LT (maxPopulationSize, selectionPoolSize); \
164
162
CHECK (mutationRate >= 0 and mutationRate <= 100 ) \
165
163
<< " the mutation rate (" << mutationRate \
166
164
<< " ) should be in the [0,100] interval" ; \
@@ -188,15 +186,16 @@ GeneticSearch::GeneticSearch(
188
186
size_t populationSize,
189
187
uint8_t crossOverRate,
190
188
uint8_t mutationRate,
191
- size_t numberElites)
189
+ size_t matingPoolSize,
190
+ size_t selectionPoolSize)
192
191
: population(),
193
192
lastBestConf (confs[0 ]),
194
193
numGenerations(numGenerations),
195
194
maxPopulationSize(populationSize),
196
- matingPoolSize(populationSize * 3 ),
195
+ matingPoolSize(matingPoolSize),
196
+ selectionPoolSize(selectionPoolSize),
197
197
crossOverRate(crossOverRate),
198
198
mutationRate(mutationRate),
199
- numberElites(numberElites),
200
199
rng{std::random_device{}()} {
201
200
restoreRngState (rng);
202
201
VALIDATE ();
@@ -276,13 +275,6 @@ void GeneticSearch::breed() {
276
275
auto matingPool =
277
276
stochasticUniversalSampling (computeAccumulatedFitness (population));
278
277
279
- Population new_population;
280
- new_population.reserve (matingPoolSize);
281
- for (size_t c = 0 ; c < numberElites; ++c) {
282
- new_population.push_back (
283
- make_unique<CandidateConfiguration>(population.at (c)->configuration ));
284
- }
285
-
286
278
auto select = [&]() -> TuningConfiguration& {
287
279
auto idx = std::uniform_int_distribution<size_t >{
288
280
size_t (0 ), matingPool.size () - 1 }(rng);
@@ -298,39 +290,20 @@ void GeneticSearch::breed() {
298
290
return dist (rng);
299
291
};
300
292
301
- while (new_population .size () < maxPopulationSize ) {
293
+ while (selectionPool .size () < selectionPoolSize ) {
302
294
if (shouldCrossOver ()) {
303
295
auto parent1 = select ();
304
296
auto parent2 = select ();
305
297
auto parent3 = select ();
306
- new_population .emplace_back (make_unique<CandidateConfiguration>(
298
+ selectionPool .emplace_back (make_unique<CandidateConfiguration>(
307
299
crossover (parent1, parent2, parent3)));
308
300
} else {
309
- new_population.emplace_back (
310
- make_unique<CandidateConfiguration>(select ()));
301
+ selectionPool.emplace_back (make_unique<CandidateConfiguration>(select ()));
311
302
}
312
303
}
313
- population = std::move (new_population);
314
304
}
315
305
316
- void GeneticSearch::updateParameters () {
317
- dropInvalidConfigurations (population);
318
-
319
- // Sort population before taking any decision
320
- std::sort (
321
- population.begin (),
322
- population.end (),
323
- [](const std::unique_ptr<CandidateConfiguration>& a,
324
- const std::unique_ptr<CandidateConfiguration>& b) {
325
- checkRuntimeRecorded (a->runtime );
326
- checkRuntimeRecorded (b->runtime );
327
- return a->runtime < b->runtime ;
328
- });
329
-
330
- // Update failsafe lastBestConf
331
- lastBestConf =
332
- population.size () > 0 ? population.front ()->configuration : lastBestConf;
333
-
306
+ bool GeneticSearch::resetPopulationIfNotEnoughCandidates () {
334
307
if (population.size () < minCandidatesForBreeding) {
335
308
LOG_IF (ERROR, FLAGS_debug_tuner)
336
309
<< population.size () << " out of " << maxPopulationSize
@@ -341,30 +314,94 @@ void GeneticSearch::updateParameters() {
341
314
" --tuner_min_launch_total_threads=1. This is mostly relevant "
342
315
" when autotuning a TC operating on small tensors. The next "
343
316
" generation will be randomly initialized." ;
344
- population. resize ( 0 );
345
- for (size_t i = 0 ; i < maxPopulationSize ; ++i) {
346
- population .emplace_back (
317
+ selectionPool. clear ( );
318
+ for (size_t i = 0 ; i < selectionPoolSize ; ++i) {
319
+ selectionPool .emplace_back (
347
320
make_unique<CandidateConfiguration>(lastBestConf));
348
321
}
349
322
// Don't lose the first one which was the best from before
350
- CHECK_LT (0u , population.size ());
351
- randomizePopulation (population.begin () + 1 , population.end (), rng);
352
- return ;
323
+ randomizePopulation (selectionPool.begin () + 1 , selectionPool.end (), rng);
324
+ return true ;
353
325
}
326
+ return false ;
327
+ }
328
+
329
+ namespace {
330
+ void sortByRuntime (GeneticSearch::Population& population) {
331
+ std::sort (
332
+ population.begin (),
333
+ population.end (),
334
+ [](const std::unique_ptr<CandidateConfiguration>& a,
335
+ const std::unique_ptr<CandidateConfiguration>& b) {
336
+ checkRuntimeRecorded (a->runtime );
337
+ checkRuntimeRecorded (b->runtime );
338
+ return a->runtime < b->runtime ;
339
+ });
340
+ }
341
+ } // namespace
354
342
343
+ void GeneticSearch::generateSelectionPool () {
344
+ dropInvalidConfigurations (population);
345
+ sortByRuntime (population);
346
+ lastBestConf =
347
+ population.size () > 0 ? population.front ()->configuration : lastBestConf;
348
+ if (resetPopulationIfNotEnoughCandidates ()) {
349
+ return ;
350
+ }
355
351
breed ();
356
- for (size_t i = numberElites; i < population.size (); ++i) {
357
- mutate (*population[i], mutationRate, mutateIterations, rng);
352
+ selectionPool.clear ();
353
+ selectionPool.emplace_back (make_unique<CandidateConfiguration>(lastBestConf));
354
+ breed ();
355
+ for (size_t i = 1 ; i < selectionPool.size (); ++i) {
356
+ mutate (*selectionPool[i], mutationRate, mutateIterations, rng);
357
+ }
358
+ }
359
+
360
+ void GeneticSearch::selectSurvivors () {
361
+ dropInvalidConfigurations (selectionPool);
362
+ sortByRuntime (selectionPool);
363
+ population.clear ();
364
+ std::transform (
365
+ selectionPool.begin (),
366
+ selectionPool.begin () + std::min (selectionPool.size (), maxPopulationSize),
367
+ std::back_inserter (population),
368
+ [](const std::unique_ptr<CandidateConfiguration>& c) {
369
+ return make_unique<CandidateConfiguration>(c->configuration );
370
+ });
371
+
372
+ if (selectionPool.size () < maxPopulationSize) {
373
+ auto numberMissing = maxPopulationSize - selectionPool.size ();
374
+
375
+ for (size_t i = 0 ; i < numberMissing; ++i) {
376
+ selectionPool.emplace_back (
377
+ make_unique<CandidateConfiguration>(lastBestConf));
378
+ }
379
+ randomizePopulation (
380
+ selectionPool.rbegin (), selectionPool.rbegin () + numberMissing, rng);
358
381
}
359
382
}
360
383
361
384
GeneticSearch::Population& GeneticSearch::candidatesOfStep (uint64_t step) {
362
- if (step != 0 ) {
363
- throw std::invalid_argument (" GeneticSearch has only one step" );
385
+ if (step > 1 ) {
386
+ throw std::invalid_argument (" GeneticSearch has only 2 steps." );
387
+ }
388
+ if (step == 0 ) {
389
+ return population;
390
+ } else {
391
+ return selectionPool;
364
392
}
365
- return population;
366
393
}
367
394
395
+ void GeneticSearch::finishStep (uint64_t step) {
396
+ if (step > 1 ) {
397
+ throw std::invalid_argument (" GeneticSearch has only 2 steps." );
398
+ }
399
+ if (step == 0 ) {
400
+ generateSelectionPool ();
401
+ } else {
402
+ selectSurvivors ();
403
+ }
404
+ }
368
405
} // namespace autotune
369
406
} // namespace tc
370
407
0 commit comments