Skip to content

Commit 90918fc

Browse files
add support for any case spelling of block size name defaults
1 parent 4c0a877 commit 90918fc

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

kernel_tuner/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,8 +592,8 @@ def tune_kernel(
592592
# check for forbidden names in tune parameters
593593
util.check_tune_params_list(tune_params, observers, simulation_mode=simulation_mode)
594594

595-
# check whether block_size_names are used as expected
596-
util.check_block_size_params_names_list(block_size_names, tune_params)
595+
# check whether block_size_names are used
596+
block_size_names = util.check_block_size_params_names_list(block_size_names, tune_params)
597597

598598
# ensure there is always at least three names
599599
util.append_default_block_size_names(block_size_names)

kernel_tuner/util.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,22 @@ def check_block_size_params_names_list(block_size_names, tune_params):
237237
"Block size name " + name + " is not specified in the tunable parameters list!", UserWarning
238238
)
239239
else: # if default block size names are used
240-
if not any([k in default_block_size_names for k in tune_params.keys()]):
240+
if not any([k.lower() in default_block_size_names for k in tune_params.keys()]):
241241
warnings.warn(
242242
"None of the tunable parameters specify thread block dimensions!",
243243
UserWarning,
244244
)
245+
else:
246+
# check for alternative case spelling of defaults such as BLOCK_SIZE_X or block_Size_X etc
247+
result = []
248+
for k in tune_params.keys():
249+
if k.lower() in default_block_size_names and k not in default_block_size_names:
250+
result.append(k)
251+
# ensure order of block_size_names is correct regardless of case used
252+
block_size_names = sorted(result, key=str.casefold)
253+
254+
return block_size_names
255+
245256

246257
def check_restriction(restrict, params: dict) -> bool:
247258
"""Check whether a configuration meets a search space restriction."""

0 commit comments

Comments
 (0)