@@ -6,46 +6,49 @@ static constexpr bool is_arithmetic_v() {
6
6
return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> || std::is_same_v<T, sycl::ext::oneapi::bfloat16>;
7
7
}
8
8
}
9
+
9
10
template <typename TIn, typename TOut>
10
11
static inline std::enable_if_t <utils::is_arithmetic_v<TIn>() && utils::is_arithmetic_v<TOut>(), void >
11
12
convert (const char * src, char * dst) {
12
13
auto src_val = *reinterpret_cast <const TIn*>(src);
13
14
auto dst_val = sycl::vec<TIn, 1 >(src_val).template convert <TOut, sycl::rounding_mode::automatic>()[0 ];
14
- *reinterpret_cast <TOut*>(dst) = dst_val;;
15
+ *reinterpret_cast <TOut*>(dst) = dst_val;
15
16
}
16
17
17
18
template <typename TIn, typename TOut>
18
19
static void k_set_rows (
19
20
const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
20
- const int64_t ne00, const int64_t ne01, const int64_t ne11, const int64_t ne12,
21
+ const int64_t ne00, const int64_t ne01, const int64_t ne02,
22
+ const int64_t ne11, const int64_t ne12,
21
23
const size_t nb01, const size_t nb02, const size_t nb03,
22
24
const size_t nb10, const size_t nb11, const size_t nb12,
23
25
const size_t nb1, const size_t nb2, const size_t nb3,
24
26
const size_t src_type_size, const size_t dst_type_size,
25
- const sycl::nd_item<3 > & item_ct1) {
26
-
27
- const int i03 = item_ct1.get_group (0 );
28
- const int i02 = item_ct1.get_group (1 );
29
- const int i01 = item_ct1.get_group (2 ) * item_ct1.get_local_range (1 ) + item_ct1.get_local_id (1 ); // Row index
27
+ const int64_t total_elements,
28
+ const sycl::nd_item<1 > & item_ct1) {
30
29
31
- if (i01 >= ne01) {
30
+ const int64_t i = item_ct1.get_global_linear_id ();
31
+ if (i >= total_elements) {
32
32
return ;
33
33
}
34
34
35
- const int i12 = i03 % ne12;
36
- const int i11 = i02 % ne11;
37
- const int i10 = i01;
35
+ const int64_t i03 = i / (ne00 * ne01 * ne02);
36
+ const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
37
+ const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
38
+ const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
39
+
40
+ const int64_t i12 = i03 % ne12;
41
+ const int64_t i11 = i02 % ne11;
42
+ const int64_t i10 = i01;
38
43
39
44
const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3 >({nb10, nb11, nb12}, {i10, i11, i12}));
40
45
41
46
const char * src0_row = src0 + calculate_offset<3 >({nb01, nb02, nb03}, {i01, i02, i03});
42
- char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
47
+ const char * src_elem = src0_row + i00 * src_type_size;
48
+ char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
49
+ char * dst_elem = dst_row_ptr + i00 * dst_type_size;
43
50
44
- for (int col = item_ct1.get_local_id (0 ); col < ne00; col += item_ct1.get_local_range (0 )) {
45
- const char * src_elem = src0_row + col * src_type_size;
46
- char * dst_elem = dst_row_ptr + col * dst_type_size;
47
- convert<TIn, TOut>(src_elem, dst_elem);
48
- }
51
+ convert<TIn, TOut>(src_elem, dst_elem);
49
52
}
50
53
51
54
template <typename TIn, typename TOut>
@@ -58,32 +61,29 @@ static void set_rows_sycl(
58
61
const size_t src_type_size, const size_t dst_type_size,
59
62
queue_ptr stream) {
60
63
61
- constexpr int max_threads_per_row = 64 ; // KEEPING 64 for now
62
- const int threads_per_row = std::min ((int )ne00, max_threads_per_row);
63
-
64
- constexpr int max_threads_per_block = 64 ;
65
- const int rows_per_block = std::max (1 , max_threads_per_block / threads_per_row);
66
-
67
- const sycl::range<3 > block_size (1 , rows_per_block, threads_per_row);
68
- const sycl::range<3 > grid_size (ne03, ne02, (ne01 + rows_per_block - 1 ) / rows_per_block);
69
-
70
- sycl_parallel_for (
71
- stream,
72
- sycl::nd_range<3 >(grid_size * block_size, block_size),
73
- [=](sycl::nd_item<3 > item_ct1) {
74
- k_set_rows<TIn, TOut>(
75
- src0_d, src1_d, dst_d,
76
- ne00, ne01, ne11, ne12,
77
- nb01, nb02, nb03,
78
- nb10, nb11, nb12,
79
- nb1, nb2, nb3,
80
- src_type_size, dst_type_size,
81
- item_ct1
82
- );
83
- }
84
- );
85
- }
64
+ const int64_t total_elements = ne00 * ne01 * ne02 * ne03;
86
65
66
+ constexpr int block_size = 64 ;
67
+ const int64_t grid_size = ceil_div (total_elements, block_size);
68
+
69
+ sycl_parallel_for (
70
+ stream,
71
+ sycl::nd_range<1 >(grid_size * block_size, block_size),
72
+ [=](sycl::nd_item<1 > item_ct1) {
73
+ k_set_rows<TIn, TOut>(
74
+ src0_d, src1_d, dst_d,
75
+ ne00, ne01, ne02,
76
+ ne11, ne12,
77
+ nb01, nb02, nb03,
78
+ nb10, nb11, nb12,
79
+ nb1, nb2, nb3,
80
+ src_type_size, dst_type_size,
81
+ total_elements,
82
+ item_ct1
83
+ );
84
+ }
85
+ );
86
+ }
87
87
88
88
void ggml_sycl_op_set_rows (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
89
89
scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 2 );
@@ -122,7 +122,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
122
122
nb1, nb2, nb3,
123
123
sizeof (float ), sizeof (sycl::half),
124
124
stream
125
- );
125
+ );
126
126
break ;
127
127
default :
128
128
GGML_ABORT (" Unsupported tensor type!" );
0 commit comments