10
10
11
11
#pragma once
12
12
13
+ #include < sycl/ext/oneapi/experimental/annotated_ptr/annotated_ptr.hpp>
13
14
#include < sycl/ext/oneapi/properties/properties.hpp>
14
15
#include < sycl/group_barrier.hpp>
15
16
#include < sycl/sycl_span.hpp>
@@ -255,25 +256,29 @@ constexpr auto get_block_op_ptr(IteratorT iter,
255
256
}
256
257
}
257
258
258
- template <int RequiredAlign, typename IteratorType>
259
- bool is_aligned (IteratorType iter) {
259
+ template <int RequiredAlign, typename IteratorType, typename Properties >
260
+ bool is_aligned (IteratorType iter, [[maybe_unused]] Properties props ) {
260
261
using value_type = remove_decoration_t <
261
262
typename std::iterator_traits<IteratorType>::value_type>;
263
+
264
+ if constexpr (Properties::template has_property<alignment_key>()) {
265
+ if (Properties::template get_property<alignment_key>().value >=
266
+ RequiredAlign)
267
+ return true ;
268
+ }
269
+
262
270
return alignof (value_type) >= RequiredAlign ||
263
271
reinterpret_cast <uintptr_t >(&*iter) % RequiredAlign == 0 ;
264
272
}
265
273
266
- } // namespace detail
267
-
268
- // Load API span overload.
269
274
template <typename Group, typename InputIteratorT, typename OutputT,
270
275
std::size_t ElementsPerWorkItem,
271
276
typename Properties = decltype (properties())>
272
277
std::enable_if_t <detail::verify_load_types<InputIteratorT, OutputT> &&
273
278
detail::is_generic_group_v<Group> &&
274
279
is_property_list_v<Properties>>
275
- group_load (Group g, InputIteratorT in_ptr,
276
- span<OutputT, ElementsPerWorkItem> out, Properties props = {}) {
280
+ group_load_impl (Group g, InputIteratorT in_ptr,
281
+ span<OutputT, ElementsPerWorkItem> out, Properties props = {}) {
277
282
constexpr bool blocked = detail::isBlocked (props);
278
283
using use_naive =
279
284
detail::merged_properties_t <Properties,
@@ -286,7 +291,7 @@ group_load(Group g, InputIteratorT in_ptr,
286
291
group_barrier (g);
287
292
return ;
288
293
} else if constexpr (!std::is_same_v<Group, sycl::sub_group>) {
289
- return group_load (g, in_ptr, out, use_naive{});
294
+ return group_load_impl (g, in_ptr, out, use_naive{});
290
295
} else {
291
296
auto ptr = detail::get_block_op_ptr<ElementsPerWorkItem>(in_ptr, props);
292
297
static constexpr auto deduced_address_space =
@@ -297,12 +302,12 @@ group_load(Group g, InputIteratorT in_ptr,
297
302
access::address_space::generic_space) {
298
303
if (auto local_ptr = detail::dynamic_address_cast<
299
304
access::address_space::local_space>(ptr)) {
300
- return group_load (g, local_ptr, out, props);
305
+ return group_load_impl (g, local_ptr, out, props);
301
306
} else if (auto global_ptr = detail::dynamic_address_cast<
302
307
access::address_space::global_space>(ptr)) {
303
- return group_load (g, global_ptr, out, props);
308
+ return group_load_impl (g, global_ptr, out, props);
304
309
} else {
305
- return group_load (g, in_ptr, out, use_naive{});
310
+ return group_load_impl (g, in_ptr, out, use_naive{});
306
311
}
307
312
} else {
308
313
using value_type = remove_decoration_t <
@@ -314,8 +319,8 @@ group_load(Group g, InputIteratorT in_ptr,
314
319
constexpr int ReqAlign =
315
320
detail::RequiredAlignment<detail::operation_type::load,
316
321
deduced_address_space>::value;
317
- if (!detail::is_aligned<ReqAlign>(in_ptr))
318
- return group_load (g, in_ptr, out, use_naive{});
322
+ if (!detail::is_aligned<ReqAlign>(in_ptr, props ))
323
+ return group_load_impl (g, in_ptr, out, use_naive{});
319
324
320
325
// We know the pointer is aligned and the address space is known. Do the
321
326
// optimized load.
@@ -353,20 +358,21 @@ group_load(Group g, InputIteratorT in_ptr,
353
358
}
354
359
}
355
360
} else {
356
- return group_load (g, in_ptr, out, use_naive{});
361
+ return group_load_impl (g, in_ptr, out, use_naive{});
357
362
}
363
+
364
+ return ;
358
365
}
359
366
}
360
367
361
- // Store API span overload.
362
368
template <typename Group, typename InputT, std::size_t ElementsPerWorkItem,
363
369
typename OutputIteratorT,
364
370
typename Properties = decltype (properties())>
365
371
std::enable_if_t <detail::verify_store_types<InputT, OutputIteratorT> &&
366
372
detail::is_generic_group_v<Group> &&
367
373
is_property_list_v<Properties>>
368
- group_store (Group g, const span<InputT, ElementsPerWorkItem> in,
369
- OutputIteratorT out_ptr, Properties props = {}) {
374
+ group_store_impl (Group g, const span<InputT, ElementsPerWorkItem> in,
375
+ OutputIteratorT out_ptr, Properties props = {}) {
370
376
constexpr bool blocked = detail::isBlocked (props);
371
377
using use_naive =
372
378
detail::merged_properties_t <Properties,
@@ -379,7 +385,7 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
379
385
group_barrier (g);
380
386
return ;
381
387
} else if constexpr (!std::is_same_v<Group, sycl::sub_group>) {
382
- return group_store (g, in, out_ptr, use_naive{});
388
+ return group_store_impl (g, in, out_ptr, use_naive{});
383
389
} else {
384
390
auto ptr = detail::get_block_op_ptr<ElementsPerWorkItem>(out_ptr, props);
385
391
@@ -390,12 +396,12 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
390
396
access::address_space::generic_space) {
391
397
if (auto local_ptr = detail::dynamic_address_cast<
392
398
access::address_space::local_space>(ptr)) {
393
- return group_store (g, in, local_ptr, props);
399
+ return group_store_impl (g, in, local_ptr, props);
394
400
} else if (auto global_ptr = detail::dynamic_address_cast<
395
401
access::address_space::global_space>(ptr)) {
396
- return group_store (g, in, global_ptr, props);
402
+ return group_store_impl (g, in, global_ptr, props);
397
403
} else {
398
- return group_store (g, in, out_ptr, use_naive{});
404
+ return group_store_impl (g, in, out_ptr, use_naive{});
399
405
}
400
406
} else {
401
407
using block_info = typename detail::BlockTypeInfo<
@@ -406,8 +412,8 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
406
412
constexpr int ReqAlign =
407
413
detail::RequiredAlignment<detail::operation_type::store,
408
414
deduced_address_space>::value;
409
- if (!detail::is_aligned<ReqAlign>(out_ptr))
410
- return group_store (g, in, out_ptr, use_naive{});
415
+ if (!detail::is_aligned<ReqAlign>(out_ptr, props ))
416
+ return group_store_impl (g, in, out_ptr, use_naive{});
411
417
412
418
std::remove_const_t <remove_decoration_t <
413
419
typename std::iterator_traits<OutputIteratorT>::value_type>>
@@ -424,10 +430,41 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
424
430
sycl::bit_cast<block_op_type>(values));
425
431
}
426
432
} else {
427
- return group_store (g, in, out_ptr, use_naive{});
433
+ return group_store_impl (g, in, out_ptr, use_naive{});
428
434
}
429
435
}
430
436
}
437
+ } // namespace detail
438
+
439
+ // Load API span overload.
440
+ template <typename Group, typename InputIteratorT, typename OutputT,
441
+ std::size_t ElementsPerWorkItem,
442
+ typename Properties = decltype (properties())>
443
+ std::enable_if_t <detail::verify_load_types<InputIteratorT, OutputT> &&
444
+ detail::is_generic_group_v<Group> &&
445
+ is_property_list_v<Properties>>
446
+ group_load (Group g, InputIteratorT in_ptr,
447
+ span<OutputT, ElementsPerWorkItem> out, Properties props = {}) {
448
+ static_assert (std::is_pointer_v<InputIteratorT> ||
449
+ !Properties::template has_property<alignment_key>(),
450
+ " group_load requires a pointer if alignment property is used" );
451
+ detail::group_load_impl (g, in_ptr, out, props);
452
+ }
453
+
454
+ // Store API span overload.
455
+ template <typename Group, typename InputT, std::size_t ElementsPerWorkItem,
456
+ typename OutputIteratorT,
457
+ typename Properties = decltype (properties())>
458
+ std::enable_if_t <detail::verify_store_types<InputT, OutputIteratorT> &&
459
+ detail::is_generic_group_v<Group> &&
460
+ is_property_list_v<Properties>>
461
+ group_store (Group g, const span<InputT, ElementsPerWorkItem> in,
462
+ OutputIteratorT out_ptr, Properties props = {}) {
463
+ static_assert (std::is_pointer_v<OutputIteratorT> ||
464
+ !Properties::template has_property<alignment_key>(),
465
+ " group_store requires a pointer if alignment property is used" );
466
+ detail::group_store_impl (g, in, out_ptr, props);
467
+ }
431
468
432
469
// Load API scalar.
433
470
template <typename Group, typename InputIteratorT, typename OutputT,
@@ -437,7 +474,10 @@ std::enable_if_t<detail::verify_load_types<InputIteratorT, OutputT> &&
437
474
is_property_list_v<Properties>>
438
475
group_load (Group g, InputIteratorT in_ptr, OutputT &out,
439
476
Properties properties = {}) {
440
- group_load (g, in_ptr, span<OutputT, 1 >(&out, 1 ), properties);
477
+ static_assert (std::is_pointer_v<InputIteratorT> ||
478
+ !Properties::template has_property<alignment_key>(),
479
+ " group_load requires a pointer if alignment property is used" );
480
+ detail::group_load_impl (g, in_ptr, span<OutputT, 1 >(&out, 1 ), properties);
441
481
}
442
482
443
483
// Store API scalar.
@@ -448,7 +488,11 @@ std::enable_if_t<detail::verify_store_types<InputT, OutputIteratorT> &&
448
488
is_property_list_v<Properties>>
449
489
group_store (Group g, const InputT &in, OutputIteratorT out_ptr,
450
490
Properties properties = {}) {
451
- group_store (g, span<const InputT, 1 >(&in, 1 ), out_ptr, properties);
491
+ static_assert (std::is_pointer_v<OutputIteratorT> ||
492
+ !Properties::template has_property<alignment_key>(),
493
+ " group_store requires a pointer if alignment property is used" );
494
+ detail::group_store_impl (g, span<const InputT, 1 >(&in, 1 ), out_ptr,
495
+ properties);
452
496
}
453
497
454
498
// Load API sycl::vec overload.
@@ -459,7 +503,10 @@ std::enable_if_t<detail::verify_load_types<InputIteratorT, OutputT> &&
459
503
is_property_list_v<Properties>>
460
504
group_load (Group g, InputIteratorT in_ptr, sycl::vec<OutputT, N> &out,
461
505
Properties properties = {}) {
462
- group_load (g, in_ptr, span<OutputT, N>(&out[0 ], N), properties);
506
+ static_assert (std::is_pointer_v<InputIteratorT> ||
507
+ !Properties::template has_property<alignment_key>(),
508
+ " group_load requires a pointer if alignment property is used" );
509
+ detail::group_load_impl (g, in_ptr, span<OutputT, N>(&out[0 ], N), properties);
463
510
}
464
511
465
512
// Store API sycl::vec overload.
@@ -470,7 +517,11 @@ std::enable_if_t<detail::verify_store_types<InputT, OutputIteratorT> &&
470
517
is_property_list_v<Properties>>
471
518
group_store (Group g, const sycl::vec<InputT, N> &in, OutputIteratorT out_ptr,
472
519
Properties properties = {}) {
473
- group_store (g, span<const InputT, N>(&in[0 ], N), out_ptr, properties);
520
+ static_assert (std::is_pointer_v<OutputIteratorT> ||
521
+ !Properties::template has_property<alignment_key>(),
522
+ " group_store requires a pointer if alignment property is used" );
523
+ detail::group_store_impl (g, span<const InputT, N>(&in[0 ], N), out_ptr,
524
+ properties);
474
525
}
475
526
476
527
#else
0 commit comments