@@ -125,8 +125,9 @@ int get_mem_idx(GroupTy g, int vec_or_array_idx) {
125
125
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_char.html
126
126
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_long.html
127
127
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_short.html
128
- // Reads require 4-byte alignment, writes 16-byte alignment. Supported
129
- // sizes:
128
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_local_block_io.html
129
+ // Reads require 4-byte alignment for global pointers and 16-byte alignment for
130
+ // local pointers, writes require 16-byte alignment. Supported sizes:
130
131
//
131
132
// +------------+-------------+
132
133
// | block type | # of blocks |
@@ -156,6 +157,21 @@ struct BlockInfo {
156
157
(num_blocks <= 8 || (num_blocks == 16 && block_size <= 2 ));
157
158
};
158
159
160
+ enum class operation_type { load, store };
161
+
162
+ template <operation_type OpType, access::address_space Space>
163
+ struct RequiredAlignment {};
164
+
165
+ template <operation_type OpType>
166
+ struct RequiredAlignment <OpType, access::address_space::global_space> {
167
+ static constexpr int value = (OpType == operation_type::load) ? 4 : 16 ;
168
+ };
169
+
170
+ template <operation_type OpType>
171
+ struct RequiredAlignment <OpType, access::address_space::local_space> {
172
+ static constexpr int value = 16 ;
173
+ };
174
+
159
175
template <typename BlockInfoTy> struct BlockTypeInfo ;
160
176
161
177
template <typename IteratorT, std::size_t ElementsPerWorkItem, bool Blocked>
@@ -186,11 +202,10 @@ struct BlockTypeInfo<BlockInfo<IteratorT, ElementsPerWorkItem, Blocked>> {
186
202
// aren't satisfied. If deduced address space is generic then returned pointer
187
203
// will have generic address space and has to be dynamically casted to global or
188
204
// local space before using in a builtin.
189
- template <int RequiredAlign, std::size_t ElementsPerWorkItem,
190
- typename IteratorT, typename Properties>
191
- auto get_block_op_ptr (IteratorT iter, [[maybe_unused]] Properties props) {
192
- using value_type =
193
- remove_decoration_t <typename std::iterator_traits<IteratorT>::value_type>;
205
+ template <std::size_t ElementsPerWorkItem, typename IteratorT,
206
+ typename Properties>
207
+ constexpr auto get_block_op_ptr (IteratorT iter,
208
+ [[maybe_unused]] Properties props) {
194
209
using iter_no_cv = std::remove_cv_t <IteratorT>;
195
210
196
211
constexpr bool blocked = detail::isBlocked (props);
@@ -208,39 +223,46 @@ auto get_block_op_ptr(IteratorT iter, [[maybe_unused]] Properties props) {
208
223
} else if constexpr (!props.template has_property <full_group_key>()) {
209
224
return nullptr ;
210
225
} else if constexpr (detail::is_multi_ptr_v<IteratorT>) {
211
- return get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(
212
- iter.get_decorated (), props);
226
+ return get_block_op_ptr<ElementsPerWorkItem>(iter.get_decorated (), props);
213
227
} else if constexpr (!std::is_pointer_v<iter_no_cv>) {
214
228
if constexpr (props.template has_property <contiguous_memory_key>())
215
- return get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(&*iter,
216
- props);
229
+ return get_block_op_ptr<ElementsPerWorkItem>(&*iter, props);
217
230
else
218
231
return nullptr ;
219
232
} else {
220
233
// Load/store to/from nullptr would be an UB, this assume allows the
221
234
// compiler to optimize the IR further.
222
235
__builtin_assume (iter != nullptr );
223
236
224
- // No early return as that would mess up return type deduction.
225
- bool is_aligned = alignof (value_type) >= RequiredAlign ||
226
- reinterpret_cast <uintptr_t >(iter) % RequiredAlign == 0 ;
227
-
228
237
using block_pointer_type =
229
238
typename BlockTypeInfo<BlkInfo>::block_pointer_type;
230
239
231
- static constexpr auto deduced_address_space =
240
+ constexpr auto deduced_address_space =
232
241
BlockTypeInfo<BlkInfo>::deduced_address_space;
242
+
233
243
if constexpr (deduced_address_space ==
234
244
access::address_space::generic_space ||
235
245
deduced_address_space ==
236
246
access::address_space::global_space ||
237
- deduced_address_space == access::address_space::local_space) {
238
- return is_aligned ? reinterpret_cast <block_pointer_type>(iter) : nullptr ;
247
+ (deduced_address_space ==
248
+ access::address_space::local_space &&
249
+ props.template has_property <
250
+ detail::native_local_block_io_key>())) {
251
+ return reinterpret_cast <block_pointer_type>(iter);
239
252
} else {
240
253
return nullptr ;
241
254
}
242
255
}
243
256
}
257
+
258
+ template <int RequiredAlign, typename IteratorType>
259
+ bool is_aligned (IteratorType iter) {
260
+ using value_type = remove_decoration_t <
261
+ typename std::iterator_traits<IteratorType>::value_type>;
262
+ return alignof (value_type) >= RequiredAlign ||
263
+ reinterpret_cast <uintptr_t >(&*iter) % RequiredAlign == 0 ;
264
+ }
265
+
244
266
} // namespace detail
245
267
246
268
// Load API span overload.
@@ -266,78 +288,72 @@ group_load(Group g, InputIteratorT in_ptr,
266
288
} else if constexpr (!std::is_same_v<Group, sycl::sub_group>) {
267
289
return group_load (g, in_ptr, out, use_naive{});
268
290
} else {
269
- auto ptr =
270
- detail::get_block_op_ptr<4 /* load align */ , ElementsPerWorkItem>(
271
- in_ptr, props);
272
- if (!ptr)
273
- return group_load (g, in_ptr, out, use_naive{});
291
+ auto ptr = detail::get_block_op_ptr<ElementsPerWorkItem>(in_ptr, props);
292
+ static constexpr auto deduced_address_space =
293
+ detail::deduce_AS<std::remove_cv_t <decltype (ptr)>>::value;
274
294
275
295
if constexpr (!std::is_same_v<std::nullptr_t , decltype (ptr)>) {
276
- // Do optimized load.
277
- using value_type = remove_decoration_t <
278
- typename std::iterator_traits<InputIteratorT>::value_type>;
279
- using block_info = typename detail::BlockTypeInfo<
280
- detail::BlockInfo<InputIteratorT, ElementsPerWorkItem, blocked>>;
281
- static constexpr auto deduced_address_space =
282
- block_info::deduced_address_space;
283
- using block_op_type = typename block_info::block_op_type;
284
-
285
- if constexpr (deduced_address_space ==
286
- access::address_space::local_space &&
287
- !props.template has_property <
288
- detail::native_local_block_io_key>())
289
- return group_load (g, in_ptr, out, use_naive{});
290
-
291
- block_op_type load;
292
296
if constexpr (deduced_address_space ==
293
297
access::address_space::generic_space) {
294
298
if (auto local_ptr = detail::dynamic_address_cast<
295
299
access::address_space::local_space>(ptr)) {
296
- if constexpr (props.template has_property <
297
- detail::native_local_block_io_key>())
298
- load = __spirv_SubgroupBlockReadINTEL<block_op_type>(local_ptr);
299
- else
300
- return group_load (g, in_ptr, out, use_naive{});
300
+ return group_load (g, local_ptr, out, props);
301
301
} else if (auto global_ptr = detail::dynamic_address_cast<
302
302
access::address_space::global_space>(ptr)) {
303
- load = __spirv_SubgroupBlockReadINTEL<block_op_type>( global_ptr);
303
+ return group_load (g, global_ptr, out, props );
304
304
} else {
305
305
return group_load (g, in_ptr, out, use_naive{});
306
306
}
307
307
} else {
308
- load = __spirv_SubgroupBlockReadINTEL<block_op_type>(ptr);
309
- }
308
+ using value_type = remove_decoration_t <
309
+ typename std::iterator_traits<InputIteratorT>::value_type>;
310
+ using block_info = typename detail::BlockTypeInfo<
311
+ detail::BlockInfo<InputIteratorT, ElementsPerWorkItem, blocked>>;
312
+ using block_op_type = typename block_info::block_op_type;
313
+ // Alignment checks of the pointer.
314
+ constexpr int ReqAlign =
315
+ detail::RequiredAlignment<detail::operation_type::load,
316
+ deduced_address_space>::value;
317
+ if (!detail::is_aligned<ReqAlign>(in_ptr))
318
+ return group_load (g, in_ptr, out, use_naive{});
310
319
311
- // TODO: accessor_iterator's value_type is weird, so we need
312
- // `std::remove_const_t` below:
313
- //
314
- // static_assert(
315
- // std::is_same_v<
316
- // typename std::iterator_traits<
317
- // sycl::detail::accessor_iterator<const int, 1>>::value_type,
318
- // const int>);
319
- //
320
- // yet
321
- //
322
- // static_assert(
323
- // std::is_same_v<
324
- // typename std::iterator_traits<const int *>::value_type, int>);
325
-
326
- if constexpr (std::is_same_v<std::remove_const_t <value_type>, OutputT>) {
327
- static_assert (sizeof (load) == out.size_bytes ());
328
- sycl::detail::memcpy_no_adl (out.begin (), &load, out.size_bytes ());
329
- } else {
330
- std::remove_const_t <value_type> values[ElementsPerWorkItem];
331
- static_assert (sizeof (load) == sizeof (values));
332
- sycl::detail::memcpy_no_adl (values, &load, sizeof (values));
333
-
334
- // Note: can't `memcpy` directly into `out` because that might bypass
335
- // an implicit conversion required by the specification.
336
- for (int i = 0 ; i < ElementsPerWorkItem; ++i)
337
- out[i] = values[i];
320
+ // We know the pointer is aligned and the address space is known. Do the
321
+ // optimized load.
322
+ auto load = __spirv_SubgroupBlockReadINTEL<block_op_type>(ptr);
323
+
324
+ // TODO: accessor_iterator's value_type is weird, so we need
325
+ // `std::remove_const_t` below:
326
+ //
327
+ // static_assert(
328
+ // std::is_same_v<
329
+ // typename std::iterator_traits<
330
+ // sycl::detail::accessor_iterator<const int,
331
+ // 1>>::value_type,
332
+ // const int>);
333
+ //
334
+ // yet
335
+ //
336
+ // static_assert(
337
+ // std::is_same_v<
338
+ // typename std::iterator_traits<const int *>::value_type,
339
+ // int>);
340
+ if constexpr (std::is_same_v<std::remove_const_t <value_type>,
341
+ OutputT>) {
342
+ static_assert (sizeof (load) == out.size_bytes ());
343
+ sycl::detail::memcpy_no_adl (out.begin (), &load, out.size_bytes ());
344
+ } else {
345
+ std::remove_const_t <value_type> values[ElementsPerWorkItem];
346
+ static_assert (sizeof (load) == sizeof (values));
347
+ sycl::detail::memcpy_no_adl (values, &load, sizeof (values));
348
+
349
+ // Note: can't `memcpy` directly into `out` because that might bypass
350
+ // an implicit conversion required by the specification.
351
+ for (int i = 0 ; i < ElementsPerWorkItem; ++i)
352
+ out[i] = values[i];
353
+ }
338
354
}
339
-
340
- return ;
355
+ } else {
356
+ return group_load (g, in_ptr, out, use_naive{}) ;
341
357
}
342
358
}
343
359
}
@@ -365,55 +381,50 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
365
381
} else if constexpr (!std::is_same_v<Group, sycl::sub_group>) {
366
382
return group_store (g, in, out_ptr, use_naive{});
367
383
} else {
368
- auto ptr =
369
- detail::get_block_op_ptr<16 /* store align */ , ElementsPerWorkItem>(
370
- out_ptr, props);
371
- if (!ptr)
372
- return group_store (g, in, out_ptr, use_naive{});
384
+ auto ptr = detail::get_block_op_ptr<ElementsPerWorkItem>(out_ptr, props);
373
385
374
386
if constexpr (!std::is_same_v<std::nullptr_t , decltype (ptr)>) {
375
- using block_info = typename detail::BlockTypeInfo<
376
- detail::BlockInfo<OutputIteratorT, ElementsPerWorkItem, blocked>>;
377
387
static constexpr auto deduced_address_space =
378
- block_info::deduced_address_space;
379
- if constexpr (deduced_address_space ==
380
- access::address_space::local_space &&
381
- !props.template has_property <
382
- detail::native_local_block_io_key>())
383
- return group_store (g, in, out_ptr, use_naive{});
384
-
385
- // Do optimized store.
386
- std::remove_const_t <remove_decoration_t <
387
- typename std::iterator_traits<OutputIteratorT>::value_type>>
388
- values[ElementsPerWorkItem];
389
-
390
- for (int i = 0 ; i < ElementsPerWorkItem; ++i) {
391
- // Including implicit conversion.
392
- values[i] = in[i];
393
- }
394
-
395
- using block_op_type = typename block_info::block_op_type;
388
+ detail::deduce_AS<std::remove_cv_t <decltype (ptr)>>::value;
396
389
if constexpr (deduced_address_space ==
397
390
access::address_space::generic_space) {
398
391
if (auto local_ptr = detail::dynamic_address_cast<
399
392
access::address_space::local_space>(ptr)) {
400
- if constexpr (props.template has_property <
401
- detail::native_local_block_io_key>())
402
- __spirv_SubgroupBlockWriteINTEL (
403
- local_ptr, sycl::bit_cast<block_op_type>(values));
404
- else
405
- return group_store (g, in, out_ptr, use_naive{});
393
+ return group_store (g, in, local_ptr, props);
406
394
} else if (auto global_ptr = detail::dynamic_address_cast<
407
395
access::address_space::global_space>(ptr)) {
408
- __spirv_SubgroupBlockWriteINTEL (
409
- global_ptr, sycl::bit_cast<block_op_type>(values));
396
+ return group_store (g, in, global_ptr, props);
410
397
} else {
411
398
return group_store (g, in, out_ptr, use_naive{});
412
399
}
413
400
} else {
401
+ using block_info = typename detail::BlockTypeInfo<
402
+ detail::BlockInfo<OutputIteratorT, ElementsPerWorkItem, blocked>>;
403
+ using block_op_type = typename block_info::block_op_type;
404
+
405
+ // Alignment checks of the pointer.
406
+ constexpr int ReqAlign =
407
+ detail::RequiredAlignment<detail::operation_type::store,
408
+ deduced_address_space>::value;
409
+ if (!detail::is_aligned<ReqAlign>(out_ptr))
410
+ return group_store (g, in, out_ptr, use_naive{});
411
+
412
+ std::remove_const_t <remove_decoration_t <
413
+ typename std::iterator_traits<OutputIteratorT>::value_type>>
414
+ values[ElementsPerWorkItem];
415
+
416
+ for (int i = 0 ; i < ElementsPerWorkItem; ++i) {
417
+ // Including implicit conversion.
418
+ values[i] = in[i];
419
+ }
420
+
421
+ // We know the pointer is aligned and the address space is known. Do the
422
+ // optimized load.
414
423
__spirv_SubgroupBlockWriteINTEL (ptr,
415
424
sycl::bit_cast<block_op_type>(values));
416
425
}
426
+ } else {
427
+ return group_store (g, in, out_ptr, use_naive{});
417
428
}
418
429
}
419
430
}
0 commit comments