|
| 1 | +/*************************************************************************** |
| 2 | + * |
| 3 | + * Copyright (C) Codeplay Software Ltd. |
| 4 | + * |
| 5 | + * Part of the LLVM Project, under the Apache License v2.0 with LLVM |
| 6 | + * Exceptions. See https://llvm.org/LICENSE.txt for license information. |
| 7 | + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 8 | + * |
| 9 | + * Unless required by applicable law or agreed to in writing, software |
| 10 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | + * See the License for the specific language governing permissions and |
| 13 | + * limitations under the License. |
| 14 | + * |
| 15 | + * SYCLcompat |
| 16 | + * |
| 17 | + * max_active_work_groups_per_cu.cpp |
| 18 | + * |
| 19 | + * Description: |
| 20 | + * Test the syclcompat::max_active_work_groups_per_cu API |
| 21 | + **************************************************************************/ |
| 22 | +// RUN: %{build} -o %t.out |
| 23 | +// RUN: %{run} %t.out |
| 24 | + |
| 25 | +#include "sycl/accessor.hpp" |
| 26 | +#include <sycl/detail/core.hpp> |
| 27 | +#include <syclcompat/util.hpp> |
| 28 | + |
| 29 | +template <class T, size_t Dim> |
| 30 | +using sycl_global_accessor = |
| 31 | + sycl::accessor<T, Dim, sycl::access::mode::read_write, |
| 32 | + sycl::access::target::global_buffer>; |
| 33 | + |
| 34 | +using value_type = int; |
| 35 | + |
| 36 | +template <int RangeDim> struct MyKernel { |
| 37 | + MyKernel(sycl_global_accessor<value_type, RangeDim> acc) : acc_{acc} {} |
| 38 | + void operator()(sycl::nd_item<RangeDim> item) const { |
| 39 | + auto gid = item.get_global_id(); |
| 40 | + acc_[gid] = item.get_global_linear_id(); |
| 41 | + } |
| 42 | + sycl_global_accessor<value_type, RangeDim> acc_; |
| 43 | + static constexpr bool has_local_mem = false; |
| 44 | +}; |
| 45 | + |
| 46 | +template <int RangeDim> struct MyLocalMemKernel { |
| 47 | + MyLocalMemKernel(sycl_global_accessor<value_type, RangeDim> acc, |
| 48 | + sycl::local_accessor<value_type, RangeDim> lacc) |
| 49 | + : acc_{acc}, lacc_{lacc} {} |
| 50 | + void operator()(sycl::nd_item<RangeDim> item) const { |
| 51 | + auto gid = item.get_global_id(); |
| 52 | + acc_[gid] = item.get_global_linear_id(); |
| 53 | + auto lid = item.get_local_id(); |
| 54 | + lacc_[lid] = item.get_global_linear_id(); |
| 55 | + } |
| 56 | + sycl_global_accessor<value_type, RangeDim> acc_; |
| 57 | + sycl::local_accessor<value_type, RangeDim> lacc_; |
| 58 | + static constexpr bool has_local_mem = true; |
| 59 | +}; |
| 60 | + |
| 61 | +template <template <int> class KernelName, int RangeDim> |
| 62 | +void test_max_active_work_groups_per_cu(sycl::queue q, |
| 63 | + sycl::range<RangeDim> wg_range, |
| 64 | + size_t local_mem_size = 0) { |
| 65 | + if constexpr (!KernelName<RangeDim>::has_local_mem) |
| 66 | + assert(local_mem_size == 0 && "Bad test setup"); |
| 67 | + |
| 68 | + size_t max_per_cu = syclcompat::max_active_work_groups_per_cu<KernelName<RangeDim>>( |
| 69 | + wg_range, local_mem_size, q); |
| 70 | + |
| 71 | + // Check we get the same result passing equivalent dim3 |
| 72 | + syclcompat::dim3 wg_dim3{wg_range}; |
| 73 | + size_t max_per_cu_dim3 = syclcompat::max_active_work_groups_per_cu<KernelName<RangeDim>>( |
| 74 | + wg_dim3, local_mem_size, q); |
| 75 | + assert(max_per_cu == max_per_cu_dim3); |
| 76 | + |
| 77 | + // Compare w/ reference impl |
| 78 | + size_t max_compute_units = |
| 79 | + q.get_device().get_info<sycl::info::device::max_compute_units>(); |
| 80 | + namespace syclex = sycl::ext::oneapi::experimental; |
| 81 | + auto ctx = q.get_context(); |
| 82 | + auto bundle = sycl::get_kernel_bundle<sycl::bundle_state::executable>(ctx); |
| 83 | + auto kernel = bundle.template get_kernel<KernelName<RangeDim>>(); |
| 84 | + size_t max_wgs = kernel.template ext_oneapi_get_info< |
| 85 | + syclex::info::kernel_queue_specific::max_num_work_groups>( |
| 86 | + q, sycl::range<3>{syclcompat::dim3{wg_range}}, local_mem_size); |
| 87 | + assert(max_per_cu == max_wgs / max_compute_units); |
| 88 | + |
| 89 | + // We aren't interested in the launch, it's here to define the kernel |
| 90 | + if (false) { |
| 91 | + sycl::range<RangeDim> global_range = wg_range; |
| 92 | + if(max_per_cu > 0) |
| 93 | + global_range[0] = global_range[0] * max_per_cu * max_compute_units; |
| 94 | + sycl::nd_range<RangeDim> my_range{global_range, wg_range}; |
| 95 | + sycl::buffer<value_type, RangeDim> buf{global_range}; |
| 96 | + |
| 97 | + q.submit([&](sycl::handler &cgh) { |
| 98 | + auto acc = buf.template get_access<sycl::access::mode::read_write>(cgh); |
| 99 | + if constexpr (KernelName<RangeDim>::has_local_mem) { |
| 100 | + sycl::local_accessor<value_type, RangeDim> lacc( |
| 101 | + my_range.get_local_range(), cgh); |
| 102 | + cgh.parallel_for(my_range, KernelName<RangeDim>{acc, lacc}); |
| 103 | + } else { |
| 104 | + cgh.parallel_for(my_range, KernelName<RangeDim>{acc}); |
| 105 | + } |
| 106 | + }); |
| 107 | + } |
| 108 | +} |
| 109 | + |
| 110 | +int main() { |
| 111 | + sycl::queue q{}; |
| 112 | + sycl::range<1> range_1d{32}; |
| 113 | + sycl::range<2> range_2d{1, 32}; |
| 114 | + sycl::range<3> range_3d{1, 1, 32}; |
| 115 | + syclcompat::dim3 wg_dim3{32, 1, 1}; |
| 116 | + |
| 117 | + size_t lmem_size_small = sizeof(value_type) * 32; |
| 118 | + size_t lmem_size_medium = lmem_size_small * 32; |
| 119 | + size_t lmem_size_large = lmem_size_medium * 32; |
| 120 | + |
| 121 | + test_max_active_work_groups_per_cu<MyKernel, 3>(q, range_3d); |
| 122 | + test_max_active_work_groups_per_cu<MyKernel, 2>(q, range_2d); |
| 123 | + test_max_active_work_groups_per_cu<MyKernel, 1>(q, range_1d); |
| 124 | + test_max_active_work_groups_per_cu<MyLocalMemKernel, 3>(q, range_3d, |
| 125 | + lmem_size_small); |
| 126 | + test_max_active_work_groups_per_cu<MyLocalMemKernel, 3>(q, range_3d, |
| 127 | + lmem_size_medium); |
| 128 | + test_max_active_work_groups_per_cu<MyLocalMemKernel, 3>(q, range_3d, |
| 129 | + lmem_size_large); |
| 130 | + test_max_active_work_groups_per_cu<MyLocalMemKernel, 1>(q, range_1d, |
| 131 | + lmem_size_large); |
| 132 | + return 0; |
| 133 | +} |
0 commit comments