Skip to content

Commit 8360487

Browse files
committed
Only add present to parallel directive.
1 parent c82f27f commit 8360487

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

kernel_tuner/utils/directives.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,23 +63,47 @@ def is_fortran(lang: Language) -> bool:
6363
return isinstance(lang, Fortran)
6464

6565

66-
def line_contains_openacc_pragma(line: str, lang: Language) -> bool:
67-
"""Check if line contains OpenACC pragma or not"""
66+
def line_contains_openacc_directive(line: str, lang: Language) -> bool:
67+
"""Check if line contains an OpenACC pragma or not"""
6868
if is_cxx(lang):
69-
return line_contains_openacc_pragma_cxx(line)
69+
return line_contains_openacc_directive_cxx(line)
7070
elif is_fortran(lang):
71-
return line_contains_openacc_pragma_fortran(line)
71+
return line_contains_openacc_directive_fortran(line)
7272
return False
7373

7474

75-
def line_contains_openacc_pragma_cxx(line: str) -> bool:
76-
"""Check if a line of code contains a C++ OpenACC pragma or not"""
77-
return "#pragma acc" in line
75+
def line_contains_openacc_directive_cxx(line: str) -> bool:
76+
"""Check if a line of code contains a C++ OpenACC directive or not"""
77+
return line_contains(line, "#pragma acc")
7878

7979

80-
def line_contains_openacc_pragma_fortran(line: str) -> bool:
81-
"""Check if a line of code contains a Fortran OpenACC pragma or not"""
82-
return "!$acc" in line
80+
def line_contains_openacc_directive_fortran(line: str) -> bool:
81+
"""Check if a line of code contains a Fortran OpenACC directive or not"""
82+
return line_contains(line, "!$acc")
83+
84+
85+
def line_contains_openacc_parallel_directive(line: str, lang: Language) -> bool:
86+
"""Check if line contains an OpenACC parallel directive or not"""
87+
if is_cxx(lang):
88+
return line_contains_openacc_parallel_directive_cxx(line)
89+
elif is_fortran(lang):
90+
return line_contains_openacc_parallel_directive_fortran(line)
91+
return False
92+
93+
94+
def line_contains_openacc_parallel_directive_cxx(line: str) -> bool:
95+
"""Check if a line of code contains a C++ OpenACC parallel directive or not"""
96+
return line_contains(line, "#pragma acc parallel")
97+
98+
99+
def line_contains_openacc_parallel_directive_fortran(line: str) -> bool:
100+
"""Check if a line of code contains a Fortran OpenACC parallel directive or not"""
101+
return line_contains(line, "!$acc parallel")
102+
103+
104+
def line_contains(line: str, target: str) -> bool:
105+
"""Generic helper to check if a line contains the target"""
106+
return target in line
83107

84108

85109
def openacc_pragma_contains_clause(line: str, clauses: list) -> bool:
@@ -518,7 +542,7 @@ def add_present_openacc(
518542
"""Add the present clause to OpenACC directive"""
519543
new_body = ""
520544
for line in code.replace("\\\n", "").split("\n"):
521-
if not line_contains_openacc_pragma(line, langs.language):
545+
if not line_contains_openacc_parallel_directive(line, langs.language):
522546
new_body += line
523547
else:
524548
# The line contains an OpenACC directive

test/utils/test_directives.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,20 @@ def test_is_fortran():
2020
assert not is_fortran(None)
2121

2222

23-
def test_line_contains_pragma():
23+
def test_line_contains_openacc_directive():
2424
cxx_code = "int main(void) {\n#pragma acc parallel}"
2525
f90_code = "!$acc parallel"
26-
assert line_contains_openacc_pragma(cxx_code, Cxx())
27-
assert not line_contains_openacc_pragma(f90_code, Cxx())
28-
assert line_contains_openacc_pragma(f90_code, Fortran())
29-
assert not line_contains_openacc_pragma(cxx_code, Fortran())
26+
assert line_contains_openacc_directive(cxx_code, Cxx())
27+
assert not line_contains_openacc_directive(f90_code, Cxx())
28+
assert line_contains_openacc_directive(f90_code, Fortran())
29+
assert not line_contains_openacc_directive(cxx_code, Fortran())
30+
31+
32+
def test_line_contains_openacc_parallel_directive():
33+
assert line_contains_openacc_parallel_directive("#pragma acc parallel wait", Cxx())
34+
assert line_contains_openacc_parallel_directive("!$acc parallel", Fortran())
35+
assert not line_contains_openacc_parallel_directive("#pragma acc for", Cxx())
36+
assert not line_contains_openacc_parallel_directive("!$acc for", Fortran())
3037

3138

3239
def test_openacc_pragma_contains_data_clause():

0 commit comments

Comments
 (0)