Skip to content

Commit 691da11

Browse files
committed
feat: kt now accepts HIP Python DeviceArray type
1 parent 81ef2ac commit 691da11

File tree

4 files changed

+107
-30
lines changed

4 files changed

+107
-30
lines changed

kernel_tuner/backends/compiler.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@
2626
except ImportError:
2727
cp = None
2828

29+
try:
30+
from jip import hip
31+
except ImportError:
32+
hip = None
33+
34+
try:
35+
from hip._util.types import DeviceArray
36+
except ImportError:
37+
Pointer = Exception # using Exception here as a type that will never be among kernel arguments
38+
DeviceArray = Exception
39+
2940

3041
def is_cupy_array(array):
3142
"""Check if something is a cupy array.
@@ -145,9 +156,9 @@ def ready_argument_list(self, arguments):
145156
ctype_args = [None for _ in arguments]
146157

147158
for i, arg in enumerate(arguments):
148-
if not (isinstance(arg, (np.ndarray, np.number)) or is_cupy_array(arg)):
149-
raise TypeError(f"Argument is not numpy or cupy ndarray or numpy scalar but a {type(arg)}")
150-
dtype_str = str(arg.dtype)
159+
if not (isinstance(arg, (np.ndarray, np.number, DeviceArray)) or is_cupy_array(arg)):
160+
raise TypeError(f"Argument is not numpy or cupy ndarray or numpy scalar or HIP Python DeviceArray but a {type(arg)}")
161+
dtype_str = arg.typestr if isinstance(arg, DeviceArray) else str(arg.dtype)
151162
if isinstance(arg, np.ndarray):
152163
if dtype_str in dtype_map.keys():
153164
# In numpy <= 1.15, ndarray.ctypes.data_as does not itself keep a reference
@@ -156,13 +167,20 @@ def ready_argument_list(self, arguments):
156167
# (This changed in numpy > 1.15.)
157168
# data_ctypes = data.ctypes.data_as(C.POINTER(dtype_map[dtype_str]))
158169
data_ctypes = arg.ctypes.data_as(C.POINTER(dtype_map[dtype_str]))
170+
numpy_arg = arg
159171
else:
160172
raise TypeError("unknown dtype for ndarray")
161173
elif isinstance(arg, np.generic):
162174
data_ctypes = dtype_map[dtype_str](arg)
175+
numpy_arg = arg
163176
elif is_cupy_array(arg):
164177
data_ctypes = C.c_void_p(arg.data.ptr)
165-
ctype_args[i] = Argument(numpy=arg, ctypes=data_ctypes)
178+
numpy_arg = arg
179+
elif isinstance(arg, DeviceArray):
180+
data_ctypes = arg.as_c_void_p()
181+
numpy_arg = None
182+
183+
ctype_args[i] = Argument(numpy=numpy_arg, ctypes=data_ctypes)
166184
return ctype_args
167185

168186
def compile(self, kernel_instance):
@@ -380,6 +398,12 @@ def memcpy_dtoh(self, dest, src):
380398
:param src: An Argument for some memory allocation
381399
:type src: Argument
382400
"""
401+
# If src.numpy is None, it means we're dealing with a HIP Python DeviceArray
402+
if src.numpy is None:
403+
# Skip memory copies for HIP Python DeviceArray
404+
# This is because DeviceArray manages its own memory and donesn't need
405+
# explicit copies like numpy arrays do
406+
return
383407
if isinstance(dest, np.ndarray) and is_cupy_array(src.numpy):
384408
# Implicit conversion to a NumPy array is not allowed.
385409
value = src.numpy.get()
@@ -397,6 +421,12 @@ def memcpy_htod(self, dest, src):
397421
:param src: A numpy or cupy array containing the source data
398422
:type src: np.ndarray or cupy.ndarray
399423
"""
424+
# If src.numpy is None, it means we're dealing with a HIP Python DeviceArray
425+
if dest.numpy is None:
426+
# Skip memory copies for HIP Python DeviceArray
427+
# This is because DeviceArray manages its own memory and donesn't need
428+
# explicit copies like numpy arrays do
429+
return
400430
if isinstance(dest.numpy, np.ndarray) and is_cupy_array(src):
401431
# Implicit conversion to a NumPy array is not allowed.
402432
value = src.get()

kernel_tuner/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
except ImportError:
3030
torch = util.TorchPlaceHolder()
3131

32+
try:
33+
from hip._util.types import DeviceArray
34+
except ImportError:
35+
DeviceArray = Exception # using Exception here as a type that will never be among kernel arguments
36+
3237
_KernelInstance = namedtuple(
3338
"_KernelInstance",
3439
[
@@ -495,7 +500,7 @@ def check_kernel_output(
495500

496501
should_sync = [answer[i] is not None for i, arg in enumerate(instance.arguments)]
497502
else:
498-
should_sync = [isinstance(arg, (np.ndarray, cp.ndarray, torch.Tensor)) for arg in instance.arguments]
503+
should_sync = [isinstance(arg, (np.ndarray, cp.ndarray, torch.Tensor, DeviceArray)) for arg in instance.arguments]
499504

500505
# re-copy original contents of output arguments to GPU memory, to overwrite any changes
501506
# by earlier kernel runs
@@ -659,7 +664,7 @@ def compile_kernel(self, instance, verbose):
659664
f"skipping config {util.get_instance_string(instance.params)} reason: too much shared memory used"
660665
)
661666
else:
662-
logging.debug("compile_kernel failed due to error: " + str(e))
667+
print("compile_kernel failed due to error: " + error_message)
663668
print("Error while compiling:", instance.name)
664669
raise e
665670
return func

kernel_tuner/util.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ class StopCriterionReached(Exception):
102102
"block_size_z",
103103
]
104104

105+
try:
106+
from hip._util.types import DeviceArray
107+
except ImportError:
108+
DeviceArray = Exception # using Exception here as a type that will never be among kernel arguments
105109

106110
def check_argument_type(dtype, kernel_argument):
107111
"""Check if the numpy.dtype matches the type used in the code."""
@@ -125,59 +129,59 @@ def check_argument_type(dtype, kernel_argument):
125129
return any([substr in kernel_argument for substr in types_map[dtype]])
126130
return False # unknown dtype. do not throw exception to still allow kernel to run.
127131

128-
129132
def check_argument_list(kernel_name, kernel_string, args):
130-
"""Raise an exception if a kernel arguments do not match host arguments."""
133+
"""Raise an exception if kernel arguments do not match host arguments."""
131134
kernel_arguments = list()
132135
collected_errors = list()
136+
133137
for iterator in re.finditer(kernel_name + "[ \n\t]*" + r"\(", kernel_string):
134138
kernel_start = iterator.end()
135139
kernel_end = kernel_string.find(")", kernel_start)
136140
if kernel_start != 0:
137141
kernel_arguments.append(kernel_string[kernel_start:kernel_end].split(","))
142+
138143
for arguments_set, arguments in enumerate(kernel_arguments):
139144
collected_errors.append(list())
140145
if len(arguments) != len(args):
141146
collected_errors[arguments_set].append(
142147
"Kernel and host argument lists do not match in size."
143148
)
144149
continue
150+
145151
for i, arg in enumerate(args):
146152
kernel_argument = arguments[i]
147153

148-
# Fix to deal with tunable arguments
154+
# Handle tunable arguments
149155
if isinstance(arg, Tunable):
150156
continue
151-
152-
if not isinstance(arg, (np.ndarray, np.generic, cp.ndarray, torch.Tensor)):
157+
158+
# Handle numpy arrays and other array types
159+
if not isinstance(arg, (np.ndarray, np.generic, cp.ndarray, torch.Tensor, DeviceArray)):
153160
raise TypeError(
154-
"Argument at position "
155-
+ str(i)
156-
+ " of type: "
157-
+ str(type(arg))
158-
+ " should be of type np.ndarray or numpy scalar"
161+
f"Argument at position {i} of type: {type(arg)} should be of type "
162+
"np.ndarray, numpy scalar, or HIP Python DeviceArray type"
159163
)
160164

161165
correct = True
162-
if isinstance(arg, np.ndarray) and "*" not in kernel_argument:
163-
correct = False # array is passed to non-pointer kernel argument
166+
if isinstance(arg, np.ndarray):
167+
if "*" not in kernel_argument:
168+
correct = False
169+
170+
if isinstance(arg, DeviceArray):
171+
str_dtype = str(np.dtype(arg.typestr))
172+
else:
173+
str_dtype = str(arg.dtype)
164174

165-
if correct and check_argument_type(str(arg.dtype), kernel_argument):
175+
if correct and check_argument_type(str_dtype, kernel_argument):
166176
continue
167-
177+
168178
collected_errors[arguments_set].append(
169-
"Argument at position "
170-
+ str(i)
171-
+ " of dtype: "
172-
+ str(arg.dtype)
173-
+ " does not match "
174-
+ kernel_argument
175-
+ "."
179+
f"Argument at position {i} of dtype: {str_dtype} does not match {kernel_argument}."
176180
)
181+
177182
if not collected_errors[arguments_set]:
178-
# We assume that if there is a possible list of arguments that matches with the provided one
179-
# it is the right one
180183
return
184+
181185
for errors in collected_errors:
182186
warnings.warn(errors[0], UserWarning)
183187

test/test_file_utils.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
import json
22

33
import pytest
4+
import ctypes
45
from jsonschema import validate
6+
import numpy as np
7+
import warnings
8+
try:
9+
from hip import hip
10+
except:
11+
hip = None
512

613
from kernel_tuner.file_utils import output_file_schema, store_metadata_file, store_output_file
7-
from kernel_tuner.util import delete_temp_file
14+
from kernel_tuner.util import delete_temp_file, check_argument_list
15+
from .context import skip_if_no_hip
816

917
from .test_runners import cache_filename, env, tune_kernel # noqa: F401
1018

@@ -55,3 +63,33 @@ def test_store_metadata_file():
5563
finally:
5664
# clean up
5765
delete_temp_file(filename)
66+
67+
def hip_check(call_result):
68+
err = call_result[0]
69+
result = call_result[1:]
70+
if len(result) == 1:
71+
result = result[0]
72+
if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess:
73+
raise RuntimeError(str(err))
74+
return result
75+
76+
@skip_if_no_hip
77+
def test_check_argument_list_device_array():
78+
"""Test check_argument_list with DeviceArray"""
79+
float_kernel = """
80+
__global__ void simple_kernel(float* input) {
81+
// kernel code
82+
}
83+
"""
84+
host_array = np.ones((100,), dtype=np.float32)
85+
num_bytes = host_array.size * host_array.itemsize
86+
device_array = hip_check(hip.hipMalloc(num_bytes))
87+
device_array.configure(
88+
typestr="float32",
89+
shape=host_array.shape,
90+
itemsize=host_array.itemsize
91+
)
92+
93+
with warnings.catch_warnings():
94+
warnings.simplefilter("error")
95+
check_argument_list("simple_kernel", float_kernel, [device_array])

0 commit comments

Comments
 (0)