Skip to content

Commit 68f7ce6

Browse files
formatted with black
1 parent 5c4ba1a commit 68f7ce6

File tree

2 files changed

+142
-160
lines changed

2 files changed

+142
-160
lines changed

kernel_tuner/backends/hip.py

Lines changed: 29 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232

3333
hipSuccess = 0
3434

35+
3536
def hip_check(call_result):
37+
"""helper function to check return values of hip calls"""
3638
err = call_result[0]
3739
result = call_result[1:]
3840
if len(result) == 1:
@@ -41,6 +43,7 @@ def hip_check(call_result):
4143
raise RuntimeError(str(err))
4244
return result
4345

46+
4447
class HipFunctions(GPUBackend):
4548
"""Class that groups the HIP functions on maintains state about the device."""
4649

@@ -59,7 +62,9 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
5962
:type iterations: int
6063
"""
6164
if not hip or not hiprtc:
62-
raise ImportError("Unable to import HIP Python, check https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-hip-python.")
65+
raise ImportError(
66+
"Unable to import HIP Python, check https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-hip-python."
67+
)
6368

6469
# embedded in try block to be able to generate documentation
6570
# and run tests without HIP Python installed
@@ -69,7 +74,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
6974
props = hip.hipDeviceProp_t()
7075
hip_check(hip.hipGetDeviceProperties(props, device))
7176

72-
self.name = props.name.decode('utf-8')
77+
self.name = props.name.decode("utf-8")
7378
self.max_threads = props.maxThreadsPerBlock
7479
self.device = device
7580
self.compiler_options = compiler_options or []
@@ -81,7 +86,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
8186
env["compiler_options"] = compiler_options
8287
self.env = env
8388

84-
# Create stream and events
89+
# Create stream and events
8590
self.stream = hip_check(hip.hipStreamCreate())
8691
self.start = hip_check(hip.hipEventCreate())
8792
self.end = hip_check(hip.hipEventCreate())
@@ -108,40 +113,34 @@ def ready_argument_list(self, arguments):
108113
"""
109114
logging.debug("HipFunction ready_argument_list called")
110115
prepared_args = []
111-
116+
112117
for arg in arguments:
113118
dtype_str = str(arg.dtype)
114-
119+
115120
# Handle numpy arrays
116121
if isinstance(arg, np.ndarray):
117122
if dtype_str in dtype_map.keys():
118123
# Allocate device memory
119124
device_ptr = hip_check(hip.hipMalloc(arg.nbytes))
120-
125+
121126
# Copy data to device using hipMemcpy
122-
hip_check(hip.hipMemcpy(
123-
device_ptr,
124-
arg,
125-
arg.nbytes,
126-
hip.hipMemcpyKind.hipMemcpyHostToDevice
127-
))
128-
127+
hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
128+
129129
prepared_args.append(device_ptr)
130130
else:
131131
raise TypeError(f"Unknown dtype {dtype_str} for ndarray")
132-
132+
133133
# Handle numpy scalar types
134134
elif isinstance(arg, np.generic):
135135
# Convert numpy scalar to corresponding ctypes
136136
ctype_arg = dtype_map[dtype_str](arg)
137137
prepared_args.append(ctype_arg)
138-
138+
139139
else:
140140
raise ValueError(f"Invalid argument type {type(arg)}, {arg}")
141141

142142
return prepared_args
143143

