|
19 | 19 | import numpy as np
|
20 | 20 | import os
|
21 | 21 | import tensor_comprehensions as tc
|
| 22 | +import torch |
22 | 23 |
|
23 | 24 | from caffe2.python import core, dyndep, workspace, utils
|
24 | 25 |
|
@@ -85,18 +86,22 @@ def main():
|
85 | 86 |
|
86 | 87 | @utils.debug
|
87 | 88 | def tune(args):
|
88 |
| - fc = tc.define(FC_LANG, name="func_fc") |
89 |
| - options = fc.autotune( |
90 |
| - (args.batch_size, args.input_dim), |
91 |
| - (args.output_dim, args.input_dim), |
92 |
| - (args.output_dim,), |
93 |
| - cache = args.tuner_cache_file, |
94 |
| - threads = args.tuner_threads, |
95 |
| - generations = args.tuner_gen_generations, |
96 |
| - pop_size = args.tuner_gen_pop_size, |
97 |
| - ) |
98 |
| - print(options.toString()) |
99 |
| - return options |
| 89 | + tuner_config = ( |
| 90 | + tc.TunerConfig() |
| 91 | + .generations(args.tuner_gen_generations) |
| 92 | + .devices(args.tuner_devices) |
| 93 | + .threads(args.tuner_threads) |
| 94 | + .pop_size(args.tuner_gen_pop_size)) |
| 95 | + return tc.autotune( |
| 96 | + FC_LANG, |
| 97 | + 'func_fc', |
| 98 | + torch.randn(args.batch_size, args.input_dim, device='cuda'), |
| 99 | + torch.randn(args.output_dim, args.input_dim, device='cuda'), |
| 100 | + torch.randn(args.output_dim, device='cuda'), |
| 101 | + starting_options = tc.MappingOptions('naive'), |
| 102 | + tuner_config = tuner_config, |
| 103 | + cache_filename = args.tuner_cache_file, |
| 104 | + store_to_cache = True) |
100 | 105 |
|
101 | 106 |
|
102 | 107 | @utils.debug
|
|
0 commit comments