Skip to content

Commit 93dce7d

Browse files
committed
Solves various issues with searchspaces and simulation mode
1 parent 1c96693 commit 93dce7d

File tree

5 files changed

+47
-12
lines changed

5 files changed

+47
-12
lines changed

kernel_tuner/interface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,8 @@ def tune_kernel(
670670

671671
# create search space
672672
searchspace = Searchspace(tune_params, restrictions, runner.dev.max_threads)
673+
restrictions = searchspace._modified_restrictions
674+
tuning_options.restrictions = restrictions
673675
if verbose:
674676
print(f"Searchspace has {searchspace.size} configurations after restrictions.")
675677

kernel_tuner/runners/simulation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ def run(self, parameter_space, tuning_options):
128128
continue
129129

130130
# if the element is not in the cache, raise an error
131-
logging.debug(f"kernel configuration {element} not in cache")
132-
raise ValueError(f"Kernel configuration {element} not in cache - in simulation mode, all configurations must be present in the cache")
131+
check = util.check_restrictions(tuning_options.restrictions, dict(zip(tuning_options['tune_params'].keys(), element)), True)
132+
err_string = f"kernel configuration {element} not in cache, does {'' if check else 'not '}pass extra restriction check ({check})"
133+
logging.debug(err_string)
134+
raise ValueError(f"{err_string} - in simulation mode, all configurations must be present in the cache")
133135

134136
return results

kernel_tuner/searchspace.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
restrictions = restrictions if restrictions is not None else []
5353
self.tune_params = tune_params
5454
self.restrictions = restrictions
55+
self._modified_restrictions = restrictions # the searchspace can add commonly used constraints (e.g. maxprod(blocks) <= maxthreads)
5556
self.param_names = list(self.tune_params.keys())
5657
self.params_values = tuple(tuple(param_vals) for param_vals in self.tune_params.values())
5758
self.params_values_indices = None
@@ -166,6 +167,10 @@ def __build_searchspace_bruteforce(self, block_size_names: list, max_threads: in
166167
block_size_restriction_unspaced = f"{'*'.join(used_block_size_names)} <= {max_threads}"
167168
if block_size_restriction_spaced not in restrictions and block_size_restriction_unspaced not in restrictions:
168169
restrictions.append(block_size_restriction_spaced)
170+
if isinstance(self._modified_restrictions, list) and block_size_restriction_spaced not in self._modified_restrictions:
171+
self._modified_restrictions.append(block_size_restriction_spaced)
172+
if isinstance(self.restrictions, list):
173+
self.restrictions.append(block_size_restriction_spaced)
169174

170175
# check for search space restrictions
171176
if restrictions is not None:
@@ -293,6 +298,11 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
293298
)
294299
if len(valid_block_size_names) > 0:
295300
parameter_space.addConstraint(MaxProdConstraint(max_threads), valid_block_size_names)
301+
max_block_size_product = f"{' * '.join(valid_block_size_names)} <= {max_threads}"
302+
if isinstance(self._modified_restrictions, list) and max_block_size_product not in self._modified_restrictions:
303+
self._modified_restrictions.append(max_block_size_product)
304+
if isinstance(self.restrictions, list):
305+
self.restrictions.append((MaxProdConstraint(max_threads), valid_block_size_names))
296306

297307
# construct the parameter space with the constraints applied
298308
return parameter_space.getSolutionsAsListDict(order=self.param_names)
@@ -302,10 +312,14 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
302312
if isinstance(self.restrictions, list):
303313
for restriction in self.restrictions:
304314
required_params = self.param_names
315+
316+
# convert to a Constraint type if necessary
305317
if isinstance(restriction, tuple):
306318
restriction, required_params = restriction
307319
if callable(restriction) and not isinstance(restriction, Constraint):
308320
restriction = FunctionConstraint(restriction)
321+
322+
# add the Constraint
309323
if isinstance(restriction, FunctionConstraint):
310324
parameter_space.addConstraint(restriction, required_params)
311325
elif isinstance(restriction, Constraint):

kernel_tuner/strategies/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __call__(self, x, check_restrictions=True):
7272
# check if max_fevals is reached or time limit is exceeded
7373
util.check_stop_criterion(self.tuning_options)
7474

75-
# snap values in x to nearest actual value for each parameter unscale x if needed
75+
# snap values in x to nearest actual value for each parameter, unscale x if needed
7676
if self.snap:
7777
if self.scaling:
7878
params = unscale_and_snap_to_nearest(x, self.searchspace.tune_params, self.tuning_options.eps)

kernel_tuner/util.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ class StopCriterionReached(Exception):
102102
"block_size_x",
103103
"block_size_y",
104104
"block_size_z",
105-
"ngangs",
106-
"nworkers",
107-
"vlength",
105+
# "ngangs",
106+
# "nworkers",
107+
# "vlength",
108108
]
109109

