@@ -82,6 +82,20 @@ def line_contains_openacc_pragma_fortran(line: str) -> bool:
82
82
return "!$acc" in line
83
83
84
84
85
+ def openacc_pragma_contains_clause (line : str , clauses : list ) -> bool :
86
+ """Check if an OpenACC directive contains one clause from a list"""
87
+ for clause in clauses :
88
+ if clause in line :
89
+ return True
90
+ return False
91
+
92
+
93
+ def openacc_pragma_contains_data_clause (line : str ) -> bool :
94
+ """Check if an OpenACC directive contains one data clause"""
95
+ data_clauses = ["copy" , "copyin" , "copyout" , "create" , "no_create" , "present" , "device_ptr" , "attach" ]
96
+ return openacc_pragma_contains_clause (line , data_clauses )
97
+
98
+
85
99
def create_data_directive_openacc (name : str , size : int , lang : Language ) -> str :
86
100
"""Create a data directive for a given language"""
87
101
if is_cxx (lang ):
@@ -244,23 +258,18 @@ def end_timing_cxx(code: str) -> str:
244
258
return code + "\n return elapsed_time.count();\n "
245
259
246
260
247
- def wrap_data (code : str , langs : Code , data : dict , preprocessor : list , user_dimensions : dict ) -> str :
261
+ def wrap_data (code : str , langs : Code , data : dict , preprocessor : list = None , user_dimensions : dict = None ) -> str :
248
262
"""Insert data directives before and after the timed code"""
249
263
intro = str ()
250
- for name in data .keys ():
251
- if "*" in data [name ][0 ]:
252
- size = parse_size (data [name ][1 ], preprocessor = preprocessor , dimensions = user_dimensions )
253
- if is_openacc (langs .directive ) and is_cxx (langs .language ):
254
- intro += create_data_directive_openacc_cxx (name , size )
255
- elif is_openacc (langs .directive ) and is_fortran (langs .language ):
256
- intro += create_data_directive_openacc_fortran (name , size )
257
264
outro = str ()
258
265
for name in data .keys ():
259
266
if "*" in data [name ][0 ]:
260
267
size = parse_size (data [name ][1 ], preprocessor = preprocessor , dimensions = user_dimensions )
261
268
if is_openacc (langs .directive ) and is_cxx (langs .language ):
269
+ intro += create_data_directive_openacc_cxx (name , size )
262
270
outro += exit_data_directive_openacc_cxx (name , size )
263
271
elif is_openacc (langs .directive ) and is_fortran (langs .language ):
272
+ intro += create_data_directive_openacc_fortran (name , size )
264
273
outro += exit_data_directive_openacc_fortran (name , size )
265
274
return intro + code + outro
266
275
@@ -439,6 +448,8 @@ def generate_directive_function(
439
448
code += "\n " + signature
440
449
if len (initialization ) > 1 :
441
450
code += initialization + "\n "
451
+ if data is not None :
452
+ body = add_present_openacc (body , langs , data , preprocessor , user_dimensions )
442
453
if is_cxx (langs .language ):
443
454
body = start_timing_cxx (body )
444
455
if data is not None :
@@ -499,3 +510,40 @@ def allocate_signature_memory(data: dict, preprocessor: list = None, user_dimens
499
510
args .append (allocate_scalar (p_type , size ))
500
511
501
512
return args
513
+
514
+
515
+ def add_present_openacc (
516
+ code : str , langs : Code , data : dict , preprocessor : list = None , user_dimensions : dict = None
517
+ ) -> str :
518
+ """Add the present clause to OpenACC directive"""
519
+ new_body = ""
520
+ for line in code .replace ("\\ \n " , "" ).split ("\n " ):
521
+ if not line_contains_openacc_pragma (line , langs .language ):
522
+ new_body += line
523
+ else :
524
+ # The line contains an OpenACC directive
525
+ if openacc_pragma_contains_data_clause (line ):
526
+ # The OpenACC directive manages memory, do not interfere
527
+ return code
528
+ else :
529
+ new_line = line .replace ("\n " , "" )
530
+ present_clause = ""
531
+ for name in data .keys ():
532
+ if "*" in data [name ][0 ]:
533
+ size = parse_size (data [name ][1 ], preprocessor = preprocessor , dimensions = user_dimensions )
534
+ if is_cxx (langs .language ):
535
+ present_clause += add_present_openacc_cxx (name , size )
536
+ elif is_fortran (langs .language ):
537
+ present_clause += add_present_openacc_fortran (name , size )
538
+ new_body += new_line + present_clause .rstrip () + "\n "
539
+ return new_body
540
+
541
+
542
+ def add_present_openacc_cxx (name : str , size : int ) -> str :
543
+ """Create present clause for C++ OpenACC directive"""
544
+ return f" present({ name } [:{ size } ]) "
545
+
546
+
547
+ def add_present_openacc_fortran (name : str , size : int ) -> str :
548
+ """Create present clause for Fortran OpenACC directive"""
549
+ return f" present({ name } (:{ size } )) "
0 commit comments