Skip to content

Commit 942b632

Browse files
committed
Small refactor to reduce cognitive complexity.
1 parent 8e5f510 commit 942b632

File tree

1 file changed

+43
-32
lines changed

1 file changed

+43
-32
lines changed

kernel_tuner/utils/directives.py

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,22 @@ def extract_initialization_code(code: str, langs: Code) -> str:
293293
return ""
294294

295295

296+
def format_argument_fortran(p_type: str, p_size: int, p_name: str) -> str:
297+
if "float*" in p_type:
298+
return f"real (c_float), dimension({p_size}) :: {p_name}"
299+
elif "double*" in p_type:
300+
return f"real (c_double), dimension({p_size}) :: {p_name}"
301+
elif "int*" in p_type:
302+
return f"integer (c_int), dimension({p_size}) :: {p_name}"
303+
elif "float" in p_type:
304+
return f"real (c_float), value :: {p_name}"
305+
elif "double" in p_type:
306+
return f"real (c_double), value :: {p_name}"
307+
elif "int" in p_type:
308+
return f"integer (c_int), value :: {p_name}"
309+
return ""
310+
311+
296312
def extract_directive_signature(code: str, langs: Code, kernel_name: str = None) -> dict:
297313
"""Extract the user defined signature for directive sections"""
298314

@@ -341,18 +357,7 @@ def extract_directive_signature(code: str, langs: Code, kernel_name: str = None)
341357
p_type = param[1:-1]
342358
p_size = p_type.split(":")[1]
343359
p_type = p_type.split(":")[0]
344-
if "float*" in p_type:
345-
params.append(f"real (c_float), dimension({p_size}) :: {p_name}")
346-
elif "double*" in p_type:
347-
params.append(f"real (c_double), dimension({p_size}) :: {p_name}")
348-
elif "int*" in p_type:
349-
params.append(f"integer (c_int), dimension({p_size}) :: {p_name}")
350-
elif "float" in p_type:
351-
params.append(f"real (c_float), value :: {p_name}")
352-
elif "double" in p_type:
353-
params.append(f"real (c_double), value :: {p_name}")
354-
elif "int" in p_type:
355-
params.append(f"integer (c_int), value :: {p_name}")
360+
params.append(format_argument_fortran(p_type, p_size, p_name))
356361
signatures[name] += "\n".join(params) + "\n"
357362
signatures[
358363
name
@@ -448,6 +453,30 @@ def generate_directive_function(
448453
return code
449454

450455

456+
def allocate_array(p_type: str, size: int) -> np.ndarray:
457+
if p_type == "float*":
458+
return np.random.rand(size).astype(np.float32)
459+
elif p_type == "double*":
460+
return np.random.rand(size).astype(np.float64)
461+
elif p_type == "int*":
462+
return np.random.randint(max_int, size=size)
463+
else:
464+
# The parameter is an array of user defined types
465+
return np.random.rand(size).astype(np.byte)
466+
467+
468+
def allocate_scalar(p_type: str, size: int) -> np.number:
469+
if p_type == "float":
470+
return np.float32(size)
471+
elif p_type == "double":
472+
return np.float64(size)
473+
elif p_type == "int":
474+
return np.int32(size)
475+
else:
476+
# The parameter is some user defined type
477+
return np.byte(size)
478+
479+
451480
def allocate_signature_memory(data: dict, preprocessor: list = None, user_dimensions: dict = None) -> list:
452481
"""Allocates the data needed by a kernel and returns the arguments array"""
453482
args = []
@@ -457,26 +486,8 @@ def allocate_signature_memory(data: dict, preprocessor: list = None, user_dimens
457486
p_type = data[parameter][0]
458487
size = parse_size(data[parameter][1], preprocessor, user_dimensions)
459488
if "*" in p_type:
460-
# The parameter is an array
461-
if p_type == "float*":
462-
args.append(np.random.rand(size).astype(np.float32))
463-
elif p_type == "double*":
464-
args.append(np.random.rand(size).astype(np.float64))
465-
elif p_type == "int*":
466-
args.append(np.random.randint(max_int, size=size))
467-
else:
468-
# The parameter is an array of user defined types
469-
args.append(np.random.rand(size).astype(np.byte))
489+
args.append(allocate_array(p_type, size))
470490
else:
471-
# The parameter is a scalar
472-
if p_type == "float":
473-
args.append(np.float32(size))
474-
elif p_type == "double":
475-
args.append(np.float64(size))
476-
elif p_type == "int":
477-
args.append(np.int32(size))
478-
else:
479-
# The parameter is some user defined type
480-
args.append(np.byte(size))
491+
args.append(allocate_scalar(p_type, size))
481492

482493
return args

0 commit comments

Comments
 (0)