110110

@@ -263,7 +263,7 @@ def check_restrictions(restrictions, params: dict, verbose: bool) -> bool:
263263
# if it's a python-constraint, convert to function and execute
264264
if isinstance(restrict, Constraint):
265265
restrict = convert_constraint_restriction(restrict)
266-
if not restrict(params.values()):
266+
if not restrict(list(params.values())):
267267
valid = False
268268
break
269269
# if it's a string, fill in the parameters and evaluate
@@ -284,9 +284,15 @@ def check_restrictions(restrictions, params: dict, verbose: bool) -> bool:
284284
# look up the selected parameters and their value
285285
selected_params = dict((key, params[key]) for key in selected_params)
286286
# call the restriction
287-
if not restrict(**selected_params):
288-
valid = False
289-
break
287+
if isinstance(restrict, Constraint):
288+
restrict = convert_constraint_restriction(restrict)
289+
if not restrict(list(selected_params.values())):
290+
valid = False
291+
break
292+
else:
293+
if not restrict(**selected_params):
294+
valid = False
295+
break
290296
# otherwise, raise an error
291297
else:
292298
raise ValueError(f"Unkown restriction type {type(restrict)} ({restrict})")
@@ -884,6 +890,7 @@ def is_or_evals_to_number(s: str) -> Optional[Union[int, float]]:
884890
except Exception:
885891
# it's not a solvable subexpression, return None
886892
return None
893+
887894

888895
# either the left or right side of the equation must evaluate to a constant number
889896
left_num = is_or_evals_to_number(left)
@@ -998,6 +1005,9 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio
9981005
params_used = list(params_used)
9991006
finalized_constraint = None
10001007
if try_to_constraint and " or " not in res and " and " not in res:
1008+
# if applicable, strip the outermost round brackets
1009+
while parsed_restriction[0] == '(' and parsed_restriction[-1] == ')' and '(' not in parsed_restriction[1:] and ')' not in parsed_restriction[:1]:
1010+
parsed_restriction = parsed_restriction[1:-1]
10011011
# check if we can turn this into the built-in numeric comparison constraint
10021012
finalized_constraint = to_numeric_constraint(parsed_restriction, params_used)
10031013
if finalized_constraint is None:
@@ -1052,8 +1062,15 @@ def compile_restrictions(restrictions: list, tune_params: dict, monolithic = Fal
10521062
# return the restrictions and used parameters
10531063
if len(restrictions_ignore) == 0:
10541064
return compiled_restrictions
1055-
restrictions_ignore = list(zip(restrictions_ignore, (() for _ in restrictions_ignore)))
1056-
return restrictions_ignore + compiled_restrictions
1065+
1066+
# use the required parameters or add an empty tuple for unknown parameters of ignored restrictions
1067+
noncompiled_restrictions = []
1068+
for r in restrictions_ignore:
1069+
if isinstance(r, tuple) and len(r) == 2 and isinstance(r[1], (list, tuple)):
1070+
noncompiled_restrictions.append(r)
1071+
else:
1072+
noncompiled_restrictions.append((r, ()))
1073+
return noncompiled_restrictions + compiled_restrictions
10571074

10581075

10591076
def process_cache(cache, kernel_options, tuning_options, runner):

0 commit comments

Comments
 (0)