Skip to content

Commit c82f27f

Browse files
committed
Automatically add OpenACC present directory for tuning (if necessary).
1 parent c49b93a commit c82f27f

File tree

2 files changed

+78
-12
lines changed

2 files changed

+78
-12
lines changed

kernel_tuner/utils/directives.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,20 @@ def line_contains_openacc_pragma_fortran(line: str) -> bool:
8282
return "!$acc" in line
8383

8484

85+
def openacc_pragma_contains_clause(line: str, clauses: list) -> bool:
86+
"""Check if an OpenACC directive contains one clause from a list"""
87+
for clause in clauses:
88+
if clause in line:
89+
return True
90+
return False
91+
92+
93+
def openacc_pragma_contains_data_clause(line: str) -> bool:
94+
"""Check if an OpenACC directive contains one data clause"""
95+
data_clauses = ["copy", "copyin", "copyout", "create", "no_create", "present", "device_ptr", "attach"]
96+
return openacc_pragma_contains_clause(line, data_clauses)
97+
98+
8599
def create_data_directive_openacc(name: str, size: int, lang: Language) -> str:
86100
"""Create a data directive for a given language"""
87101
if is_cxx(lang):
@@ -244,23 +258,18 @@ def end_timing_cxx(code: str) -> str:
244258
return code + "\nreturn elapsed_time.count();\n"
245259

246260

247-
def wrap_data(code: str, langs: Code, data: dict, preprocessor: list, user_dimensions: dict) -> str:
261+
def wrap_data(code: str, langs: Code, data: dict, preprocessor: list = None, user_dimensions: dict = None) -> str:
248262
"""Insert data directives before and after the timed code"""
249263
intro = str()
250-
for name in data.keys():
251-
if "*" in data[name][0]:
252-
size = parse_size(data[name][1], preprocessor=preprocessor, dimensions=user_dimensions)
253-
if is_openacc(langs.directive) and is_cxx(langs.language):
254-
intro += create_data_directive_openacc_cxx(name, size)
255-
elif is_openacc(langs.directive) and is_fortran(langs.language):
256-
intro += create_data_directive_openacc_fortran(name, size)
257264
outro = str()
258265
for name in data.keys():
259266
if "*" in data[name][0]:
260267
size = parse_size(data[name][1], preprocessor=preprocessor, dimensions=user_dimensions)
261268
if is_openacc(langs.directive) and is_cxx(langs.language):
269+
intro += create_data_directive_openacc_cxx(name, size)
262270
outro += exit_data_directive_openacc_cxx(name, size)
263271
elif is_openacc(langs.directive) and is_fortran(langs.language):
272+
intro += create_data_directive_openacc_fortran(name, size)
264273
outro += exit_data_directive_openacc_fortran(name, size)
265274
return intro + code + outro
266275

@@ -439,6 +448,8 @@ def generate_directive_function(
439448
code += "\n" + signature
440449
if len(initialization) > 1:
441450
code += initialization + "\n"
451+
if data is not None:
452+
body = add_present_openacc(body, langs, data, preprocessor, user_dimensions)
442453
if is_cxx(langs.language):
443454
body = start_timing_cxx(body)
444455
if data is not None:
@@ -499,3 +510,40 @@ def allocate_signature_memory(data: dict, preprocessor: list = None, user_dimens
499510
args.append(allocate_scalar(p_type, size))
500511

501512
return args
513+
514+
515+
def add_present_openacc(
516+
code: str, langs: Code, data: dict, preprocessor: list = None, user_dimensions: dict = None
517+
) -> str:
518+
"""Add the present clause to OpenACC directive"""
519+
new_body = ""
520+
for line in code.replace("\\\n", "").split("\n"):
521+
if not line_contains_openacc_pragma(line, langs.language):
522+
new_body += line
523+
else:
524+
# The line contains an OpenACC directive
525+
if openacc_pragma_contains_data_clause(line):
526+
# The OpenACC directive manages memory, do not interfere
527+
return code
528+
else:
529+
new_line = line.replace("\n", "")
530+
present_clause = ""
531+
for name in data.keys():
532+
if "*" in data[name][0]:
533+
size = parse_size(data[name][1], preprocessor=preprocessor, dimensions=user_dimensions)
534+
if is_cxx(langs.language):
535+
present_clause += add_present_openacc_cxx(name, size)
536+
elif is_fortran(langs.language):
537+
present_clause += add_present_openacc_fortran(name, size)
538+
new_body += new_line + present_clause.rstrip() + "\n"
539+
return new_body
540+
541+
542+
def add_present_openacc_cxx(name: str, size: int) -> str:
543+
"""Create present clause for C++ OpenACC directive"""
544+
return f" present({name}[:{size}]) "
545+
546+
547+
def add_present_openacc_fortran(name: str, size: int) -> str:
548+
"""Create present clause for Fortran OpenACC directive"""
549+
return f" present({name}(:{size})) "

test/utils/test_directives.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,19 @@ def test_line_contains_pragma():
2929
assert not line_contains_openacc_pragma(cxx_code, Fortran())
3030

3131

32+
def test_openacc_pragma_contains_data_clause():
33+
assert openacc_pragma_contains_data_clause("#pragma acc parallel present(A[:1089])")
34+
assert not openacc_pragma_contains_data_clause("#pragma acc parallel for")
35+
36+
3237
def test_create_data_directive():
3338
assert (
34-
create_data_directive_openacc("array", 1024, Cxx())
35-
== "#pragma acc enter data create(array[:1024])\n#pragma acc update device(array[:1024])\n"
39+
create_data_directive_openacc("array", 1024, Cxx())
40+
== "#pragma acc enter data create(array[:1024])\n#pragma acc update device(array[:1024])\n"
3641
)
3742
assert (
38-
create_data_directive_openacc("matrix", 35, Fortran())
39-
== "!$acc enter data create(matrix(:35))\n!$acc update device(matrix(:35))\n"
43+
create_data_directive_openacc("matrix", 35, Fortran())
44+
== "!$acc enter data create(matrix(:35))\n!$acc update device(matrix(:35))\n"
4045
)
4146

4247

@@ -291,3 +296,16 @@ def test_extract_initialization_code():
291296
code_f90 = "!$tuner initialize\ninteger :: value\n!$tuner stop\n"
292297
assert extract_initialization_code(code_cpp, Code(OpenACC(), Cxx())) == "const int value = 42;\n"
293298
assert extract_initialization_code(code_f90, Code(OpenACC(), Fortran())) == "integer :: value\n"
299+
300+
301+
def test_add_present_openacc():
302+
acc_cxx = Code(OpenACC(), Cxx())
303+
acc_f90 = Code(OpenACC(), Fortran())
304+
code_cxx = "#pragma acc parallel num_gangs(32)\n"
305+
code_f90 = "!$acc parallel async num_workers(16)\n"
306+
data = {"array": ["int*", "size"]}
307+
preprocessor = ["#define size 42"]
308+
expected_cxx = "#pragma acc parallel num_gangs(32) present(array[:42])\n"
309+
assert add_present_openacc(code_cxx, acc_cxx, data, preprocessor, None) == expected_cxx
310+
expected_f90 = "!$acc parallel async num_workers(16) present(array(:42))\n"
311+
assert add_present_openacc(code_f90, acc_f90, data, preprocessor, None) == expected_f90

0 commit comments

Comments
 (0)