@@ -58,6 +58,13 @@ struct naive_key : detail::compile_time_property_key<detail::PropKind::Naive> {
58
58
using value_t = property_value<naive_key>;
59
59
};
60
60
inline constexpr naive_key::value_t naive;
61
+
62
+ struct native_local_block_io_key
63
+ : detail::compile_time_property_key<detail::PropKind::NativeLocalBlockIO> {
64
+ using value_t = property_value<native_local_block_io_key>;
65
+ };
66
+ inline constexpr native_local_block_io_key::value_t native_local_block_io;
67
+
61
68
using namespace sycl ::detail;
62
69
} // namespace detail
63
70
@@ -154,7 +161,6 @@ template <typename BlockInfoTy> struct BlockTypeInfo;
154
161
template <typename IteratorT, std::size_t ElementsPerWorkItem, bool Blocked>
155
162
struct BlockTypeInfo <BlockInfo<IteratorT, ElementsPerWorkItem, Blocked>> {
156
163
using BlockInfoTy = BlockInfo<IteratorT, ElementsPerWorkItem, Blocked>;
157
- static_assert (BlockInfoTy::has_builtin);
158
164
159
165
using block_type = detail::fixed_width_unsigned<BlockInfoTy::block_size>;
160
166
@@ -163,15 +169,23 @@ struct BlockTypeInfo<BlockInfo<IteratorT, ElementsPerWorkItem, Blocked>> {
163
169
typename std::iterator_traits<IteratorT>::reference>>,
164
170
std::add_const_t <block_type>, block_type>;
165
171
166
- using block_pointer_type = typename detail::DecoratedType<
167
- block_pointer_elem_type, access::address_space::global_space>::type *;
172
+ static constexpr auto deduced_address_space =
173
+ detail::deduce_AS<std::remove_cv_t <IteratorT>>::value;
174
+
175
+ using block_pointer_type =
176
+ typename detail::DecoratedType<block_pointer_elem_type,
177
+ deduced_address_space>::type *;
178
+
168
179
using block_op_type = std::conditional_t <
169
180
BlockInfoTy::num_blocks == 1 , block_type,
170
181
detail::ConvertToOpenCLType_t<vec<block_type, BlockInfoTy::num_blocks>>>;
171
182
};
172
183
173
- // Returns either a pointer suitable to use in a block read/write builtin or
174
- // nullptr if some legality conditions aren't satisfied.
184
+ // Returns either a pointer decorated with the deduced address space, suitable
185
+ // to use in a block read/write builtin, or nullptr if some legality conditions
186
+ // aren't satisfied. If deduced address space is generic then returned pointer
187
+ // will have generic address space and has to be dynamically casted to global or
188
+ // local space before using in a builtin.
175
189
template <int RequiredAlign, std::size_t ElementsPerWorkItem,
176
190
typename IteratorT, typename Properties>
177
191
auto get_block_op_ptr (IteratorT iter, [[maybe_unused]] Properties props) {
@@ -211,16 +225,17 @@ auto get_block_op_ptr(IteratorT iter, [[maybe_unused]] Properties props) {
211
225
bool is_aligned = alignof (value_type) >= RequiredAlign ||
212
226
reinterpret_cast <uintptr_t >(iter) % RequiredAlign == 0 ;
213
227
214
- constexpr auto AS = detail::deduce_AS<iter_no_cv>::value;
215
228
using block_pointer_type =
216
229
typename BlockTypeInfo<BlkInfo>::block_pointer_type;
217
- if constexpr (AS == access::address_space::global_space) {
230
+
231
+ static constexpr auto deduced_address_space =
232
+ BlockTypeInfo<BlkInfo>::deduced_address_space;
233
+ if constexpr (deduced_address_space ==
234
+ access::address_space::generic_space ||
235
+ deduced_address_space ==
236
+ access::address_space::global_space ||
237
+ deduced_address_space == access::address_space::local_space) {
218
238
return is_aligned ? reinterpret_cast <block_pointer_type>(iter) : nullptr ;
219
- } else if constexpr (AS == access::address_space::generic_space) {
220
- return is_aligned ? reinterpret_cast <block_pointer_type>(
221
- detail::dynamic_address_cast<
222
- access::address_space::global_space>(iter))
223
- : nullptr ;
224
239
} else {
225
240
return nullptr ;
226
241
}
@@ -261,11 +276,37 @@ group_load(Group g, InputIteratorT in_ptr,
261
276
// Do optimized load.
262
277
using value_type = remove_decoration_t <
263
278
typename std::iterator_traits<InputIteratorT>::value_type>;
264
-
265
- auto load = __spirv_SubgroupBlockReadINTEL<
266
- typename detail::BlockTypeInfo<detail::BlockInfo<
267
- InputIteratorT, ElementsPerWorkItem, blocked>>::block_op_type>(
268
- ptr);
279
+ using block_info = typename detail::BlockTypeInfo<
280
+ detail::BlockInfo<InputIteratorT, ElementsPerWorkItem, blocked>>;
281
+ static constexpr auto deduced_address_space =
282
+ block_info::deduced_address_space;
283
+ using block_op_type = typename block_info::block_op_type;
284
+
285
+ if constexpr (deduced_address_space ==
286
+ access::address_space::local_space &&
287
+ !props.template has_property <
288
+ detail::native_local_block_io_key>())
289
+ return group_load (g, in_ptr, out, use_naive{});
290
+
291
+ block_op_type load;
292
+ if constexpr (deduced_address_space ==
293
+ access::address_space::generic_space) {
294
+ if (auto local_ptr = detail::dynamic_address_cast<
295
+ access::address_space::local_space>(ptr)) {
296
+ if constexpr (props.template has_property <
297
+ detail::native_local_block_io_key>())
298
+ load = __spirv_SubgroupBlockReadINTEL<block_op_type>(local_ptr);
299
+ else
300
+ return group_load (g, in_ptr, out, use_naive{});
301
+ } else if (auto global_ptr = detail::dynamic_address_cast<
302
+ access::address_space::global_space>(ptr)) {
303
+ load = __spirv_SubgroupBlockReadINTEL<block_op_type>(global_ptr);
304
+ } else {
305
+ return group_load (g, in_ptr, out, use_naive{});
306
+ }
307
+ } else {
308
+ load = __spirv_SubgroupBlockReadINTEL<block_op_type>(ptr);
309
+ }
269
310
270
311
// TODO: accessor_iterator's value_type is weird, so we need
271
312
// `std::remove_const_t` below:
@@ -331,6 +372,16 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
331
372
return group_store (g, in, out_ptr, use_naive{});
332
373
333
374
if constexpr (!std::is_same_v<std::nullptr_t , decltype (ptr)>) {
375
+ using block_info = typename detail::BlockTypeInfo<
376
+ detail::BlockInfo<OutputIteratorT, ElementsPerWorkItem, blocked>>;
377
+ static constexpr auto deduced_address_space =
378
+ block_info::deduced_address_space;
379
+ if constexpr (deduced_address_space ==
380
+ access::address_space::local_space &&
381
+ !props.template has_property <
382
+ detail::native_local_block_io_key>())
383
+ return group_store (g, in, out_ptr, use_naive{});
384
+
334
385
// Do optimized store.
335
386
std::remove_const_t <remove_decoration_t <
336
387
typename std::iterator_traits<OutputIteratorT>::value_type>>
@@ -341,11 +392,28 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
341
392
values[i] = in[i];
342
393
}
343
394
344
- __spirv_SubgroupBlockWriteINTEL (
345
- ptr,
346
- sycl::bit_cast<typename detail::BlockTypeInfo<detail::BlockInfo<
347
- OutputIteratorT, ElementsPerWorkItem, blocked>>::block_op_type>(
348
- values));
395
+ using block_op_type = typename block_info::block_op_type;
396
+ if constexpr (deduced_address_space ==
397
+ access::address_space::generic_space) {
398
+ if (auto local_ptr = detail::dynamic_address_cast<
399
+ access::address_space::local_space>(ptr)) {
400
+ if constexpr (props.template has_property <
401
+ detail::native_local_block_io_key>())
402
+ __spirv_SubgroupBlockWriteINTEL (
403
+ local_ptr, sycl::bit_cast<block_op_type>(values));
404
+ else
405
+ return group_store (g, in, out_ptr, use_naive{});
406
+ } else if (auto global_ptr = detail::dynamic_address_cast<
407
+ access::address_space::global_space>(ptr)) {
408
+ __spirv_SubgroupBlockWriteINTEL (
409
+ global_ptr, sycl::bit_cast<block_op_type>(values));
410
+ } else {
411
+ return group_store (g, in, out_ptr, use_naive{});
412
+ }
413
+ } else {
414
+ __spirv_SubgroupBlockWriteINTEL (ptr,
415
+ sycl::bit_cast<block_op_type>(values));
416
+ }
349
417
}
350
418
}
351
419
}
0 commit comments