12
12
using namespace clang ::dpct;
13
13
using namespace clang ::ast_matchers;
14
14
15
+ #define STRINGIZE (x ) #x
16
+ #define EXPAND_AND_STRINGIZE (x ) STRINGIZE(x)
17
+
18
+ #define FOR_ALL_STANDARD_RMA_TYPES (PREFIX, POSTFIX ) \
19
+ EXPAND_AND_STRINGIZE (nvshmem##PREFIX##_##float ##_##POSTFIX), \
20
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##double ##_##POSTFIX), \
21
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##char ##_##POSTFIX), \
22
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##schar##_##POSTFIX), \
23
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##short ##_##POSTFIX), \
24
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##int ##_##POSTFIX), \
25
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##long ##_##POSTFIX), \
26
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##longlong##_##POSTFIX), \
27
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##uchar##_##POSTFIX), \
28
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##ushort##_##POSTFIX), \
29
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##uint##_##POSTFIX), \
30
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##ulong##_##POSTFIX), \
31
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##ulonglong##_##POSTFIX), \
32
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##int8##_##POSTFIX), \
33
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##int16##_##POSTFIX), \
34
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##int32##_##POSTFIX), \
35
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##int64##_##POSTFIX), \
36
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##uint8##_##POSTFIX), \
37
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##uint16##_##POSTFIX), \
38
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##uint32##_##POSTFIX), \
39
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##uint64##_##POSTFIX), \
40
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##size##_##POSTFIX), \
41
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##ptrdiff##_##POSTFIX)
42
+
43
+ #define FOR_ALL_SIZES (PREFIX, OP, POSTFIX ) \
44
+ EXPAND_AND_STRINGIZE (nvshmem##PREFIX##_##OP##8 ##POSTFIX), \
45
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##OP##16 ##POSTFIX), \
46
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##OP##32 ##POSTFIX), \
47
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##OP##64 ##POSTFIX), \
48
+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX##_##OP##128 ##POSTFIX)
49
+
15
50
void clang::dpct::NVSHMEMRule::registerMatcher(ast_matchers::MatchFinder &MF) {
16
51
auto NvshmemAPI = [&]() {
17
52
return hasAnyName (
@@ -25,8 +60,88 @@ void clang::dpct::NVSHMEMRule::registerMatcher(ast_matchers::MatchFinder &MF) {
25
60
" nvshmem_team_my_pe" , " nvshmem_team_n_pes" , " nvshmem_team_get_config" ,
26
61
" nvshmem_team_translate_pe" , " nvshmem_team_split_strided" ,
27
62
" nvshmem_team_split_2d" , " nvshmem_team_destroy" ,
63
+ // RMA
64
+ FOR_ALL_STANDARD_RMA_TYPES (, put) /* nvshmem_TYPENAME_put*/ ,
65
+ FOR_ALL_STANDARD_RMA_TYPES (
66
+ x, put_on_stream) /* nvshmemx_TYPENAME_put_on_stream*/ ,
67
+ FOR_ALL_STANDARD_RMA_TYPES (x,
68
+ put_block) /* nvshmemx_TYPENAME_put_block*/ ,
69
+ FOR_ALL_STANDARD_RMA_TYPES (x, put_warp) /* nvshmemx_TYPENAME_put_warp*/ ,
70
+ FOR_ALL_SIZES (, put, ) /* nvshmem_putSIZE*/ ,
71
+ FOR_ALL_SIZES (x, put, _on_stream) /* nvshmemx_putSIZE_on_stream*/ ,
72
+ FOR_ALL_SIZES (x, put, _block) /* nvshmem_putSIZE_block*/ ,
73
+ FOR_ALL_SIZES (x, put, _warp) /* nvshmem_putSIZE_warp*/ ,
74
+ FOR_ALL_STANDARD_RMA_TYPES (, iput) /* nvshmem_TYPENAME_iput*/ ,
75
+ FOR_ALL_STANDARD_RMA_TYPES (
76
+ x, iput_on_stream) /* nvshmemx_TYPENAME_iput_on_stream*/ ,
77
+ FOR_ALL_STANDARD_RMA_TYPES (x,
78
+ iput_block) /* nvshmemx_TYPENAME_iput_block*/ ,
79
+ FOR_ALL_STANDARD_RMA_TYPES (x,
80
+ iput_warp) /* nvshmemx_TYPENAME_iput_warp*/ ,
81
+ FOR_ALL_SIZES (, iput, ) /* nvshmem_iputSIZE*/ ,
82
+ FOR_ALL_SIZES (x, iput, _on_stream) /* nvshmem_iputSIZE_on_stream*/ ,
83
+ FOR_ALL_SIZES (x, iput, _block) /* nvshmem_iputSIZE_block*/ ,
84
+ FOR_ALL_SIZES (x, iput, _warp) /* nvshmem_iputSIZE_warp*/ ,
85
+ " nvshmem_putmem" , " nvshmemx_putmem_on_stream" , " nvshmemx_putmem_block" ,
86
+ " nvshmemx_putmem_warp" ,
87
+ FOR_ALL_STANDARD_RMA_TYPES (, p) /* nvshmem_TYPENAME_p*/ ,
88
+ FOR_ALL_STANDARD_RMA_TYPES (, get) /* nvshmem_TYPENAME_get*/ ,
89
+ FOR_ALL_STANDARD_RMA_TYPES (
90
+ x, get_on_stream) /* nvshmemx_TYPENAME_get_on_stream*/ ,
91
+ FOR_ALL_STANDARD_RMA_TYPES (x,
92
+ get_block) /* nvshmemx_TYPENAME_get_block*/ ,
93
+ FOR_ALL_STANDARD_RMA_TYPES (x, get_warp) /* nvshmemx_TYPENAME_get_warp*/ ,
94
+ FOR_ALL_SIZES (, get, ) /* nvshmem_getSIZE*/ ,
95
+ FOR_ALL_SIZES (x, get, _on_stream) /* nvshmem_getSIZE_on_stream*/ ,
96
+ FOR_ALL_SIZES (x, get, _block) /* nvshmem_getSIZE_block*/ ,
97
+ FOR_ALL_SIZES (x, get, _warp) /* nvshmem_getSIZE_warp*/ ,
98
+ FOR_ALL_STANDARD_RMA_TYPES (, iget) /* nvshmem_TYPENAME_iget*/ ,
99
+ FOR_ALL_STANDARD_RMA_TYPES (
100
+ x, iget_on_stream) /* nvshmemx_TYPENAME_iget_on_stream*/ ,
101
+ FOR_ALL_STANDARD_RMA_TYPES (x,
102
+ iget_block) /* nvshmemx_TYPENAME_iget_block*/ ,
103
+ FOR_ALL_STANDARD_RMA_TYPES (x,
104
+ iget_warp) /* nvshmemx_TYPENAME_iget_warp*/ ,
105
+ FOR_ALL_SIZES (, iget, ) /* nvshmem_igetSIZE*/ ,
106
+ FOR_ALL_SIZES (x, iget, _on_stream) /* nvshmem_igetSIZE_on_stream*/ ,
107
+ FOR_ALL_SIZES (x, iget, _block) /* nvshmem_igetSIZE_block*/ ,
108
+ FOR_ALL_SIZES (x, iget, _warp) /* nvshmem_igetSIZE_warp*/ ,
109
+ " nvshmem_getmem" , " nvshmemx_getmem_on_stream" , " nvshmemx_getmem_block" ,
110
+ " nvshmemx_getmem_warp" ,
111
+ FOR_ALL_STANDARD_RMA_TYPES (, g) /* nvshmem_TYPENAME_g*/ ,
28
112
// Nonblocking RMA
29
- " nvshmem_putmem_nbi" ,
113
+ FOR_ALL_STANDARD_RMA_TYPES (, put_nbi) /* nvshmem_TYPENAME_put_nbi*/ ,
114
+ FOR_ALL_STANDARD_RMA_TYPES (
115
+ x, put_nbi_on_stream) /* nvshmemx_TYPENAME_put_nbi_on_stream*/ ,
116
+ FOR_ALL_STANDARD_RMA_TYPES (
117
+ x, put_nbi_block) /* nvshmemx_TYPENAME_put_nbi_block*/ ,
118
+ FOR_ALL_STANDARD_RMA_TYPES (
119
+ x, put_nbi_warp) /* nvshmemx_TYPENAME_put_nbi_warp*/ ,
120
+ FOR_ALL_SIZES (, put, _nbi) /* nvshmem_putSIZE_nbi*/ ,
121
+ FOR_ALL_SIZES (x, put,
122
+ _nbi_on_stream) /* nvshmemx_putSIZE_nbi_on_stream*/ ,
123
+ FOR_ALL_SIZES (x, put, _nbi_block) /* nvshmem_putSIZE_nbi_block*/ ,
124
+ FOR_ALL_SIZES (x, put, _nbi_warp) /* nvshmem_putSIZE_nbi_warp*/ ,
125
+ " nvshmem_putmem_nbi" , " nvshmemx_putmem_nbi_on_stream" ,
126
+ " nvshmemx_putmem_nbi_block" , " nvshmemx_putmem_nbi_warp" ,
127
+ FOR_ALL_STANDARD_RMA_TYPES (, get_nbi) /* nvshmem_TYPENAME_get_nbi*/ ,
128
+ FOR_ALL_STANDARD_RMA_TYPES (
129
+ x, get_nbi_on_stream) /* nvshmemx_TYPENAME_get_nbi_on_stream*/ ,
130
+ FOR_ALL_STANDARD_RMA_TYPES (
131
+ x, get_nbi_block) /* nvshmemx_TYPENAME_get_nbi_block*/ ,
132
+ FOR_ALL_STANDARD_RMA_TYPES (
133
+ x, get_nbi_warp) /* nvshmemx_TYPENAME_get_nbi_warp*/ ,
134
+ FOR_ALL_SIZES (, get, _nbi) /* nvshmem_getSIZE_nbi*/ ,
135
+ FOR_ALL_SIZES (x, get,
136
+ _nbi_on_stream) /* nvshmemx_getSIZE_nbi_on_stream*/ ,
137
+ FOR_ALL_SIZES (x, get, _nbi_block) /* nvshmem_getSIZE_nbi_block*/ ,
138
+ FOR_ALL_SIZES (x, get, _nbi_warp) /* nvshmem_getSIZE_nbi_warp*/ ,
139
+ " nvshmem_getmem_nbi" , " nvshmemx_getmem_nbi_on_stream" ,
140
+ " nvshmemx_getmem_nbi_block" , " nvshmemx_getmem_nbi_warp" ,
141
+ // Memory Ordering
142
+ " nvshmem_fence" , " nvshmem_quiet" , " nvshmemx_quiet_on_stream" ,
143
+ // Collective Operations
144
+ " nvshmemx_barrier_all_on_stream" ,
30
145
// Signalling Operations
31
146
" nvshmemx_signal_op" , " nvshmem_signal_wait_until" ,
32
147
" nvshmem_putmem_signal_nbi" );
@@ -112,3 +227,8 @@ void clang::dpct::NVSHMEMRule::runRule(
112
227
emplaceTransformation (EA.getReplacement ());
113
228
EA.applyAllSubExprRepl ();
114
229
}
230
+
231
+ #undef STRINGIZE
232
+ #undef EXPAND_AND_STRINGIZE
233
+ #undef FOR_ALL_STANDARD_RMA_TYPES
234
+ #undef FOR_ALL_SIZES
0 commit comments