Skip to content

Commit 199fd99

Browse files
Merge pull request #277 from KernelTuner/support_case_insensitive_block_names
add support for any case spelling of block size name defaults
2 parents 6875ae4 + 22b4bdf commit 199fd99

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

kernel_tuner/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,8 +592,8 @@ def tune_kernel(
592592
# check for forbidden names in tune parameters
593593
util.check_tune_params_list(tune_params, observers, simulation_mode=simulation_mode)
594594

595-
# check whether block_size_names are used as expected
596-
util.check_block_size_params_names_list(block_size_names, tune_params)
595+
# check whether block_size_names are used
596+
block_size_names = util.check_block_size_params_names_list(block_size_names, tune_params)
597597

598598
# ensure there is always at least three names
599599
util.append_default_block_size_names(block_size_names)

kernel_tuner/util.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,22 @@ def check_block_size_params_names_list(block_size_names, tune_params):
235235
"Block size name " + name + " is not specified in the tunable parameters list!", UserWarning
236236
)
237237
else: # if default block size names are used
238-
if not any([k in default_block_size_names for k in tune_params.keys()]):
238+
if not any([k.lower() in default_block_size_names for k in tune_params.keys()]):
239239
warnings.warn(
240240
"None of the tunable parameters specify thread block dimensions!",
241241
UserWarning,
242242
)
243+
else:
244+
# check for alternative case spelling of defaults such as BLOCK_SIZE_X or block_Size_X etc
245+
result = []
246+
for k in tune_params.keys():
247+
if k.lower() in default_block_size_names and k not in default_block_size_names:
248+
result.append(k)
249+
# ensure order of block_size_names is correct regardless of case used
250+
block_size_names = sorted(result, key=str.casefold)
251+
252+
return block_size_names
253+
243254

244255

245256
def check_restriction(restrict, params: dict) -> bool:

0 commit comments

Comments
 (0)