Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit e8a7e04

Browse files
Simplify and update TunerConfig
TunerConfig acts similarly to a python ContextManager (similar to RAII effects). Therefore we can simplify the user-facing API and just construct a TunerConfig instead of calling a bunch of methods that will set up a bunch of flags to non-default values and never set them back. Additionally, add the previously missing TunerConfig fields to the user-facing API.
1 parent fb0d365 commit e8a7e04

File tree

2 files changed

+95
-106
lines changed

2 files changed

+95
-106
lines changed

python/examples/tc_pybind_example.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def matmul_grad(float(M,N) A, float(N,K) B, float(M,K) d_O) -> (d_A, d_B) {
137137
"matmul",
138138
(mat1, mat2),
139139
MappingOptions('naive'),
140-
TunerConfig(threads = 8, pop_size = 25, generations = 3, devices = "0"))
140+
TunerConfig().threads(8).pop_size(25).generations(3).devices("0"))
141141
cache = MappingOptionsCache(unique_filename)
142142
top10 = cache.load(mm, "matmul", (mat1, mat2), 10)
143143
assert top1.__str__() == top10[0].__str__()
@@ -227,10 +227,7 @@ def compileOrTune(self, name = "", force_reinforcement_tuning = False, inputs =
227227
tcb = TcBuilder(
228228
tc = mm,
229229
tuner_cache_file = "/tmp/some_cache_file_we_reuse_for_perf_reinforcement",
230-
tuner_config = TunerConfig(threads = 8,
231-
pop_size = 25,
232-
generations = 3,
233-
devices = "0"))
230+
tuner_config = TunerConfig().threads(8).pop_size(25).generations(3).devices("0"))
234231

235232
tcb.compileOrTune(name = "matmul", inputs = (mat1, mat2))
236233
time_tc(100,
@@ -310,10 +307,7 @@ def backward(ctx, *gradients):
310307
backward_force_reinforcement_tuning = False,
311308
check_output_shapes = False,
312309
tuner_cache_file = "/tmp/some_cache_file_we_reuse_for_perf_reinforcement",
313-
tuner_config = TunerConfig(threads = 8,
314-
pop_size = 25,
315-
generations = 3,
316-
devices = "0"),
310+
tuner_config = TunerConfig().threads(8).pop_size(25).generations(3).devices("0"),
317311
)
318312

