@@ -274,40 +274,90 @@ struct TcExecutor {
274
274
275
275
class TunerConfig {
276
276
public:
277
- TunerConfig (
278
- uint32_t generations,
279
- uint32_t populationSize,
280
- uint32_t threads,
281
- std::string devices,
282
- bool logtostderr,
283
- uint32_t stderrthreshold) {
284
- generations_ = generations;
285
- populationSize_ = populationSize;
286
- threads_ = threads;
287
- devices_ = devices;
288
- logtostderr_ = logtostderr;
289
- stderrthreshold_ = stderrthreshold;
277
+ TunerConfig ()
278
+ : generations_(tc::FLAGS_tuner_gen_generations),
279
+ populationSize_ (tc::FLAGS_tuner_gen_pop_size),
280
+ crossoverRate_(tc::FLAGS_tuner_gen_crossover_rate),
281
+ mutationRate_(tc::FLAGS_tuner_gen_mutation_rate),
282
+ numberElites_(tc::FLAGS_tuner_gen_number_elites),
283
+ tunerMinLaunchTotalThreads_(tc::FLAGS_tuner_min_launch_total_threads),
284
+ threads_(tc::FLAGS_tuner_threads),
285
+ devices_(tc::FLAGS_tuner_devices),
286
+ logtostderr_(false ),
287
+ // Suppress non-FATAL errors from the python user by default
288
+ stderrthreshold_(google::FATAL) {}
289
+
290
+ TunerConfig& generations (uint32_t val) {
291
+ generations_ = val;
292
+ return *this ;
290
293
}
291
- // __enter__ / __exit__ in case we want to use a ContextManager in Python in
292
- // the future. In any case, RAII and Python GC can just never work together.
293
- void __enter__ () const {
294
+ TunerConfig& populationSize (uint32_t val) {
295
+ populationSize_ = val;
296
+ return *this ;
297
+ }
298
+ TunerConfig& crossoverRate (uint32_t val) {
299
+ crossoverRate_ = val;
300
+ return *this ;
301
+ }
302
+ TunerConfig& mutationRate (uint32_t val) {
303
+ mutationRate_ = val;
304
+ return *this ;
305
+ }
306
+ TunerConfig& numberElites (uint32_t val) {
307
+ numberElites_ = val;
308
+ return *this ;
309
+ }
310
+ TunerConfig& tunerMinLaunchTotalThreads (uint32_t val) {
311
+ tunerMinLaunchTotalThreads_ = val;
312
+ return *this ;
313
+ }
314
+ TunerConfig& threads (uint32_t val) {
315
+ threads_ = val;
316
+ return *this ;
317
+ }
318
+ TunerConfig& devices (const std::string& val) {
319
+ devices_ = val;
320
+ return *this ;
321
+ }
322
+ TunerConfig& logtostderr (bool val) {
323
+ logtostderr_ = val;
324
+ return *this ;
325
+ }
326
+ TunerConfig& stderrthreshold (uint32_t val) {
327
+ stderrthreshold_ = val;
328
+ return *this ;
329
+ }
330
+
331
+ void enter () const {
294
332
savedGenerations_ = tc::FLAGS_tuner_gen_generations;
295
333
savedPopulationSize_ = tc::FLAGS_tuner_gen_pop_size;
334
+ savedCrossoverRate_ = tc::FLAGS_tuner_gen_crossover_rate;
335
+ savedMutationRate_ = tc::FLAGS_tuner_gen_mutation_rate;
336
+ savedNumberElites_ = tc::FLAGS_tuner_gen_number_elites;
337
+ savedTunerMinLaunchTotalThreads_ = tc::FLAGS_tuner_min_launch_total_threads;
296
338
savedThreads_ = tc::FLAGS_tuner_threads;
297
339
savedDevices_ = tc::FLAGS_tuner_devices;
298
340
savedLogtostderr_ = FLAGS_logtostderr;
299
341
savedStderrthreshold_ = FLAGS_stderrthreshold;
300
342
301
343
tc::FLAGS_tuner_gen_generations = generations_;
302
344
tc::FLAGS_tuner_gen_pop_size = populationSize_;
345
+ tc::FLAGS_tuner_gen_crossover_rate = crossoverRate_;
346
+ tc::FLAGS_tuner_gen_mutation_rate = mutationRate_;
347
+ tc::FLAGS_tuner_gen_number_elites = numberElites_;
348
+ tc::FLAGS_tuner_min_launch_total_threads = tunerMinLaunchTotalThreads_;
303
349
tc::FLAGS_tuner_threads = threads_;
304
350
tc::FLAGS_tuner_devices = devices_;
305
351
FLAGS_logtostderr = logtostderr_;
306
352
FLAGS_stderrthreshold = stderrthreshold_;
307
353
}
308
- void __exit__ () const {
354
+ void exit () const {
309
355
tc::FLAGS_tuner_gen_generations = savedGenerations_;
310
356
tc::FLAGS_tuner_gen_pop_size = savedPopulationSize_;
357
+ tc::FLAGS_tuner_gen_crossover_rate = savedCrossoverRate_;
358
+ tc::FLAGS_tuner_gen_mutation_rate = savedMutationRate_;
359
+ tc::FLAGS_tuner_gen_number_elites = savedNumberElites_;
360
+ tc::FLAGS_tuner_min_launch_total_threads = savedTunerMinLaunchTotalThreads_;
311
361
tc::FLAGS_tuner_threads = savedThreads_;
312
362
tc::FLAGS_tuner_devices = savedDevices_;
313
363
FLAGS_logtostderr = savedLogtostderr_;
@@ -317,12 +367,20 @@ class TunerConfig {
317
367
private:
318
368
uint32_t generations_;
319
369
uint32_t populationSize_;
370
+ uint32_t crossoverRate_;
371
+ uint32_t mutationRate_;
372
+ uint32_t numberElites_;
373
+ uint32_t tunerMinLaunchTotalThreads_;
320
374
uint32_t threads_;
321
375
std::string devices_;
322
376
bool logtostderr_;
323
377
uint32_t stderrthreshold_;
324
378
mutable uint32_t savedGenerations_;
325
379
mutable uint32_t savedPopulationSize_;
380
+ mutable uint32_t savedCrossoverRate_;
381
+ mutable uint32_t savedMutationRate_;
382
+ mutable uint32_t savedNumberElites_;
383
+ mutable uint32_t savedTunerMinLaunchTotalThreads_;
326
384
mutable uint32_t savedThreads_;
327
385
mutable std::string savedDevices_;
328
386
mutable bool savedLogtostderr_;
@@ -390,91 +448,34 @@ PYBIND11_MODULE(tclib, m) {
390
448
return TcExecutor{tc, entryPoint, std::move (execUPtr)};
391
449
});
392
450
451
+ // A TunerConfig object can be passed to configure a tuning run
393
452
py::class_<TunerConfig>(m, " TunerConfig" , py::module_local ())
453
+ .def (py::init<>())
454
+ .def (" generations" , &TunerConfig::generations)
455
+ .def (" pop_size" , &TunerConfig::populationSize)
456
+ .def (" crossover_rate" , &TunerConfig::crossoverRate)
457
+ .def (" mutation_rate" , &TunerConfig::mutationRate)
458
+ .def (" number_elites" , &TunerConfig::numberElites)
394
459
.def (
395
- py::init<uint32_t , uint32_t , uint32_t , std::string, bool , uint32_t >(),
396
- py::arg (" generations" ) = tc::FLAGS_tuner_gen_generations,
397
- py::arg (" pop_size" ) = tc::FLAGS_tuner_gen_pop_size,
398
- py::arg (" threads" ) = tc::FLAGS_tuner_threads,
399
- py::arg (" devices" ) = tc::FLAGS_tuner_devices,
400
- py::arg (" logtostderr" ) = false ,
401
- // Suppress non-FATAL errors from the python user
402
- py::arg (" stderrthreshold" ) = google::FATAL);
460
+ " tuner_min_launch_total_threads" ,
461
+ &TunerConfig::tunerMinLaunchTotalThreads)
462
+ .def (" threads" , &TunerConfig::threads)
463
+ .def (" devices" , &TunerConfig::devices)
464
+ .def (" logtostderr" , &TunerConfig::logtostderr)
465
+ .def (" stderrthreshold" , &TunerConfig::stderrthreshold);
403
466
404
467
py::class_<Tuner>(m, " Tuner" , py::module_local ())
405
468
.def (py::init<std::string>())
406
469
.def (py::init<std::string, std::string>())
407
- .def (
408
- " pop_size" ,
409
- [](Tuner& instance, uint32_t & pop_size) {
410
- tc::FLAGS_tuner_gen_pop_size = pop_size;
411
- })
412
- .def (
413
- " crossover_rate" ,
414
- [](Tuner& instance, uint32_t & crossover_rate) {
415
- tc::FLAGS_tuner_gen_crossover_rate = crossover_rate;
416
- })
417
- .def (
418
- " mutation_rate" ,
419
- [](Tuner& instance, uint32_t & mutation_rate) {
420
- tc::FLAGS_tuner_gen_mutation_rate = mutation_rate;
421
- })
422
- .def (
423
- " generations" ,
424
- [](Tuner& instance, uint32_t & generations) {
425
- tc::FLAGS_tuner_gen_generations = generations;
426
- })
427
- .def (
428
- " number_elites" ,
429
- [](Tuner& instance, uint32_t & number_elites) {
430
- tc::FLAGS_tuner_gen_number_elites = number_elites;
431
- })
432
- .def (
433
- " threads" ,
434
- [](Tuner& instance, uint32_t & threads) {
435
- tc::FLAGS_tuner_threads = threads;
436
- })
437
- .def (
438
- " gpus" ,
439
- [](Tuner& instance, std::string& gpus) {
440
- tc::FLAGS_tuner_devices = gpus;
441
- })
442
- .def (
443
- " restore_from_proto" ,
444
- [](Tuner& instance, bool restore_from_proto) {
445
- tc::FLAGS_tuner_gen_restore_from_proto = restore_from_proto;
446
- })
447
- .def (
448
- " restore_number" ,
449
- [](Tuner& instance, uint32_t & restore_number) {
450
- tc::FLAGS_tuner_gen_restore_number = restore_number;
451
- })
452
- .def (
453
- " log_generations" ,
454
- [](Tuner& instance, bool log_generations) {
455
- tc::FLAGS_tuner_gen_log_generations = log_generations;
456
- })
457
- .def (
458
- " tuner_min_launch_total_threads" ,
459
- [](Tuner& instance, bool tuner_min_launch_total_threads) {
460
- tc::FLAGS_tuner_min_launch_total_threads =
461
- tuner_min_launch_total_threads;
462
- })
463
- .def (
464
- " save_best_candidates_count" ,
465
- [](Tuner& instance, bool save_best_candidates_count) {
466
- tc::FLAGS_tuner_save_best_candidates_count =
467
- save_best_candidates_count;
468
- })
469
470
.def (
470
471
" tune" ,
471
472
[](Tuner& instance,
472
473
const std::string& entryPoint,
473
474
const py::tuple& inputs,
474
475
tc::CudaMappingOptions& baseMapping,
475
476
const TunerConfig& config) {
476
- config.__enter__ ();
477
- ScopeGuard sg ([&config]() { config.__exit__ (); });
477
+ config.enter ();
478
+ ScopeGuard sg ([&config]() { config.exit (); });
478
479
std::vector<at::Tensor> atInputs = getATenTensors (inputs);
479
480
auto bestOptions =
480
481
instance.tune (entryPoint, atInputs, {baseMapping});
0 commit comments