Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit ff658b9

Browse files
Uniformize examples
This changeset uniformizes the flag usage of tensordot using blockdiagperm as an example
1 parent 7939778 commit ff658b9

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

examples/tensordot.cc

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,7 @@
3030

3131
#include "../test/test_harness_aten_cuda.h"
3232

33-
DEFINE_uint32(number_elites, 2, "Number of elites per generation");
34-
DEFINE_uint32(generations, 3, "Number of generations to tune for");
35-
DEFINE_uint32(pop_size, 10, "Population size to tune for");
36-
DEFINE_uint32(threads, 16, "Number of threads to tune with");
37-
DEFINE_string(gpus, "0", "List of gpus to evaluate on");
33+
DEFINE_string(tuner_proto, "", "Filename to load and store proto cache ");
3834

3935
TEST(TensorDot, SimpleAutotune) {
4036
// 1. Define and setup the TC compilation unit with CUDA memory
@@ -57,7 +53,7 @@ def tensordot(float(N, C1, C2, H, W) I0,
5753
auto naiveOptions = tc::CudaMappingOptions::makeNaiveCudaMappingOptions();
5854
tc::autotune::GeneticAutotunerATen geneticAutotuneATen(tc);
5955
auto bestOption = geneticAutotuneATen.tune(
60-
"/tmp/save_results", "tensordot", {I0, I1}, naiveOptions);
56+
FLAGS_tuner_proto, "tensordot", {I0, I1}, naiveOptions);
6157

6258
// 4. Compile and run the TC with the best option.
6359
// Outputs get allocated; could also be pre-allocated and passed.
@@ -91,15 +87,14 @@ def tensordot(float(N, C1, C2, H, W) I0,
9187
}
9288
}
9389

90+
// From root, run with:
91+
// ./build/examples/tensordot --tuner_threads=10 --tuner_gen_pop_size=10
92+
// --tuner_gen_generations=3 --tuner_gen_number_elites=4
93+
// --tuner_proto="/tmp/tensordot"
9494
int main(int argc, char** argv) {
9595
::testing::InitGoogleTest(&argc, argv);
9696
::gflags::ParseCommandLineFlags(&argc, &argv, true);
9797
::google::InitGoogleLogging(argv[0]);
9898
setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
99-
tc::FLAGS_tuner_gen_number_elites = FLAGS_number_elites;
100-
tc::FLAGS_tuner_gen_generations = FLAGS_generations;
101-
tc::FLAGS_tuner_gen_pop_size = FLAGS_pop_size;
102-
tc::FLAGS_tuner_threads = FLAGS_threads;
103-
tc::FLAGS_tuner_gpus = FLAGS_gpus;
10499
return RUN_ALL_TESTS();
105100
}

0 commit comments

Comments
 (0)