Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 379a450

Browse files
author
Nicolas Vasilache
committed
Update caffe2_benchmak.py to latest python API
Tested internally
1 parent 553203c commit 379a450

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

python/benchmarks/caffe2_benchmark.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import os
2121
import tensor_comprehensions as tc
22+
import torch
2223

2324
from caffe2.python import core, dyndep, workspace, utils
2425

@@ -85,18 +86,22 @@ def main():
8586

8687
@utils.debug
8788
def tune(args):
88-
fc = tc.define(FC_LANG, name="func_fc")
89-
options = fc.autotune(
90-
(args.batch_size, args.input_dim),
91-
(args.output_dim, args.input_dim),
92-
(args.output_dim,),
93-
cache = args.tuner_cache_file,
94-
threads = args.tuner_threads,
95-
generations = args.tuner_gen_generations,
96-
pop_size = args.tuner_gen_pop_size,
97-
)
98-
print(options.toString())
99-
return options
89+
tuner_config = (
90+
tc.TunerConfig()
91+
.generations(args.tuner_gen_generations)
92+
.devices(args.tuner_devices)
93+
.threads(args.tuner_threads)
94+
.pop_size(args.tuner_gen_pop_size))
95+
return tc.autotune(
96+
FC_LANG,
97+
'func_fc',
98+
torch.randn(args.batch_size, args.input_dim, device='cuda'),
99+
torch.randn(args.output_dim, args.input_dim, device='cuda'),
100+
torch.randn(args.output_dim, device='cuda'),
101+
starting_options = tc.MappingOptions('naive'),
102+
tuner_config = tuner_config,
103+
cache_filename = args.tuner_cache_file,
104+
store_to_cache = True)
100105

101106

102107
@utils.debug

0 commit comments

Comments
 (0)