Skip to content

Commit bc59e74

Browse files
authored
[SYCL] Simplify group load/store implementation (#16890)
Simplify handling of multiple address spaces and alignment checks. Additional improvement regarding alignment checks is being done here (to perform compile-time alignment check instead of expensive dynamic check): #16882 Also this PR fixes alignment requirement for local address space: 16-byte alignment is required for both load and store.
1 parent af6aa41 commit bc59e74

File tree

3 files changed

+501
-502
lines changed

3 files changed

+501
-502
lines changed

sycl/include/sycl/ext/oneapi/experimental/group_load_store.hpp

Lines changed: 120 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@ int get_mem_idx(GroupTy g, int vec_or_array_idx) {
125125
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_char.html
126126
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_long.html
127127
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_short.html
128-
// Reads require 4-byte alignment, writes 16-byte alignment. Supported
129-
// sizes:
128+
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_local_block_io.html
129+
// Reads require 4-byte alignment for global pointers and 16-byte alignment for
130+
// local pointers, writes require 16-byte alignment. Supported sizes:
130131
//
131132
// +------------+-------------+
132133
// | block type | # of blocks |
@@ -156,6 +157,21 @@ struct BlockInfo {
156157
(num_blocks <= 8 || (num_blocks == 16 && block_size <= 2));
157158
};
158159

160+
enum class operation_type { load, store };
161+
162+
template <operation_type OpType, access::address_space Space>
163+
struct RequiredAlignment {};
164+
165+
template <operation_type OpType>
166+
struct RequiredAlignment<OpType, access::address_space::global_space> {
167+
static constexpr int value = (OpType == operation_type::load) ? 4 : 16;
168+
};
169+
170+
template <operation_type OpType>
171+
struct RequiredAlignment<OpType, access::address_space::local_space> {
172+
static constexpr int value = 16;
173+
};
174+
159175
template <typename BlockInfoTy> struct BlockTypeInfo;
160176

161177
template <typename IteratorT, std::size_t ElementsPerWorkItem, bool Blocked>
@@ -186,11 +202,10 @@ struct BlockTypeInfo<BlockInfo<IteratorT, ElementsPerWorkItem, Blocked>> {
186202
// aren't satisfied. If deduced address space is generic then returned pointer
187203
// will have generic address space and has to be dynamically casted to global or
188204
// local space before using in a builtin.
189-
template <int RequiredAlign, std::size_t ElementsPerWorkItem,
190-
typename IteratorT, typename Properties>
191-
auto get_block_op_ptr(IteratorT iter, [[maybe_unused]] Properties props) {
192-
using value_type =
193-
remove_decoration_t<typename std::iterator_traits<IteratorT>::value_type>;
205+
template <std::size_t ElementsPerWorkItem, typename IteratorT,
206+
typename Properties>
207+
constexpr auto get_block_op_ptr(IteratorT iter,
208+
[[maybe_unused]] Properties props) {
194209
using iter_no_cv = std::remove_cv_t<IteratorT>;
195210

196211
constexpr bool blocked = detail::isBlocked(props);
@@ -208,39 +223,46 @@ auto get_block_op_ptr(IteratorT iter, [[maybe_unused]] Properties props) {
208223
} else if constexpr (!props.template has_property<full_group_key>()) {
209224
return nullptr;
210225
} else if constexpr (detail::is_multi_ptr_v<IteratorT>) {
211-
return get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(
212-
iter.get_decorated(), props);
226+
return get_block_op_ptr<ElementsPerWorkItem>(iter.get_decorated(), props);
213227
} else if constexpr (!std::is_pointer_v<iter_no_cv>) {
214228
if constexpr (props.template has_property<contiguous_memory_key>())
215-
return get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(&*iter,
216-
props);
229+
return get_block_op_ptr<ElementsPerWorkItem>(&*iter, props);
217230
else
218231
return nullptr;
219232
} else {
220233
// Load/store to/from nullptr would be an UB, this assume allows the
221234
// compiler to optimize the IR further.
222235
__builtin_assume(iter != nullptr);
223236

224-
// No early return as that would mess up return type deduction.
225-
bool is_aligned = alignof(value_type) >= RequiredAlign ||
226-
reinterpret_cast<uintptr_t>(iter) % RequiredAlign == 0;
227-
228237
using block_pointer_type =
229238
typename BlockTypeInfo<BlkInfo>::block_pointer_type;
230239

231-
static constexpr auto deduced_address_space =
240+
constexpr auto deduced_address_space =
232241
BlockTypeInfo<BlkInfo>::deduced_address_space;
242+
233243
if constexpr (deduced_address_space ==
234244
access::address_space::generic_space ||
235245
deduced_address_space ==
236246
access::address_space::global_space ||
237-
deduced_address_space == access::address_space::local_space) {
238-
return is_aligned ? reinterpret_cast<block_pointer_type>(iter) : nullptr;
247+
(deduced_address_space ==
248+
access::address_space::local_space &&
249+
props.template has_property<
250+
detail::native_local_block_io_key>())) {
251+
return reinterpret_cast<block_pointer_type>(iter);
239252
} else {
240253
return nullptr;
241254
}
242255
}
243256
}
257+
258+
template <int RequiredAlign, typename IteratorType>
259+
bool is_aligned(IteratorType iter) {
260+
using value_type = remove_decoration_t<
261+
typename std::iterator_traits<IteratorType>::value_type>;
262+
return alignof(value_type) >= RequiredAlign ||
263+
reinterpret_cast<uintptr_t>(&*iter) % RequiredAlign == 0;
264+
}
265+
244266
} // namespace detail
245267

246268
// Load API span overload.
@@ -266,78 +288,72 @@ group_load(Group g, InputIteratorT in_ptr,
266288
} else if constexpr (!std::is_same_v<Group, sycl::sub_group>) {
267289
return group_load(g, in_ptr, out, use_naive{});
268290
} else {
269-
auto ptr =
270-
detail::get_block_op_ptr<4 /* load align */, ElementsPerWorkItem>(
271-
in_ptr, props);
272-
if (!ptr)
273-
return group_load(g, in_ptr, out, use_naive{});
291+
auto ptr = detail::get_block_op_ptr<ElementsPerWorkItem>(in_ptr, props);
292+
static constexpr auto deduced_address_space =
293+
detail::deduce_AS<std::remove_cv_t<decltype(ptr)>>::value;
274294

275295
if constexpr (!std::is_same_v<std::nullptr_t, decltype(ptr)>) {
276-
// Do optimized load.
277-
using value_type = remove_decoration_t<
278-
typename std::iterator_traits<InputIteratorT>::value_type>;
279-
using block_info = typename detail::BlockTypeInfo<
280-
detail::BlockInfo<InputIteratorT, ElementsPerWorkItem, blocked>>;
281-
static constexpr auto deduced_address_space =
282-
block_info::deduced_address_space;
283-
using block_op_type = typename block_info::block_op_type;
284-
285-
if constexpr (deduced_address_space ==
286-
access::address_space::local_space &&
287-
!props.template has_property<
288-
detail::native_local_block_io_key>())
289-
return group_load(g, in_ptr, out, use_naive{});
290-
291-
block_op_type load;
292296
if constexpr (deduced_address_space ==
293297
access::address_space::generic_space) {
294298
if (auto local_ptr = detail::dynamic_address_cast<
295299
access::address_space::local_space>(ptr)) {
296-
if constexpr (props.template has_property<
297-
detail::native_local_block_io_key>())
298-
load = __spirv_SubgroupBlockReadINTEL<block_op_type>(local_ptr);
299-
else
300-
return group_load(g, in_ptr, out, use_naive{});
300+
return group_load(g, local_ptr, out, props);
301301
} else if (auto global_ptr = detail::dynamic_address_cast<
302302
access::address_space::global_space>(ptr)) {
303-
load = __spirv_SubgroupBlockReadINTEL<block_op_type>(global_ptr);
303+
return group_load(g, global_ptr, out, props);
304304
} else {
305305
return group_load(g, in_ptr, out, use_naive{});
306306
}
307307
} else {
308-
load = __spirv_SubgroupBlockReadINTEL<block_op_type>(ptr);
309-
}
308+
using value_type = remove_decoration_t<
309+
typename std::iterator_traits<InputIteratorT>::value_type>;
310+
using block_info = typename detail::BlockTypeInfo<
311+
detail::BlockInfo<InputIteratorT, ElementsPerWorkItem, blocked>>;
312+
using block_op_type = typename block_info::block_op_type;
313+
// Alignment checks of the pointer.
314+
constexpr int ReqAlign =
315+
detail::RequiredAlignment<detail::operation_type::load,
316+
deduced_address_space>::value;
317+
if (!detail::is_aligned<ReqAlign>(in_ptr))
318+
return group_load(g, in_ptr, out, use_naive{});
310319

311-
// TODO: accessor_iterator's value_type is weird, so we need
312-
// `std::remove_const_t` below:
313-
//
314-
// static_assert(
315-
// std::is_same_v<
316-
// typename std::iterator_traits<
317-
// sycl::detail::accessor_iterator<const int, 1>>::value_type,
318-
// const int>);
319-
//
320-
// yet
321-
//
322-
// static_assert(
323-
// std::is_same_v<
324-
// typename std::iterator_traits<const int *>::value_type, int>);
325-
326-
if constexpr (std::is_same_v<std::remove_const_t<value_type>, OutputT>) {
327-
static_assert(sizeof(load) == out.size_bytes());
328-
sycl::detail::memcpy_no_adl(out.begin(), &load, out.size_bytes());
329-
} else {
330-
std::remove_const_t<value_type> values[ElementsPerWorkItem];
331-
static_assert(sizeof(load) == sizeof(values));
332-
sycl::detail::memcpy_no_adl(values, &load, sizeof(values));
333-
334-
// Note: can't `memcpy` directly into `out` because that might bypass
335-
// an implicit conversion required by the specification.
336-
for (int i = 0; i < ElementsPerWorkItem; ++i)
337-
out[i] = values[i];
320+
// We know the pointer is aligned and the address space is known. Do the
321+
// optimized load.
322+
auto load = __spirv_SubgroupBlockReadINTEL<block_op_type>(ptr);
323+
324+
// TODO: accessor_iterator's value_type is weird, so we need
325+
// `std::remove_const_t` below:
326+
//
327+
// static_assert(
328+
// std::is_same_v<
329+
// typename std::iterator_traits<
330+
// sycl::detail::accessor_iterator<const int,
331+
// 1>>::value_type,
332+
// const int>);
333+
//
334+
// yet
335+
//
336+
// static_assert(
337+
// std::is_same_v<
338+
// typename std::iterator_traits<const int *>::value_type,
339+
// int>);
340+
if constexpr (std::is_same_v<std::remove_const_t<value_type>,
341+
OutputT>) {
342+
static_assert(sizeof(load) == out.size_bytes());
343+
sycl::detail::memcpy_no_adl(out.begin(), &load, out.size_bytes());
344+
} else {
345+
std::remove_const_t<value_type> values[ElementsPerWorkItem];
346+
static_assert(sizeof(load) == sizeof(values));
347+
sycl::detail::memcpy_no_adl(values, &load, sizeof(values));
348+
349+
// Note: can't `memcpy` directly into `out` because that might bypass
350+
// an implicit conversion required by the specification.
351+
for (int i = 0; i < ElementsPerWorkItem; ++i)
352+
out[i] = values[i];
353+
}
338354
}
339-
340-
return;
355+
} else {
356+
return group_load(g, in_ptr, out, use_naive{});
341357
}
342358
}
343359
}
@@ -365,55 +381,50 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
365381
} else if constexpr (!std::is_same_v<Group, sycl::sub_group>) {
366382
return group_store(g, in, out_ptr, use_naive{});
367383
} else {
368-
auto ptr =
369-
detail::get_block_op_ptr<16 /* store align */, ElementsPerWorkItem>(
370-
out_ptr, props);
371-
if (!ptr)
372-
return group_store(g, in, out_ptr, use_naive{});
384+
auto ptr = detail::get_block_op_ptr<ElementsPerWorkItem>(out_ptr, props);
373385

374386
if constexpr (!std::is_same_v<std::nullptr_t, decltype(ptr)>) {
375-
using block_info = typename detail::BlockTypeInfo<
376-
detail::BlockInfo<OutputIteratorT, ElementsPerWorkItem, blocked>>;
377387
static constexpr auto deduced_address_space =
378-
block_info::deduced_address_space;
379-
if constexpr (deduced_address_space ==
380-
access::address_space::local_space &&
381-
!props.template has_property<
382-
detail::native_local_block_io_key>())
383-
return group_store(g, in, out_ptr, use_naive{});
384-
385-
// Do optimized store.
386-
std::remove_const_t<remove_decoration_t<
387-
typename std::iterator_traits<OutputIteratorT>::value_type>>
388-
values[ElementsPerWorkItem];
389-
390-
for (int i = 0; i < ElementsPerWorkItem; ++i) {
391-
// Including implicit conversion.
392-
values[i] = in[i];
393-
}
394-
395-
using block_op_type = typename block_info::block_op_type;
388+
detail::deduce_AS<std::remove_cv_t<decltype(ptr)>>::value;
396389
if constexpr (deduced_address_space ==
397390
access::address_space::generic_space) {
398391
if (auto local_ptr = detail::dynamic_address_cast<
399392
access::address_space::local_space>(ptr)) {
400-
if constexpr (props.template has_property<
401-
detail::native_local_block_io_key>())
402-
__spirv_SubgroupBlockWriteINTEL(
403-
local_ptr, sycl::bit_cast<block_op_type>(values));
404-
else
405-
return group_store(g, in, out_ptr, use_naive{});
393+
return group_store(g, in, local_ptr, props);
406394
} else if (auto global_ptr = detail::dynamic_address_cast<
407395
access::address_space::global_space>(ptr)) {
408-
__spirv_SubgroupBlockWriteINTEL(
409-
global_ptr, sycl::bit_cast<block_op_type>(values));
396+
return group_store(g, in, global_ptr, props);
410397
} else {
411398
return group_store(g, in, out_ptr, use_naive{});
412399
}
413400
} else {
401+
using block_info = typename detail::BlockTypeInfo<
402+
detail::BlockInfo<OutputIteratorT, ElementsPerWorkItem, blocked>>;
403+
using block_op_type = typename block_info::block_op_type;
404+
405+
// Alignment checks of the pointer.
406+
constexpr int ReqAlign =
407+
detail::RequiredAlignment<detail::operation_type::store,
408+
deduced_address_space>::value;
409+
if (!detail::is_aligned<ReqAlign>(out_ptr))
410+
return group_store(g, in, out_ptr, use_naive{});
411+
412+
std::remove_const_t<remove_decoration_t<
413+
typename std::iterator_traits<OutputIteratorT>::value_type>>
414+
values[ElementsPerWorkItem];
415+
416+
for (int i = 0; i < ElementsPerWorkItem; ++i) {
417+
// Including implicit conversion.
418+
values[i] = in[i];
419+
}
420+
421+
// We know the pointer is aligned and the address space is known. Do the
422+
// optimized load.
414423
__spirv_SubgroupBlockWriteINTEL(ptr,
415424
sycl::bit_cast<block_op_type>(values));
416425
}
426+
} else {
427+
return group_store(g, in, out_ptr, use_naive{});
417428
}
418429
}
419430
}

0 commit comments

Comments
 (0)