1
- from typing import Any
1
+ from typing import Any , Tuple
2
2
from abc import ABC , abstractmethod
3
3
import numpy as np
4
4
5
+ # Function templates
6
+ cpp_template : str = """
7
+ <!?PREPROCESSOR?!>
8
+ <!?USER_DEFINES?!>
9
+ #include <chrono>
10
+
11
+ extern "C" <!?SIGNATURE?!> {
12
+ <!?INITIALIZATION?!>
13
+ <!?BODY?!>
14
+ <!?DEINITIALIZATION?!>
15
+ }
16
+ """
17
+
18
+ f90_template : str = """
19
+ <!?PREPROCESSOR?!>
20
+ <!?USER_DEFINES?!>
21
+
22
+ module kt
23
+ use iso_c_binding
24
+ contains
25
+
26
+ <!?SIGNATURE?!>
27
+ <!?INITIALIZATION?!>
28
+ <!?BODY?!>
29
+ <!?DEINITIALIZATION?!>
30
+ end function <!?NAME?!>
31
+
32
+ end module kt
33
+ """
34
+
5
35
6
36
class Directive (ABC ):
7
37
"""Base class for all directives"""
@@ -339,7 +369,7 @@ def wrap_timing_fortran(code: str) -> str:
339
369
340
370
def end_timing_cxx (code : str ) -> str :
341
371
"""In C++ we need to return the measured time"""
342
- return code + " \n return elapsed_time.count();\n "
372
+ return " \n " . join ([ code , "return elapsed_time.count();\n "])
343
373
344
374
345
375
def wrap_data (code : str , langs : Code , data : dict , preprocessor : list = None , user_dimensions : dict = None ) -> str :
@@ -355,7 +385,7 @@ def wrap_data(code: str, langs: Code, data: dict, preprocessor: list = None, use
355
385
elif is_openacc (langs .directive ) and is_fortran (langs .language ):
356
386
intro += create_data_directive_openacc_fortran (name , size )
357
387
outro += exit_data_directive_openacc_fortran (name , size )
358
- return intro + code + outro
388
+ return " \n " . join ([ intro , code , outro ])
359
389
360
390
361
391
def extract_directive_code (code : str , langs : Code , kernel_name : str = None ) -> dict :
@@ -529,42 +559,34 @@ def generate_directive_function(
529
559
) -> str :
530
560
"""Generate tunable function for one directive"""
531
561
532
- code = "\n " .join (preprocessor ) + "\n "
533
- if user_dimensions is not None :
534
- # add user dimensions to preprocessor
535
- for key , value in user_dimensions .items ():
536
- code += f"#define { key } { value } \n "
537
- if is_cxx (langs .language ) and "#include <chrono>" not in preprocessor :
538
- code += "\n #include <chrono>\n "
539
- if is_cxx (langs .language ):
540
- code += 'extern "C" ' + signature + "{\n "
541
- elif is_fortran (langs .language ):
542
- code += "\n module kt\n use iso_c_binding\n contains\n "
543
- code += "\n " + signature
544
- if len (initialization ) > 1 :
545
- code += initialization + "\n "
546
- if data is not None :
547
- body = add_present_openacc (body , langs , data , preprocessor , user_dimensions )
548
562
if is_cxx (langs .language ):
563
+ code = cpp_template
549
564
body = start_timing_cxx (body )
550
565
if data is not None :
551
- code += wrap_data (body + "\n " , langs , data , preprocessor , user_dimensions )
552
- else :
553
- code += body
554
- code = end_timing_cxx (code )
555
- if len (deinitialization ) > 1 :
556
- code += deinitialization + "\n "
557
- code += "\n }"
566
+ body = wrap_data (body + "\n " , langs , data , preprocessor , user_dimensions )
567
+ body = end_timing_cxx (body )
558
568
elif is_fortran (langs .language ):
569
+ code = f90_template
559
570
body = wrap_timing (body , langs .language )
560
571
if data is not None :
561
- code += wrap_data (body + "\n " , langs , data , preprocessor , user_dimensions )
562
- else :
563
- code += body + "\n "
564
- if len (deinitialization ) > 1 :
565
- code += deinitialization + "\n "
572
+ body = wrap_data (body + "\n " , langs , data , preprocessor , user_dimensions )
566
573
name = signature .split (" " )[1 ].split ("(" )[0 ]
567
- code += f"\n end function { name } \n end module kt\n "
574
+ code = code .replace ("<!?NAME?!>" , name )
575
+ code = code .replace ("<!?PREPROCESSOR?!>" , "\n " .join (preprocessor ))
576
+ # if present, add user specific dimensions as defines
577
+ if user_dimensions is not None :
578
+ user_defines = ""
579
+ for key , value in user_dimensions .items ():
580
+ user_defines += f"#define { key } { value } \n "
581
+ code = code .replace ("<!?USER_DEFINES?!>" , user_defines )
582
+ else :
583
+ code = code .replace ("<!?USER_DEFINES?!>" , "" )
584
+ code = code .replace ("<!?SIGNATURE?!>" , signature )
585
+ code = code .replace ("<!?INITIALIZATION?!>" , initialization )
586
+ code = code .replace ("<!?DEINITIALIZATION?!>" , deinitialization )
587
+ if data is not None :
588
+ body = add_present_openacc (body , langs , data , preprocessor , user_dimensions )
589
+ code = code .replace ("<!?BODY?!>" , body )
568
590
569
591
return code
570
592
@@ -662,3 +684,21 @@ def add_present_openacc_fortran(name: str, size: ArraySize) -> str:
662
684
else :
663
685
md_size = fortran_md_size (size )
664
686
return f" present({ name } ({ ',' .join (md_size )} )) "
687
+
688
+
689
+ def process_directives (langs : Code , source : str , user_dimensions : dict = None ) -> Tuple [dict , dict ]:
690
+ """Helper functions to process all the directives in the code and create tunable functions"""
691
+ kernel_strings = dict ()
692
+ kernel_args = dict ()
693
+ preprocessor = extract_preprocessor (source )
694
+ signatures = extract_directive_signature (source , langs )
695
+ bodies = extract_directive_code (source , langs )
696
+ data = extract_directive_data (source , langs )
697
+ init = extract_initialization_code (source , langs )
698
+ deinit = extract_deinitialization_code (source , langs )
699
+ for kernel in signatures .keys ():
700
+ kernel_strings [kernel ] = generate_directive_function (
701
+ preprocessor , signatures [kernel ], bodies [kernel ], langs , data [kernel ], init , deinit , user_dimensions
702
+ )
703
+ kernel_args [kernel ] = allocate_signature_memory (data [kernel ], preprocessor , user_dimensions )
704
+ return (kernel_strings , kernel_args )
0 commit comments