Skip to content

Commit 0fc4ad2

Browse files
Merge pull request #251 from KernelTuner/simulation-searchspace-improvements
Small improvements to searchspaces and simulation mode
2 parents 9e8a59a + ea7ca58 commit 0fc4ad2

File tree

5 files changed

+71
-41
lines changed

5 files changed

+71
-41
lines changed

kernel_tuner/interface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,8 @@ def tune_kernel(
670670

671671
# create search space
672672
searchspace = Searchspace(tune_params, restrictions, runner.dev.max_threads)
673+
restrictions = searchspace._modified_restrictions
674+
tuning_options.restrictions = restrictions
673675
if verbose:
674676
print(f"Searchspace has {searchspace.size} configurations after restrictions.")
675677

kernel_tuner/runners/simulation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ def run(self, parameter_space, tuning_options):
128128
continue
129129

130130
# if the element is not in the cache, raise an error
131-
logging.debug(f"kernel configuration {element} not in cache")
132-
raise ValueError(f"Kernel configuration {element} not in cache - in simulation mode, all configurations must be present in the cache")
131+
check = util.check_restrictions(tuning_options.restrictions, dict(zip(tuning_options['tune_params'].keys(), element)), True)
132+
err_string = f"kernel configuration {element} not in cache, does {'' if check else 'not '}pass extra restriction check ({check})"
133+
logging.debug(err_string)
134+
raise ValueError(f"{err_string} - in simulation mode, all configurations must be present in the cache")
133135

134136
return results

kernel_tuner/searchspace.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __init__(
5252
restrictions = restrictions if restrictions is not None else []
5353
self.tune_params = tune_params
5454
self.restrictions = restrictions
55+
# the searchspace can add commonly used constraints (e.g. maxprod(blocks) <= maxthreads)
56+
self._modified_restrictions = restrictions
5557
self.param_names = list(self.tune_params.keys())
5658
self.params_values = tuple(tuple(param_vals) for param_vals in self.tune_params.values())
5759
self.params_values_indices = None
@@ -166,6 +168,10 @@ def __build_searchspace_bruteforce(self, block_size_names: list, max_threads: in
166168
block_size_restriction_unspaced = f"{'*'.join(used_block_size_names)} <= {max_threads}"
167169
if block_size_restriction_spaced not in restrictions and block_size_restriction_unspaced not in restrictions:
168170
restrictions.append(block_size_restriction_spaced)
171+
if isinstance(self._modified_restrictions, list) and block_size_restriction_spaced not in self._modified_restrictions:
172+
self._modified_restrictions.append(block_size_restriction_spaced)
173+
if isinstance(self.restrictions, list):
174+
self.restrictions.append(block_size_restriction_spaced)
169175

170176
# check for search space restrictions
171177
if restrictions is not None:
@@ -293,6 +299,11 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
293299
)
294300
if len(valid_block_size_names) > 0:
295301
parameter_space.addConstraint(MaxProdConstraint(max_threads), valid_block_size_names)
302+
max_block_size_product = f"{' * '.join(valid_block_size_names)} <= {max_threads}"
303+
if isinstance(self._modified_restrictions, list) and max_block_size_product not in self._modified_restrictions:
304+
self._modified_restrictions.append(max_block_size_product)
305+
if isinstance(self.restrictions, list):
306+
self.restrictions.append((MaxProdConstraint(max_threads), valid_block_size_names))
296307

297308
# construct the parameter space with the constraints applied
298309
return parameter_space.getSolutionsAsListDict(order=self.param_names)
@@ -302,10 +313,14 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
302313
if isinstance(self.restrictions, list):
303314
for restriction in self.restrictions:
304315
required_params = self.param_names
316+
317+
# convert to a Constraint type if necessary
305318
if isinstance(restriction, tuple):
306319
restriction, required_params = restriction
307320
if callable(restriction) and not isinstance(restriction, Constraint):
308321
restriction = FunctionConstraint(restriction)
322+
323+
# add the Constraint
309324
if isinstance(restriction, FunctionConstraint):
310325
parameter_space.addConstraint(restriction, required_params)
311326
elif isinstance(restriction, Constraint):

kernel_tuner/strategies/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __call__(self, x, check_restrictions=True):
7272
# check if max_fevals is reached or time limit is exceeded
7373
util.check_stop_criterion(self.tuning_options)
7474

75-
# snap values in x to nearest actual value for each parameter unscale x if needed
75+
# snap values in x to nearest actual value for each parameter, unscale x if needed
7676
if self.snap:
7777
if self.scaling:
7878
params = unscale_and_snap_to_nearest(x, self.searchspace.tune_params, self.tuning_options.eps)

kernel_tuner/util.py

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,6 @@ class StopCriterionReached(Exception):
102102
"block_size_x",
103103
"block_size_y",
104104
"block_size_z",
105-
"ngangs",
106-
"nworkers",
107-
"vlength",
108105
]
109106

110107

@@ -248,9 +245,37 @@ def check_block_size_params_names_list(block_size_names, tune_params):
248245
UserWarning,
249246
)
250247