319313
time_tc(100,
@@ -431,10 +425,7 @@ def compileOrTune(self, name = "", force_reinforcement_tuning = False, inputs =
431425
tcb = MultiTcBuilder(
432426
tc = mm,
433427
tuner_cache_file = "/tmp/some_cache_file_we_reuse_for_perf_reinforcement",
434-
tuner_config = TunerConfig(threads = 8,
435-
pop_size = 25,
436-
generations = 3,
437-
devices = "0"))
428+
tuner_config = TunerConfig().threads(8).pop_size(25).generations(3).devices("0"))
438429

439430
tcb.compileOrTune(name = "matmul", inputs = (mat1, mat2))
440431
time_tc(100,
@@ -535,10 +526,7 @@ def backward(ctx, *gradients):
535526
backward_force_reinforcement_tunings = (False, ),
536527
check_output_shapes = False,
537528
tuner_cache_file = "/tmp/some_cache_file_we_reuse_for_perf_reinforcement",
538-
tuner_config = TunerConfig(threads = 8,
539-
pop_size = 25,
540-
generations = 3,
541-
devices = "0"),
529+
tuner_config = TunerConfig().threads(8).pop_size(25).generations(3).devices("0"),
542530
)
543531

544532
time_tc(100,

tensor_comprehensions/pybinds/tclib.cc

Lines changed: 90 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -274,40 +274,90 @@ struct TcExecutor {
274274

275275
class TunerConfig {
276276
public:
277-
TunerConfig(
278-
uint32_t generations,
279-
uint32_t populationSize,
280-
uint32_t threads,
281-
std::string devices,
282-
bool logtostderr,
283-
uint32_t stderrthreshold) {
284-
generations_ = generations;
285-
populationSize_ = populationSize;
286-
threads_ = threads;
287-
devices_ = devices;
288-
logtostderr_ = logtostderr;
289-
stderrthreshold_ = stderrthreshold;
277+
TunerConfig()
278+
: generations_(tc::FLAGS_tuner_gen_generations),
279+
populationSize_(tc::FLAGS_tuner_gen_pop_size),
280+
crossoverRate_(tc::FLAGS_tuner_gen_crossover_rate),
281+
mutationRate_(tc::FLAGS_tuner_gen_mutation_rate),
282+
numberElites_(tc::FLAGS_tuner_gen_number_elites),
283+
tunerMinLaunchTotalThreads_(tc::FLAGS_tuner_min_launch_total_threads),
284+
threads_(tc::FLAGS_tuner_threads),
285+
devices_(tc::FLAGS_tuner_devices),
286+
logtostderr_(false),
287+
// Suppress non-FATAL errors from the python user by default
288+
stderrthreshold_(google::FATAL) {}
289+
290+
TunerConfig& generations(uint32_t val) {
291+
generations_ = val;
292+
return *this;
290293
}
291-
// __enter__ / __exit__ in case we want to use a ContextManager in Python in
292-
// the future. In any case, RAII and Python GC can just never work together.
293-
void __enter__() const {
294+
TunerConfig& populationSize(uint32_t val) {
295+
populationSize_ = val;
296+
return *this;
297+
}
298+
TunerConfig& crossoverRate(uint32_t val) {
299+
crossoverRate_ = val;
300+
return *this;
301+
}
302+
TunerConfig& mutationRate(uint32_t val) {
303+
mutationRate_ = val;
304+
return *this;
305+
}
306+
TunerConfig& numberElites(uint32_t val) {
307+
numberElites_ = val;
308+
return *this;
309+
}
310+
TunerConfig& tunerMinLaunchTotalThreads(uint32_t val) {
311+
tunerMinLaunchTotalThreads_ = val;
312+
return *this;
313+
}
314+
TunerConfig& threads(uint32_t val) {
315+
threads_ = val;
316+
return *this;
317+
}
318+
TunerConfig& devices(const std::string& val) {
319+
devices_ = val;
320+
return *this;
321+
}
322+
TunerConfig& logtostderr(bool val) {
323+
logtostderr_ = val;
324+
return *this;
325+
}
326+
TunerConfig& stderrthreshold(uint32_t val) {
327+
stderrthreshold_ = val;
328+
return *this;
329+
}
330+
331+
void enter() const {
294332
savedGenerations_ = tc::FLAGS_tuner_gen_generations;
295333
savedPopulationSize_ = tc::FLAGS_tuner_gen_pop_size;
334+
savedCrossoverRate_ = tc::FLAGS_tuner_gen_crossover_rate;
335+
savedMutationRate_ = tc::FLAGS_tuner_gen_mutation_rate;
336+
savedNumberElites_ = tc::FLAGS_tuner_gen_number_elites;
337+
savedTunerMinLaunchTotalThreads_ = tc::FLAGS_tuner_min_launch_total_threads;
296338
savedThreads_ = tc::FLAGS_tuner_threads;
297339
savedDevices_ = tc::FLAGS_tuner_devices;
298340
savedLogtostderr_ = FLAGS_logtostderr;
299341
savedStderrthreshold_ = FLAGS_stderrthreshold;
300342

301343
tc::FLAGS_tuner_gen_generations = generations_;
302344
tc::FLAGS_tuner_gen_pop_size = populationSize_;
345+
tc::FLAGS_tuner_gen_crossover_rate = crossoverRate_;
346+
tc::FLAGS_tuner_gen_mutation_rate = mutationRate_;
347+
tc::FLAGS_tuner_gen_number_elites = numberElites_;
348+
tc::FLAGS_tuner_min_launch_total_threads = tunerMinLaunchTotalThreads_;
303349
tc::FLAGS_tuner_threads = threads_;
304350
tc::FLAGS_tuner_devices = devices_;
305351
FLAGS_logtostderr = logtostderr_;
306352
FLAGS_stderrthreshold = stderrthreshold_;
307353
}
308-
void __exit__() const {
354+
void exit() const {
309355
tc::FLAGS_tuner_gen_generations = savedGenerations_;
310356
tc::FLAGS_tuner_gen_pop_size = savedPopulationSize_;
357+
tc::FLAGS_tuner_gen_crossover_rate = savedCrossoverRate_;
358+
tc::FLAGS_tuner_gen_mutation_rate = savedMutationRate_;
359+
tc::FLAGS_tuner_gen_number_elites = savedNumberElites_;
360+
tc::FLAGS_tuner_min_launch_total_threads = savedTunerMinLaunchTotalThreads_;
311361
tc::FLAGS_tuner_threads = savedThreads_;
312362
tc::FLAGS_tuner_devices = savedDevices_;
313363
FLAGS_logtostderr = savedLogtostderr_;
@@ -317,12 +367,20 @@ class TunerConfig {
317367
private:
318368
uint32_t generations_;
319369
uint32_t populationSize_;
370+
uint32_t crossoverRate_;
371+
uint32_t mutationRate_;
372+
uint32_t numberElites_;
373+
uint32_t tunerMinLaunchTotalThreads_;
320374
uint32_t threads_;
321375
std::string devices_;
322376
bool logtostderr_;
323377
uint32_t stderrthreshold_;
324378
mutable uint32_t savedGenerations_;
325379
mutable uint32_t savedPopulationSize_;
380+
mutable uint32_t savedCrossoverRate_;
381+
mutable uint32_t savedMutationRate_;
382+
mutable uint32_t savedNumberElites_;
383+
mutable uint32_t savedTunerMinLaunchTotalThreads_;
326384
mutable uint32_t savedThreads_;
327385
mutable std::string savedDevices_;
328386
mutable bool savedLogtostderr_;
@@ -390,91 +448,34 @@ PYBIND11_MODULE(tclib, m) {
390448
return TcExecutor{tc, entryPoint, std::move(execUPtr)};
391449
});
392450

451+
// A TunerConfig object can be passed to configure a tuning run
393452
py::class_<TunerConfig>(m, "TunerConfig", py::module_local())
453+
.def(py::init<>())
454+
.def("generations", &TunerConfig::generations)
455+
.def("pop_size", &TunerConfig::populationSize)
456+
.def("crossover_rate", &TunerConfig::crossoverRate)
457+
.def("mutation_rate", &TunerConfig::mutationRate)
458+
.def("number_elites", &TunerConfig::numberElites)
394459
.def(
395-
py::init<uint32_t, uint32_t, uint32_t, std::string, bool, uint32_t>(),
396-
py::arg("generations") = tc::FLAGS_tuner_gen_generations,
397-
py::arg("pop_size") = tc::FLAGS_tuner_gen_pop_size,
398-
py::arg("threads") = tc::FLAGS_tuner_threads,
399-
py::arg("devices") = tc::FLAGS_tuner_devices,
400-
py::arg("logtostderr") = false,
401-
// Suppress non-FATAL errors from the python user
402-
py::arg("stderrthreshold") = google::FATAL);
460+
"tuner_min_launch_total_threads",
461+
&TunerConfig::tunerMinLaunchTotalThreads)
462+
.def("threads", &TunerConfig::threads)
463+
.def("devices", &TunerConfig::devices)
464+
.def("logtostderr", &TunerConfig::logtostderr)
465+
.def("stderrthreshold", &TunerConfig::stderrthreshold);
403466

404467
py::class_<Tuner>(m, "Tuner", py::module_local())
405468
.def(py::init<std::string>())
406469
.def(py::init<std::string, std::string>())
407-
.def(
408-
"pop_size",
409-
[](Tuner& instance, uint32_t& pop_size) {
410-
tc::FLAGS_tuner_gen_pop_size = pop_size;
411-
})
412-
.def(
413-
"crossover_rate",
414-
[](Tuner& instance, uint32_t& crossover_rate) {
415-
tc::FLAGS_tuner_gen_crossover_rate = crossover_rate;
416-
})
417-
.def(
418-
"mutation_rate",
419-
[](Tuner& instance, uint32_t& mutation_rate) {
420-
tc::FLAGS_tuner_gen_mutation_rate = mutation_rate;
421-
})
422-
.def(
423-
"generations",
424-
[](Tuner& instance, uint32_t& generations) {
425-
tc::FLAGS_tuner_gen_generations = generations;
426-
})
427-
.def(
428-
"number_elites",
429-
[](Tuner& instance, uint32_t& number_elites) {
430-
tc::FLAGS_tuner_gen_number_elites = number_elites;
431-
})
432-
.def(
433-
"threads",
434-
[](Tuner& instance, uint32_t& threads) {
435-
tc::FLAGS_tuner_threads = threads;
436-
})
437-
.def(
438-
"gpus",
439-
[](Tuner& instance, std::string& gpus) {
440-
tc::FLAGS_tuner_devices = gpus;
441-
})
442-
.def(
443-
"restore_from_proto",
444-
[](Tuner& instance, bool restore_from_proto) {
445-
tc::FLAGS_tuner_gen_restore_from_proto = restore_from_proto;
446-
})
447-
.def(
448-
"restore_number",
449-
[](Tuner& instance, uint32_t& restore_number) {
450-
tc::FLAGS_tuner_gen_restore_number = restore_number;
451-
})
452-
.def(
453-
"log_generations",
454-
[](Tuner& instance, bool log_generations) {
455-
tc::FLAGS_tuner_gen_log_generations = log_generations;
456-
})
457-
.def(
458-
"tuner_min_launch_total_threads",
459-
[](Tuner& instance, bool tuner_min_launch_total_threads) {
460-
tc::FLAGS_tuner_min_launch_total_threads =
461-
tuner_min_launch_total_threads;
462-
})
463-
.def(
464-
"save_best_candidates_count",
465-
[](Tuner& instance, bool save_best_candidates_count) {
466-
tc::FLAGS_tuner_save_best_candidates_count =
467-
save_best_candidates_count;
468-
})
469470
.def(
470471
"tune",
471472
[](Tuner& instance,
472473
const std::string& entryPoint,
473474
const py::tuple& inputs,
474475
tc::CudaMappingOptions& baseMapping,
475476
const TunerConfig& config) {
476-
config.__enter__();
477-
ScopeGuard sg([&config]() { config.__exit__(); });
477+
config.enter();
478+
ScopeGuard sg([&config]() { config.exit(); });
478479
std::vector<at::Tensor> atInputs = getATenTensors(inputs);
479480
auto bestOptions =
480481
instance.tune(entryPoint, atInputs, {baseMapping});

0 commit comments

Comments
 (0)