Skip to content

Commit ee16b86

Browse files
simplified to_valid_nvrtc_gpu_arch_cc
1 parent 2ec9417 commit ee16b86

File tree

2 files changed

+3
-17
lines changed

2 files changed

+3
-17
lines changed

kernel_tuner/util.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -570,22 +570,11 @@ def get_total_timings(results, env, overhead_time):
570570
return env
571571

572572

573+
NVRTC_VALID_CC = np.array(['50', '52', '53', '60', '61', '62', '70', '72', '75', '80', '87', '89', '90', '90a'])
574+
573575
def to_valid_nvrtc_gpu_arch_cc(compute_capability: str) -> str:
574576
"""Returns a valid Compute Capability for NVRTC `--gpu-architecture=`, as per https://docs.nvidia.com/cuda/nvrtc/index.html#group__options."""
575-
valid_cc = ['50', '52', '53', '60', '61', '62', '70', '72', '75', '80', '87', '89', '90', '90a'] # must be in ascending order, when updating also update test_to_valid_nvrtc_gpu_arch_cc
576-
compute_capability = str(compute_capability)
577-
if len(compute_capability) < 2:
578-
raise ValueError(f"Compute capability '{compute_capability}' must be at least of length 2, is {len(compute_capability)}")
579-
if compute_capability in valid_cc:
580-
return compute_capability
581-
# if the compute capability does not match, scale down to the nearest matching
582-
subset_cc = [cc for cc in valid_cc if compute_capability[0] == cc[0]]
583-
if len(subset_cc) > 0:
584-
# get the next-highest valid CC
585-
highest_cc_index = max([i for i, cc in enumerate(subset_cc) if int(cc[1]) <= int(compute_capability[1])])
586-
return subset_cc[highest_cc_index]
587-
# if all else fails, return the default 52
588-
return '52'
577+
return max(NVRTC_VALID_CC[NVRTC_VALID_CC<=compute_capability], default='52')
589578

590579

591580
def print_config(config, tuning_options, runner):

test/test_util_functions.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,6 @@ def test_to_valid_nvrtc_gpu_arch_cc():
154154
assert to_valid_nvrtc_gpu_arch_cc("90b") == "90a"
155155
assert to_valid_nvrtc_gpu_arch_cc("91c") == "90a"
156156
assert to_valid_nvrtc_gpu_arch_cc("1234") == "52"
157-
with pytest.raises(ValueError):
158-
assert to_valid_nvrtc_gpu_arch_cc("")
159-
assert to_valid_nvrtc_gpu_arch_cc("1")
160157

161158

162159
def test_prepare_kernel_string():

0 commit comments

Comments
 (0)