248+
def check_restriction(restrict, params: dict) -> bool:
249+
"""Check whether a configuration meets a search space restriction."""
250+
# if it's a python-constraint, convert to function and execute
251+
if isinstance(restrict, Constraint):
252+
restrict = convert_constraint_restriction(restrict)
253+
return restrict(list(params.values()))
254+
# if it's a string, fill in the parameters and evaluate
255+
elif isinstance(restrict, str):
256+
return eval(replace_param_occurrences(restrict, params))
257+
# if it's a function, call it
258+
elif callable(restrict):
259+
return restrict(**params)
260+
# if it's a tuple, use only the parameters in the second argument to call the restriction
261+
elif (isinstance(restrict, tuple) and len(restrict) == 2
262+
and callable(restrict[0]) and isinstance(restrict[1], (list, tuple))):
263+
# unpack the tuple
264+
restrict, selected_params = restrict
265+
# look up the selected parameters and their value
266+
selected_params = dict((key, params[key]) for key in selected_params)
267+
# call the restriction
268+
if isinstance(restrict, Constraint):
269+
restrict = convert_constraint_restriction(restrict)
270+
return restrict(list(selected_params.values()))
271+
else:
272+
return restrict(**selected_params)
273+
# otherwise, raise an error
274+
else:
275+
raise ValueError(f"Unkown restriction type {type(restrict)} ({restrict})")
251276

252277
def check_restrictions(restrictions, params: dict, verbose: bool) -> bool:
253-
"""Check whether a specific configuration meets the search space restrictions."""
278+
"""Check whether a configuration meets the search space restrictions."""
254279
if callable(restrictions):
255280
valid = restrictions(params)
256281
if not valid and verbose is True:
@@ -260,40 +285,13 @@ def check_restrictions(restrictions, params: dict, verbose: bool) -> bool:
260285
for restrict in restrictions:
261286
# Check the type of each restriction and validate accordingly. Re-implement as a switch when Python >= 3.10.
262287
try:
263-
# if it's a python-constraint, convert to function and execute
264-
if isinstance(restrict, Constraint):
265-
restrict = convert_constraint_restriction(restrict)
266-
if not restrict(params.values()):
267-
valid = False
268-
break
269-
# if it's a string, fill in the parameters and evaluate
270-
elif isinstance(restrict, str):
271-
if not eval(replace_param_occurrences(restrict, params)):
272-
valid = False
273-
break
274-
# if it's a function, call it
275-
elif callable(restrict):
276-
if not restrict(**params):
277-
valid = False
278-
break
279-
# if it's a tuple, use only the parameters in the second argument to call the restriction
280-
elif (isinstance(restrict, tuple) and len(restrict) == 2
281-
and callable(restrict[0]) and isinstance(restrict[1], (list, tuple))):
282-
# unpack the tuple
283-
restrict, selected_params = restrict
284-
# look up the selected parameters and their value
285-
selected_params = dict((key, params[key]) for key in selected_params)
286-
# call the restriction
287-
if not restrict(**selected_params):
288-
valid = False
289-
break
290-
# otherwise, raise an error
291-
else:
292-
raise ValueError(f"Unkown restriction type {type(restrict)} ({restrict})")
288+
valid = check_restriction(restrict, params)
289+
if not valid:
290+
break
293291
except ZeroDivisionError:
294292
logging.debug(f"Restriction {restrict} with configuration {get_instance_string(params)} divides by zero.")
295293
if not valid and verbose is True:
296-
print(f"skipping config {get_instance_string(params)}, reason: config fails restriction")
294+
print(f"skipping config {get_instance_string(params)}, reason: config fails restriction {restrict}")
297295
return valid
298296

299297

@@ -311,6 +309,9 @@ def f_restrict(p):
311309
elif isinstance(restrict, MaxProdConstraint):
312310
def f_restrict(p):
313311
return np.prod(p) <= restrict._maxprod
312+
elif isinstance(restrict, MinProdConstraint):
313+
def f_restrict(p):
314+
return np.prod(p) >= restrict._minprod
314315
elif isinstance(restrict, MaxSumConstraint):
315316
def f_restrict(p):
316317
return sum(p) <= restrict._maxsum
@@ -1005,6 +1006,9 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio
10051006
params_used = list(params_used)
10061007
finalized_constraint = None
10071008
if try_to_constraint and " or " not in res and " and " not in res:
1009+
# if applicable, strip the outermost round brackets
1010+
while parsed_restriction[0] == '(' and parsed_restriction[-1] == ')' and '(' not in parsed_restriction[1:] and ')' not in parsed_restriction[:1]:
1011+
parsed_restriction = parsed_restriction[1:-1]
10081012
# check if we can turn this into the built-in numeric comparison constraint
10091013
finalized_constraint = to_numeric_constraint(parsed_restriction, params_used)
10101014
if finalized_constraint is None:
@@ -1059,8 +1063,15 @@ def compile_restrictions(restrictions: list, tune_params: dict, monolithic = Fal
10591063
# return the restrictions and used parameters
10601064
if len(restrictions_ignore) == 0:
10611065
return compiled_restrictions
1062-
restrictions_ignore = list(zip(restrictions_ignore, (() for _ in restrictions_ignore)))
1063-
return restrictions_ignore + compiled_restrictions
1066+
1067+
# use the required parameters or add an empty tuple for unknown parameters of ignored restrictions
1068+
noncompiled_restrictions = []
1069+
for r in restrictions_ignore:
1070+
if isinstance(r, tuple) and len(r) == 2 and isinstance(r[1], (list, tuple)):
1071+
noncompiled_restrictions.append(r)
1072+
else:
1073+
noncompiled_restrictions.append((r, ()))
1074+
return noncompiled_restrictions + compiled_restrictions
10641075

10651076

10661077
def process_cache(cache, kernel_options, tuning_options, runner):
@@ -1181,7 +1192,7 @@ def correct_open_cache(cache, open_cache=True):
11811192
filestr = cachefile.read().strip()
11821193

11831194
# if file was not properly closed, pretend it was properly closed
1184-
if len(filestr) > 0 and not filestr[-3:] == "}\n}":
1195+
if len(filestr) > 0 and not filestr[-3:] in ["}\n}", "}}}"]:
11851196
# remove the trailing comma if any, and append closing brackets
11861197
if filestr[-1] == ",":
11871198
filestr = filestr[:-1]

0 commit comments

Comments
 (0)