@@ -162,41 +162,167 @@ RewriterMap dpct::createUtilityFunctionsRewriterMap() {
162
162
MEMBER_CALL_FACTORY_ENTRY (" cub::RowMajorTid" , NDITEM, /* IsArrow=*/ false ,
163
163
" get_local_linear_id" )
164
164
// cub::LoadDirectBlocked
165
- HEADER_INSERT_FACTORY (
166
- HeaderType::HT_DPCT_GROUP_Utils,
167
- CALL_FACTORY_ENTRY (
168
- " cub::LoadDirectBlocked" ,
169
- CALL (PRETTY_TEMPLATED_CALLEE (MapNames::getDpctNamespace () +
170
- " group::load_direct_blocked" ,
171
- 0 , 1 , 2 ),
172
- NDITEM, ARG (1 ), ARG (2 ))))
165
+ CASE_FACTORY_ENTRY (
166
+ CASE (CheckArgCount (3 ),
167
+ HEADER_INSERT_FACTORY (
168
+ HeaderType::HT_DPCT_GROUP_Utils,
169
+ CALL_FACTORY_ENTRY (" cub::LoadDirectBlocked" ,
170
+ CALL (PRETTY_TEMPLATED_CALLEE (
171
+ MapNames::getDpctNamespace () +
172
+ " group::load_direct_blocked" ,
173
+ 0 , 1 , 2 ),
174
+ NDITEM, ARG (1 ), ARG (2 ))))),
175
+ CASE (CheckArgCount (4 ),
176
+ HEADER_INSERT_FACTORY (
177
+ HeaderType::HT_DPCT_GROUP_Utils,
178
+ CALL_FACTORY_ENTRY (" cub::LoadDirectBlocked" ,
179
+ CALL (PRETTY_TEMPLATED_CALLEE (
180
+ MapNames::getDpctNamespace () +
181
+ " group::load_direct_blocked" ,
182
+ 0 , 1 , 2 ),
183
+ NDITEM, ARG (1 ), ARG (2 ), ARG (3 ))))),
184
+ CASE (CheckArgCount (5 ),
185
+ HEADER_INSERT_FACTORY (
186
+ HeaderType::HT_DPCT_GROUP_Utils,
187
+ CALL_FACTORY_ENTRY (" cub::LoadDirectBlocked" ,
188
+ CALL (PRETTY_TEMPLATED_CALLEE (
189
+ MapNames::getDpctNamespace () +
190
+ " group::load_direct_blocked" ,
191
+ 0 , 1 , 2 , 3 ),
192
+ NDITEM, ARG (1 ), ARG (2 ), ARG (3 ),
193
+ ARG (4 ))))))
194
+
173
195
// cub::LoadDirectStriped
174
- HEADER_INSERT_FACTORY (
175
- HeaderType::HT_DPCT_GROUP_Utils,
176
- CALL_FACTORY_ENTRY (
177
- " cub::LoadDirectStriped" ,
178
- CALL (PRETTY_TEMPLATED_CALLEE (MapNames::getDpctNamespace () +
179
- " group::load_direct_striped" ,
180
- 1 , 2 , 3 ),
181
- NDITEM, ARG (1 ), ARG (2 ))))
196
+ CASE_FACTORY_ENTRY (
197
+ CASE (CheckArgCount (3 ),
198
+ HEADER_INSERT_FACTORY (
199
+ HeaderType::HT_DPCT_GROUP_Utils,
200
+ CALL_FACTORY_ENTRY (" cub::LoadDirectStriped" ,
201
+ CALL (PRETTY_TEMPLATED_CALLEE (
202
+ MapNames::getDpctNamespace () +
203
+ " group::load_direct_striped" ,
204
+ 1 , 2 , 3 ),
205
+ NDITEM, ARG (1 ), ARG (2 ))))),
206
+ CASE (CheckArgCount (4 ),
207
+ HEADER_INSERT_FACTORY (
208
+ HeaderType::HT_DPCT_GROUP_Utils,
209
+ CALL_FACTORY_ENTRY (" cub::LoadDirectStriped" ,
210
+ CALL (PRETTY_TEMPLATED_CALLEE (
211
+ MapNames::getDpctNamespace () +
212
+ " group::load_direct_striped" ,
213
+ 1 , 2 , 3 ),
214
+ NDITEM, ARG (1 ), ARG (2 ), ARG (3 ))))),
215
+ CASE (CheckArgCount (5 ),
216
+ HEADER_INSERT_FACTORY (
217
+ HeaderType::HT_DPCT_GROUP_Utils,
218
+ CALL_FACTORY_ENTRY (" cub::LoadDirectStriped" ,
219
+ CALL (PRETTY_TEMPLATED_CALLEE (
220
+ MapNames::getDpctNamespace () +
221
+ " group::load_direct_striped" ,
222
+ 1 , 2 , 3 , 4 ),
223
+ NDITEM, ARG (1 ), ARG (2 ), ARG (3 ),
224
+ ARG (4 ))))))
225
+ // cub::LoadDirectWarpStriped
226
+ CASE_FACTORY_ENTRY (
227
+ CASE (CheckArgCount (3 ),
228
+ HEADER_INSERT_FACTORY (
229
+ HeaderType::HT_DPCT_GROUP_Utils,
230
+ CALL_FACTORY_ENTRY (
231
+ " cub::LoadDirectWarpStriped" ,
232
+ CALL (PRETTY_TEMPLATED_CALLEE (
233
+ MapNames::getDpctNamespace () +
234
+ " group::load_direct_sub_group_striped" ,
235
+ 0 , 1 , 2 ),
236
+ NDITEM, ARG (1 ), ARG (2 ))))),
237
+ CASE (CheckArgCount (4 ),
238
+ HEADER_INSERT_FACTORY (
239
+ HeaderType::HT_DPCT_GROUP_Utils,
240
+ CALL_FACTORY_ENTRY (
241
+ " cub::LoadDirectWarpStriped" ,
242
+ CALL (PRETTY_TEMPLATED_CALLEE (
243
+ MapNames::getDpctNamespace () +
244
+ " group::load_direct_sub_group_striped" ,
245
+ 0 , 1 , 2 ),
246
+ NDITEM, ARG (1 ), ARG (2 ), ARG (3 ))))),
247
+ CASE (CheckArgCount (5 ),
248
+ HEADER_INSERT_FACTORY (
249
+ HeaderType::HT_DPCT_GROUP_Utils,
250
+ CALL_FACTORY_ENTRY (
251
+ " cub::LoadDirectWarpStriped" ,
252
+ CALL (PRETTY_TEMPLATED_CALLEE (
253
+ MapNames::getDpctNamespace () +
254
+ " group::load_direct_sub_group_striped" ,
255
+ 0 , 1 , 2 , 3 ),
256
+ NDITEM, ARG (1 ), ARG (2 ), ARG (3 ), ARG (4 ))))))
257
+
182
258
// cub::StoreDirectBlocked
183
- HEADER_INSERT_FACTORY (
184
- HeaderType::HT_DPCT_GROUP_Utils,
185
- CALL_FACTORY_ENTRY (
186
- " cub::StoreDirectBlocked" ,
187
- CALL (PRETTY_TEMPLATED_CALLEE (MapNames::getDpctNamespace () +
188
- " group::store_direct_blocked" ,
189
- 0 , 1 , 2 ),
190
- NDITEM, ARG (1 ), ARG (2 ))))
259
+ CASE_FACTORY_ENTRY (
260
+ CASE (
261
+ CheckArgCount (3 ),
262
+ HEADER_INSERT_FACTORY (
263
+ HeaderType::HT_DPCT_GROUP_Utils,
264
+ CALL_FACTORY_ENTRY (" cub::StoreDirectBlocked" ,
265
+ CALL (PRETTY_TEMPLATED_CALLEE (
266
+ MapNames::getDpctNamespace () +
267
+ " group::store_direct_blocked" ,
268
+ 0 , 1 , 2 ),
269
+ NDITEM, ARG (1 ), ARG (2 ))))),
270
+ CASE (
271
+ CheckArgCount (4 ),
272
+ HEADER_INSERT_FACTORY (
273
+ HeaderType::HT_DPCT_GROUP_Utils,
274
+ CALL_FACTORY_ENTRY (" cub::StoreDirectBlocked" ,
275
+ CALL (PRETTY_TEMPLATED_CALLEE (
276
+ MapNames::getDpctNamespace () +
277
+ " group::store_direct_blocked" ,
278
+ 0 , 1 , 2 ),
279
+ NDITEM, ARG (1 ), ARG (2 ), ARG (3 ))))))
280
+
191
281
// cub::StoreDirectStriped
192
- HEADER_INSERT_FACTORY (
193
- HeaderType::HT_DPCT_GROUP_Utils,
194
- CALL_FACTORY_ENTRY (
195
- " cub::StoreDirectStriped" ,
196
- CALL (PRETTY_TEMPLATED_CALLEE (MapNames::getDpctNamespace () +
197
- " group::store_direct_striped" ,
198
- 1 , 2 , 3 ),
199
- NDITEM, ARG (1 ), ARG (2 ))))
282
+ CASE_FACTORY_ENTRY (
283
+ CASE (
284
+ CheckArgCount (3 ),
285
+ HEADER_INSERT_FACTORY (
286
+ HeaderType::HT_DPCT_GROUP_Utils,
287
+ CALL_FACTORY_ENTRY (" cub::StoreDirectStriped" ,
288
+ CALL (PRETTY_TEMPLATED_CALLEE (
289
+ MapNames::getDpctNamespace () +
290
+ " group::store_direct_striped" ,
291
+ 1 , 2 , 3 ),
292
+ NDITEM, ARG (1 ), ARG (2 ))))),
293
+ CASE (
294
+ CheckArgCount (4 ),
295
+ HEADER_INSERT_FACTORY (
296
+ HeaderType::HT_DPCT_GROUP_Utils,
297
+ CALL_FACTORY_ENTRY (" cub::StoreDirectStriped" ,
298
+ CALL (PRETTY_TEMPLATED_CALLEE (
299
+ MapNames::getDpctNamespace () +
300
+ " group::store_direct_striped" ,
301
+ 1 , 2 , 3 ),
302
+ NDITEM, ARG (1 ), ARG (2 ), ARG (3 ))))))
303
+ // cub::StoreDirectWarpStriped
304
+ CASE_FACTORY_ENTRY (
305
+ CASE (CheckArgCount (3 ),
306
+ HEADER_INSERT_FACTORY (
307
+ HeaderType::HT_DPCT_GROUP_Utils,
308
+ CALL_FACTORY_ENTRY (
309
+ " cub::StoreDirectWarpStriped" ,
310
+ CALL (PRETTY_TEMPLATED_CALLEE (
311
+ MapNames::getDpctNamespace () +
312
+ " group::store_direct_sub_group_striped" ,
313
+ 0 , 1 , 2 ),
314
+ NDITEM, ARG (1 ), ARG (2 ))))),
315
+ CASE (CheckArgCount (4 ),
316
+ HEADER_INSERT_FACTORY (
317
+ HeaderType::HT_DPCT_GROUP_Utils,
318
+ CALL_FACTORY_ENTRY (
319
+ " cub::StoreDirectWarpStriped" ,
320
+ CALL (PRETTY_TEMPLATED_CALLEE (
321
+ MapNames::getDpctNamespace () +
322
+ " group::store_direct_sub_group_striped" ,
323
+ 0 , 1 , 2 ),
324
+ NDITEM, ARG (1 ), ARG (2 ), ARG (3 ))))))
325
+
200
326
// cub::ShuffleDown
201
327
SUBGROUPSIZE_FACTORY (
202
328
UINT_MAX,
0 commit comments