144-
145144
def compile(self, kernel_instance):
146145
"""Call the HIP compiler to compile the kernel, return the function.
147146
@@ -159,28 +158,22 @@ def compile(self, kernel_instance):
159158
kernel_name = kernel_instance.name
160159
if 'extern "C"' not in kernel_string:
161160
kernel_string = 'extern "C" {\n' + kernel_string + "\n}"
162-
161+
163162
# Create program
164-
prog = hip_check(hiprtc.hiprtcCreateProgram(
165-
kernel_string.encode(),
166-
kernel_name.encode(),
167-
0,
168-
[],
169-
[]
170-
))
163+
prog = hip_check(hiprtc.hiprtcCreateProgram(kernel_string.encode(), kernel_name.encode(), 0, [], []))
171164

172165
try:
173166
# Get device properties
174167
props = hip.hipDeviceProp_t()
175168
hip_check(hip.hipGetDeviceProperties(props, 0))
176-
169+
177170
# Setup compilation options
178171
arch = props.gcnArchName
179172
cflags = [b"--offload-arch=" + arch]
180173
cflags.extend([opt.encode() if isinstance(opt, str) else opt for opt in self.compiler_options])
181174

182175
# Compile program
183-
err, = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags)
176+
(err,) = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags)
184177
if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS:
185178
# Get compilation log if there's an error
186179
log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(prog))
@@ -208,19 +201,19 @@ def compile(self, kernel_instance):
208201
def start_event(self):
209202
"""Records the event that marks the start of a measurement."""
210203
logging.debug("HipFunction start_event called")
211-
204+
212205
hip_check(hip.hipEventRecord(self.start, self.stream))
213206

214207
def stop_event(self):
215208
"""Records the event that marks the end of a measurement."""
216209
logging.debug("HipFunction stop_event called")
217-
210+
218211
hip_check(hip.hipEventRecord(self.end, self.stream))
219212

220213
def kernel_finished(self):
221214
"""Returns True if the kernel has finished, False otherwise."""
222215
logging.debug("HipFunction kernel_finished called")
223-
216+
224217
# ROCm HIP returns (hipError_t, bool) for hipEventQuery
225218
status = hip.hipEventQuery(self.end)
226219
if status[0] == hip.hipError_t.hipSuccess:
@@ -233,7 +226,7 @@ def kernel_finished(self):
233226
def synchronize(self):
234227
"""Halts execution until device has finished its tasks."""
235228
logging.debug("HipFunction synchronize called")
236-
229+
237230
hip_check(hip.hipDeviceSynchronize())
238231

239232
def run_kernel(self, func, gpu_args, threads, grid, stream=None):
@@ -242,7 +235,7 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
242235
:param func: A HIP kernel compiled for this specific kernel configuration
243236
:type func: hipFunction_t
244237
245-
:param gpu_args: List of arguments to pass to the kernel. Can be DeviceArray
238+
:param gpu_args: List of arguments to pass to the kernel. Can be DeviceArray
246239
objects or ctypes values
247240
:type gpu_args: list
248241
@@ -272,7 +265,7 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
272265
sharedMemBytes=self.smem_size,
273266
stream=stream,
274267
kernelParams=None,
275-
extra=tuple(gpu_args)
268+
extra=tuple(gpu_args),
276269
)
277270
)
278271

@@ -303,12 +296,7 @@ def memcpy_dtoh(self, dest, src):
303296
"""
304297
logging.debug("HipFunction memcpy_dtoh called")
305298

306-
hip_check(hip.hipMemcpy(
307-
dest,
308-
src,
309-
dest.nbytes,
310-
hip.hipMemcpyKind.hipMemcpyDeviceToHost
311-
))
299+
hip_check(hip.hipMemcpy(dest, src, dest.nbytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost))
312300

313301
def memcpy_htod(self, dest, src):
314302
"""Perform a host to device memory copy.
@@ -321,12 +309,7 @@ def memcpy_htod(self, dest, src):
321309
"""
322310
logging.debug("HipFunction memcpy_htod called")
323311

324-
hip_check(hip.hipMemcpy(
325-
dest,
326-
src,
327-
src.nbytes,
328-
hip.hipMemcpyKind.hipMemcpyHostToDevice
329-
))
312+
hip_check(hip.hipMemcpy(dest, src, src.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
330313

331314
def copy_constant_memory_args(self, cmem_args):
332315
"""Adds constant memory arguments to the most recently compiled module.
@@ -343,18 +326,10 @@ def copy_constant_memory_args(self, cmem_args):
343326
# Iterate over dictionary
344327
for symbol_name, data in cmem_args.items():
345328
# Get symbol pointer and size using hipModuleGetGlobal
346-
dptr, _ = hip_check(hip.hipModuleGetGlobal(
347-
self.current_module,
348-
symbol_name.encode()
349-
))
329+
dptr, _ = hip_check(hip.hipModuleGetGlobal(self.current_module, symbol_name.encode()))
350330

351331
# Copy data to the global memory location
352-
hip_check(hip.hipMemcpy(
353-
dptr,
354-
data,
355-
data.nbytes,
356-
hip.hipMemcpyKind.hipMemcpyHostToDevice
357-
))
332+
hip_check(hip.hipMemcpy(dptr, data, data.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
358333

359334
def copy_shared_memory_args(self, smem_args):
360335
"""Add shared memory arguments to the kernel."""

0 commit comments

Comments
 (0)