@@ -48,6 +48,47 @@ def __init__(self, directive: Directive, lang: Language):
48
48
self .language = lang
49
49
50
50
51
+ class ArraySize :
52
+ """Size of an array"""
53
+
54
+ def __init__ (self ):
55
+ self .size = list ()
56
+
57
+ def __iter__ (self ):
58
+ for i in self .size :
59
+ yield i
60
+
61
+ def __len__ (self ):
62
+ return len (self .size )
63
+
64
+ def clear (self ):
65
+ self .size .clear ()
66
+
67
+ def get (self ) -> int :
68
+ length = len (self .size )
69
+ if length == 0 :
70
+ return 0
71
+ elif length == 1 :
72
+ return self .size [0 ]
73
+ else :
74
+ product = 1
75
+ for i in self .size :
76
+ product *= i
77
+ return product
78
+
79
+ def add (self , dim : int ) -> None :
80
+ # Only allow adding valid dimensions
81
+ if dim >= 1 :
82
+ self .size .append (dim )
83
+
84
+
85
+ def fortran_md_size (size : ArraySize ) -> list :
86
+ md_size = list ()
87
+ for dim in size :
88
+ md_size .append (f":{ dim } " )
89
+ return md_size
90
+
91
+
51
92
def is_openacc (directive : Directive ) -> bool :
52
93
"""Check if a directive is OpenACC"""
53
94
return isinstance (directive , OpenACC )
@@ -120,7 +161,7 @@ def openacc_directive_contains_data_clause(line: str) -> bool:
120
161
return openacc_directive_contains_clause (line , data_clauses )
121
162
122
163
123
- def create_data_directive_openacc (name : str , size : int , lang : Language ) -> str :
164
+ def create_data_directive_openacc (name : str , size : ArraySize , lang : Language ) -> str :
124
165
"""Create a data directive for a given language"""
125
166
if is_cxx (lang ):
126
167
return create_data_directive_openacc_cxx (name , size )
@@ -129,17 +170,23 @@ def create_data_directive_openacc(name: str, size: int, lang: Language) -> str:
129
170
return ""
130
171
131
172
132
- def create_data_directive_openacc_cxx (name : str , size : int ) -> str :
173
+ def create_data_directive_openacc_cxx (name : str , size : ArraySize ) -> str :
133
174
"""Create C++ OpenACC code to allocate and copy data"""
134
- return f"#pragma acc enter data create({ name } [:{ size } ])\n #pragma acc update device({ name } [:{ size } ])\n "
175
+ return f"#pragma acc enter data create({ name } [:{ size . get () } ])\n #pragma acc update device({ name } [:{ size . get () } ])\n "
135
176
136
177
137
- def create_data_directive_openacc_fortran (name : str , size : int ) -> str :
178
+ def create_data_directive_openacc_fortran (name : str , size : ArraySize ) -> str :
138
179
"""Create Fortran OpenACC code to allocate and copy data"""
139
- return f"!$acc enter data create({ name } (:{ size } ))\n !$acc update device({ name } (:{ size } ))\n "
180
+ if len (size ) == 1 :
181
+ return f"!$acc enter data create({ name } (:{ size .get ()} ))\n !$acc update device({ name } (:{ size .get ()} ))\n "
182
+ else :
183
+ md_size = fortran_md_size (size )
184
+ return (
185
+ f"!$acc enter data create({ name } ({ ',' .join (md_size )} ))\n !$acc update device({ name } ({ ',' .join (md_size )} ))\n "
186
+ )
140
187
141
188
142
- def exit_data_directive_openacc (name : str , size : int , lang : Language ) -> str :
189
+ def exit_data_directive_openacc (name : str , size : ArraySize , lang : Language ) -> str :
143
190
"""Create code to copy data back for a given language"""
144
191
if is_cxx (lang ):
145
192
return exit_data_directive_openacc_cxx (name , size )
@@ -148,14 +195,18 @@ def exit_data_directive_openacc(name: str, size: int, lang: Language) -> str:
148
195
return ""
149
196
150
197
151
- def exit_data_directive_openacc_cxx (name : str , size : int ) -> str :
198
+ def exit_data_directive_openacc_cxx (name : str , size : ArraySize ) -> str :
152
199
"""Create C++ OpenACC code to copy back data"""
153
- return f"#pragma acc exit data copyout({ name } [:{ size } ])\n "
200
+ return f"#pragma acc exit data copyout({ name } [:{ size . get () } ])\n "
154
201
155
202
156
- def exit_data_directive_openacc_fortran (name : str , size : int ) -> str :
203
+ def exit_data_directive_openacc_fortran (name : str , size : ArraySize ) -> str :
157
204
"""Create Fortran OpenACC code to copy back data"""
158
- return f"!$acc exit data copyout({ name } (:{ size } ))\n "
205
+ if len (size ) == 1 :
206
+ return f"!$acc exit data copyout({ name } (:{ size .get ()} ))\n "
207
+ else :
208
+ md_size = fortran_md_size (size )
209
+ return f"!$acc exit data copyout({ name } ({ ',' .join (md_size )} ))\n "
159
210
160
211
161
212
def correct_kernel (kernel_name : str , line : str ) -> bool :
@@ -165,7 +216,7 @@ def correct_kernel(kernel_name: str, line: str) -> bool:
165
216
166
217
def find_size_in_preprocessor (dimension : str , preprocessor : list ) -> int :
167
218
"""Find the dimension of a directive defined value in the preprocessor"""
168
- ret_size = None
219
+ ret_size = 0
169
220
for line in preprocessor :
170
221
if f"#define { dimension } " in line :
171
222
try :
@@ -209,45 +260,43 @@ def extract_code(start: str, stop: str, code: str, langs: Code, kernel_name: str
209
260
return sections
210
261
211
262
212
- def parse_size (size : Any , preprocessor : list = None , dimensions : dict = None ) -> int :
263
+ def parse_size (size : Any , preprocessor : list = None , dimensions : dict = None ) -> ArraySize :
213
264
"""Converts an arbitrary object into an integer representing memory size"""
214
- ret_size = None
265
+ ret_size = ArraySize ()
215
266
if type (size ) is not int :
216
267
try :
217
268
# Try to convert the size to an integer
218
- ret_size = int (size )
269
+ ret_size . add ( int (size ) )
219
270
except ValueError :
220
271
# If size cannot be natively converted to an int, we try to derive it from the preprocessor
221
- if preprocessor is not None :
222
- try :
272
+ try :
273
+ if preprocessor is not None :
223
274
if "," in size :
224
- ret_size = 1
225
275
for dimension in size .split ("," ):
226
- ret_size *= find_size_in_preprocessor (dimension , preprocessor )
276
+ ret_size . add ( find_size_in_preprocessor (dimension , preprocessor ) )
227
277
else :
228
- ret_size = find_size_in_preprocessor (size , preprocessor )
229
- except TypeError :
230
- # preprocessor is available but does not contain the dimensions
231
- pass
278
+ ret_size . add ( find_size_in_preprocessor (size , preprocessor ) )
279
+ except TypeError :
280
+ # At least one of the dimension cannot be derived from the preprocessor
281
+ pass
232
282
# If size cannot be natively converted, nor retrieved from the preprocessor, we check user provided values
233
283
if dimensions is not None :
234
284
if size in dimensions .keys ():
235
285
try :
236
- ret_size = int (dimensions [size ])
286
+ ret_size . add ( int (dimensions [size ]) )
237
287
except ValueError :
238
288
# User error, no mitigation
239
289
return ret_size
240
290
elif "," in size :
241
- ret_size = 1
242
291
for dimension in size .split ("," ):
243
292
try :
244
- ret_size *= int (dimensions [dimension ])
293
+ ret_size . add ( int (dimensions [dimension ]) )
245
294
except ValueError :
246
295
# User error, no mitigation
247
- return None
296
+ return ret_size
248
297
else :
249
298
# size is already an int. no need for conversion
250
- ret_size = size
299
+ ret_size . add ( size )
251
300
252
301
return ret_size
253
302
@@ -297,8 +346,13 @@ def wrap_data(code: str, langs: Code, data: dict, preprocessor: list = None, use
297
346
intro += create_data_directive_openacc_cxx (name , size )
298
347
outro += exit_data_directive_openacc_cxx (name , size )
299
348
elif is_openacc (langs .directive ) and is_fortran (langs .language ):
300
- intro += create_data_directive_openacc_fortran (name , size )
301
- outro += exit_data_directive_openacc_fortran (name , size )
349
+ if "," in data [name ][1 ]:
350
+ # Multi dimensional
351
+ pass
352
+ else :
353
+ # One dimensional
354
+ intro += create_data_directive_openacc_fortran (name , size )
355
+ outro += exit_data_directive_openacc_fortran (name , size )
302
356
return intro + code + outro
303
357
304
358
@@ -537,9 +591,9 @@ def allocate_signature_memory(data: dict, preprocessor: list = None, user_dimens
537
591
p_type = data [parameter ][0 ]
538
592
size = parse_size (data [parameter ][1 ], preprocessor , user_dimensions )
539
593
if "*" in p_type :
540
- args .append (allocate_array (p_type , size ))
594
+ args .append (allocate_array (p_type , size . get () ))
541
595
else :
542
- args .append (allocate_scalar (p_type , size ))
596
+ args .append (allocate_scalar (p_type , size . get () ))
543
597
544
598
return args
545
599
@@ -579,11 +633,15 @@ def add_present_openacc(
579
633
return new_body
580
634
581
635
582
- def add_present_openacc_cxx (name : str , size : int ) -> str :
636
+ def add_present_openacc_cxx (name : str , size : ArraySize ) -> str :
583
637
"""Create present clause for C++ OpenACC directive"""
584
- return f" present({ name } [:{ size } ]) "
638
+ return f" present({ name } [:{ size . get () } ]) "
585
639
586
640
587
- def add_present_openacc_fortran (name : str , size : int ) -> str :
641
+ def add_present_openacc_fortran (name : str , size : ArraySize ) -> str :
588
642
"""Create present clause for Fortran OpenACC directive"""
589
- return f" present({ name } (:{ size } )) "
643
+ if len (size ) == 1 :
644
+ return f" present({ name } (:{ size .get ()} )) "
645
+ else :
646
+ md_size = fortran_md_size (size )
647
+ return f" present({ name } ({ ',' .join (md_size )} )) "
0 commit comments