@@ -48,6 +48,48 @@ def __init__(self, directive: Directive, lang: Language):
48
48
self .language = lang
49
49
50
50
51
+ class ArraySize (object ):
52
+ """Size of an array"""
53
+
54
+ def __init__ (self ):
55
+ self .size = list ()
56
+
57
+ def __iter__ (self ):
58
+ for i in self .size :
59
+ yield i
60
+
61
+ def __len__ (self ):
62
+ return len (self .size )
63
+
64
+ def clear (self ):
65
+ self .size .clear ()
66
+
67
+ def get (self ) -> int :
68
+ length = len (self .size )
69
+ if length == 0 :
70
+ return 0
71
+ elif length == 1 :
72
+ return self .size [0 ]
73
+ else :
74
+ product = 1
75
+ for i in self .size :
76
+ product *= i
77
+ return product
78
+
79
+ def add (self , dim : int ) -> None :
80
+ # Only allow adding valid dimensions
81
+ if dim >= 1 :
82
+ self .size .append (dim )
83
+
84
+
85
+ def fortran_md_size (size : ArraySize ) -> list :
86
+ """Format a multidimensional size into the correct Fortran string"""
87
+ md_size = list ()
88
+ for dim in size :
89
+ md_size .append (f":{ dim } " )
90
+ return md_size
91
+
92
+
51
93
def is_openacc (directive : Directive ) -> bool :
52
94
"""Check if a directive is OpenACC"""
53
95
return isinstance (directive , OpenACC )
@@ -120,7 +162,7 @@ def openacc_directive_contains_data_clause(line: str) -> bool:
120
162
return openacc_directive_contains_clause (line , data_clauses )
121
163
122
164
123
- def create_data_directive_openacc (name : str , size : int , lang : Language ) -> str :
165
+ def create_data_directive_openacc (name : str , size : ArraySize , lang : Language ) -> str :
124
166
"""Create a data directive for a given language"""
125
167
if is_cxx (lang ):
126
168
return create_data_directive_openacc_cxx (name , size )
@@ -129,17 +171,23 @@ def create_data_directive_openacc(name: str, size: int, lang: Language) -> str:
129
171
return ""
130
172
131
173
132
- def create_data_directive_openacc_cxx (name : str , size : int ) -> str :
174
+ def create_data_directive_openacc_cxx (name : str , size : ArraySize ) -> str :
133
175
"""Create C++ OpenACC code to allocate and copy data"""
134
- return f"#pragma acc enter data create({ name } [:{ size } ])\n #pragma acc update device({ name } [:{ size } ])\n "
176
+ return f"#pragma acc enter data create({ name } [:{ size . get () } ])\n #pragma acc update device({ name } [:{ size . get () } ])\n "
135
177
136
178
137
- def create_data_directive_openacc_fortran (name : str , size : int ) -> str :
179
+ def create_data_directive_openacc_fortran (name : str , size : ArraySize ) -> str :
138
180
"""Create Fortran OpenACC code to allocate and copy data"""
139
- return f"!$acc enter data create({ name } (:{ size } ))\n !$acc update device({ name } (:{ size } ))\n "
181
+ if len (size ) == 1 :
182
+ return f"!$acc enter data create({ name } (:{ size .get ()} ))\n !$acc update device({ name } (:{ size .get ()} ))\n "
183
+ else :
184
+ md_size = fortran_md_size (size )
185
+ return (
186
+ f"!$acc enter data create({ name } ({ ',' .join (md_size )} ))\n !$acc update device({ name } ({ ',' .join (md_size )} ))\n "
187
+ )
140
188
141
189
142
- def exit_data_directive_openacc (name : str , size : int , lang : Language ) -> str :
190
+ def exit_data_directive_openacc (name : str , size : ArraySize , lang : Language ) -> str :
143
191
"""Create code to copy data back for a given language"""
144
192
if is_cxx (lang ):
145
193
return exit_data_directive_openacc_cxx (name , size )
@@ -148,14 +196,18 @@ def exit_data_directive_openacc(name: str, size: int, lang: Language) -> str:
148
196
return ""
149
197
150
198
151
- def exit_data_directive_openacc_cxx (name : str , size : int ) -> str :
199
+ def exit_data_directive_openacc_cxx (name : str , size : ArraySize ) -> str :
152
200
"""Create C++ OpenACC code to copy back data"""
153
- return f"#pragma acc exit data copyout({ name } [:{ size } ])\n "
201
+ return f"#pragma acc exit data copyout({ name } [:{ size . get () } ])\n "
154
202
155
203
156
- def exit_data_directive_openacc_fortran (name : str , size : int ) -> str :
204
+ def exit_data_directive_openacc_fortran (name : str , size : ArraySize ) -> str :
157
205
"""Create Fortran OpenACC code to copy back data"""
158
- return f"!$acc exit data copyout({ name } (:{ size } ))\n "
206
+ if len (size ) == 1 :
207
+ return f"!$acc exit data copyout({ name } (:{ size .get ()} ))\n "
208
+ else :
209
+ md_size = fortran_md_size (size )
210
+ return f"!$acc exit data copyout({ name } ({ ',' .join (md_size )} ))\n "
159
211
160
212
161
213
def correct_kernel (kernel_name : str , line : str ) -> bool :
@@ -165,7 +217,7 @@ def correct_kernel(kernel_name: str, line: str) -> bool:
165
217
166
218
def find_size_in_preprocessor (dimension : str , preprocessor : list ) -> int :
167
219
"""Find the dimension of a directive defined value in the preprocessor"""
168
- ret_size = None
220
+ ret_size = 0
169
221
for line in preprocessor :
170
222
if f"#define { dimension } " in line :
171
223
try :
@@ -209,45 +261,43 @@ def extract_code(start: str, stop: str, code: str, langs: Code, kernel_name: str
209
261
return sections
210
262
211
263
212
- def parse_size (size : Any , preprocessor : list = None , dimensions : dict = None ) -> int :
264
+ def parse_size (size : Any , preprocessor : list = None , dimensions : dict = None ) -> ArraySize :
213
265
"""Converts an arbitrary object into an integer representing memory size"""
214
- ret_size = None
266
+ ret_size = ArraySize ()
215
267
if type (size ) is not int :
216
268
try :
217
269
# Try to convert the size to an integer
218
- ret_size = int (size )
270
+ ret_size . add ( int (size ) )
219
271
except ValueError :
220
272
# If size cannot be natively converted to an int, we try to derive it from the preprocessor
221
- if preprocessor is not None :
222
- try :
273
+ try :
274
+ if preprocessor is not None :
223
275
if "," in size :
224
- ret_size = 1
225
276
for dimension in size .split ("," ):
226
- ret_size *= find_size_in_preprocessor (dimension , preprocessor )
277
+ ret_size . add ( find_size_in_preprocessor (dimension , preprocessor ) )
227
278
else :
228
- ret_size = find_size_in_preprocessor (size , preprocessor )
229
- except TypeError :
230
- # preprocessor is available but does not contain the dimensions
231
- pass
279
+ ret_size . add ( find_size_in_preprocessor (size , preprocessor ) )
280
+ except TypeError :
281
+ # At least one of the dimension cannot be derived from the preprocessor
282
+ pass
232
283
# If size cannot be natively converted, nor retrieved from the preprocessor, we check user provided values
233
284
if dimensions is not None :
234
285
if size in dimensions .keys ():
235
286
try :
236
- ret_size = int (dimensions [size ])
287
+ ret_size . add ( int (dimensions [size ]) )
237
288
except ValueError :
238
289
# User error, no mitigation
239
290
return ret_size
240
291
elif "," in size :
241
- ret_size = 1
242
292
for dimension in size .split ("," ):
243
293
try :
244
- ret_size *= int (dimensions [dimension ])
294
+ ret_size . add ( int (dimensions [dimension ]) )
245
295
except ValueError :
246
296
# User error, no mitigation
247
- return None
297
+ return ret_size
248
298
else :
249
299
# size is already an int. no need for conversion
250
- ret_size = size
300
+ ret_size . add ( size )
251
301
252
302
return ret_size
253
303
@@ -297,8 +347,13 @@ def wrap_data(code: str, langs: Code, data: dict, preprocessor: list = None, use
297
347
intro += create_data_directive_openacc_cxx (name , size )
298
348
outro += exit_data_directive_openacc_cxx (name , size )
299
349
elif is_openacc (langs .directive ) and is_fortran (langs .language ):
300
- intro += create_data_directive_openacc_fortran (name , size )
301
- outro += exit_data_directive_openacc_fortran (name , size )
350
+ if "," in data [name ][1 ]:
351
+ # Multi dimensional
352
+ pass
353
+ else :
354
+ # One dimensional
355
+ intro += create_data_directive_openacc_fortran (name , size )
356
+ outro += exit_data_directive_openacc_fortran (name , size )
302
357
return intro + code + outro
303
358
304
359
@@ -537,9 +592,9 @@ def allocate_signature_memory(data: dict, preprocessor: list = None, user_dimens
537
592
p_type = data [parameter ][0 ]
538
593
size = parse_size (data [parameter ][1 ], preprocessor , user_dimensions )
539
594
if "*" in p_type :
540
- args .append (allocate_array (p_type , size ))
595
+ args .append (allocate_array (p_type , size . get () ))
541
596
else :
542
- args .append (allocate_scalar (p_type , size ))
597
+ args .append (allocate_scalar (p_type , size . get () ))
543
598
544
599
return args
545
600
@@ -579,11 +634,15 @@ def add_present_openacc(
579
634
return new_body
580
635
581
636
582
- def add_present_openacc_cxx (name : str , size : int ) -> str :
637
+ def add_present_openacc_cxx (name : str , size : ArraySize ) -> str :
583
638
"""Create present clause for C++ OpenACC directive"""
584
- return f" present({ name } [:{ size } ]) "
639
+ return f" present({ name } [:{ size . get () } ]) "
585
640
586
641
587
- def add_present_openacc_fortran (name : str , size : int ) -> str :
642
+ def add_present_openacc_fortran (name : str , size : ArraySize ) -> str :
588
643
"""Create present clause for Fortran OpenACC directive"""
589
- return f" present({ name } (:{ size } )) "
644
+ if len (size ) == 1 :
645
+ return f" present({ name } (:{ size .get ()} )) "
646
+ else :
647
+ md_size = fortran_md_size (size )
648
+ return f" present({ name } ({ ',' .join (md_size )} )) "
0 commit comments