Skip to content

Commit b13f972

Browse files
committed
Added pyATF as searchspace builder
1 parent 41ae1d2 commit b13f972

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

kernel_tuner/searchspace.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
Optionally sort the searchspace by the order in which the parameter values were specified. By default, sort goes from first to last parameter, to reverse this use sort_last_param_first.
4848
"""
4949
# set the object attributes using the arguments
50+
framework_l = framework.lower()
5051
restrictions = restrictions if restrictions is not None else []
5152
self.tune_params = tune_params
5253
self.restrictions = restrictions
@@ -66,21 +67,23 @@ def __init__(
6667
if (
6768
len(restrictions) > 0
6869
and any(isinstance(restriction, str) for restriction in restrictions)
69-
and not (framework.lower() == "pysmt" or framework.lower() == "bruteforce")
70+
and not (framework_l == "pysmt" or framework_l == "pyatf" or framework_l == "bruteforce")
7071
):
7172
self.restrictions = compile_restrictions(
72-
restrictions, tune_params, monolithic=False, try_to_constraint=framework.lower() == "pythonconstraint"
73+
restrictions, tune_params, monolithic=False, try_to_constraint=framework_l == "pythonconstraint"
7374
)
7475

7576
# get the framework given the framework argument
76-
if framework.lower() == "pythonconstraint":
77+
if framework_l == "pythonconstraint":
7778
searchspace_builder = self.__build_searchspace
78-
elif framework.lower() == "pysmt":
79+
elif framework_l == "pysmt":
7980
searchspace_builder = self.__build_searchspace_pysmt
80-
elif framework.lower() == "atf_cache":
81+
elif framework_l == "pyatf":
82+
searchspace_builder = self.__build_searchspace_pyATF
83+
elif framework_l == "atf_cache":
8184
searchspace_builder = self.__build_searchspace_ATF_cache
8285
self.path_to_ATF_cache = path_to_ATF_cache
83-
elif framework.lower() == "bruteforce":
86+
elif framework_l == "bruteforce":
8487
searchspace_builder = self.__build_searchspace_bruteforce
8588
else:
8689
raise ValueError(f"Invalid framework parameter {framework}")
@@ -247,6 +250,28 @@ def all_smt(formula, keys) -> list:
247250

248251
return self.__parameter_space_list_to_lookup_and_return_type(parameter_space_list)
249252

253+
def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, solver: Solver):
254+
"""Builds the searchspace using pyATF."""
255+
from pyatf import TP, Set, Tuner
256+
from pyatf.cost_functions.generic import CostFunction
257+
from pyatf.search_techniques import Exhaustive
258+
259+
costfunc = CostFunction("echo 'hello'")
260+
261+
def get_params():
262+
params = List()
263+
for key, values in self.tune_params.items():
264+
TP(key, Set(values))
265+
return params
266+
267+
tuning_result = (
268+
Tuner()
269+
.tuning_parameters(*get_params())
270+
.search_technique(Exhaustive())
271+
.tune(costfunc)
272+
)
273+
return tuning_result
274+
250275
def __build_searchspace_ATF_cache(self, block_size_names: list, max_threads: int, solver: Solver):
251276
"""Imports the valid configurations from an ATF CSV file, returns the searchspace, a dict of the searchspace for fast lookups and the size."""
252277
if block_size_names != default_block_size_names or max_threads != 1024:

0 commit comments

Comments
 (0)