@@ -52,15 +52,12 @@ namespace py_internal
52
52
namespace _ns = dpctl::tensor::detail;
53
53
54
54
using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_1d_fn_ptr_t ;
55
- using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_2d_fn_ptr_t ;
56
55
using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_generic_fn_ptr_t ;
57
56
58
57
static copy_and_cast_generic_fn_ptr_t
59
58
copy_and_cast_generic_dispatch_table[_ns::num_types][_ns::num_types];
60
59
static copy_and_cast_1d_fn_ptr_t
61
60
copy_and_cast_1d_dispatch_table[_ns::num_types][_ns::num_types];
62
- static copy_and_cast_2d_fn_ptr_t
63
- copy_and_cast_2d_dispatch_table[_ns::num_types][_ns::num_types];
64
61
65
62
namespace py = pybind11;
66
63
@@ -187,7 +184,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
187
184
simplified_shape, simplified_src_strides, simplified_dst_strides,
188
185
src_offset, dst_offset);
189
186
190
- if (nd < 3 ) {
187
+ if (nd < 2 ) {
191
188
if (nd == 1 ) {
192
189
std::array<py::ssize_t , 1 > shape_arr = {shape[0 ]};
193
190
// strides may be null
@@ -205,23 +202,6 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
205
202
keep_args_alive (exec_q, {src, dst}, {copy_and_cast_1d_event}),
206
203
copy_and_cast_1d_event);
207
204
}
208
- else if (nd == 2 ) {
209
- std::array<py::ssize_t , 2 > shape_arr = {shape[0 ], shape[1 ]};
210
- std::array<py::ssize_t , 2 > src_strides_arr = {src_strides[0 ],
211
- src_strides[1 ]};
212
- std::array<py::ssize_t , 2 > dst_strides_arr = {dst_strides[0 ],
213
- dst_strides[1 ]};
214
-
215
- auto fn = copy_and_cast_2d_dispatch_table[dst_type_id][src_type_id];
216
-
217
- sycl::event copy_and_cast_2d_event = fn (
218
- exec_q, src_nelems, shape_arr, src_strides_arr, dst_strides_arr,
219
- src_data, src_offset, dst_data, dst_offset, depends);
220
-
221
- return std::make_pair (
222
- keep_args_alive (exec_q, {src, dst}, {copy_and_cast_2d_event}),
223
- copy_and_cast_2d_event);
224
- }
225
205
else if (nd == 0 ) { // case of a scalar
226
206
assert (src_nelems == 1 );
227
207
std::array<py::ssize_t , 1 > shape_arr = {1 };
@@ -290,12 +270,6 @@ void init_copy_and_cast_usm_to_usm_dispatch_tables(void)
290
270
num_types>
291
271
dtb_1d;
292
272
dtb_1d.populate_dispatch_table (copy_and_cast_1d_dispatch_table);
293
-
294
- using dpctl::tensor::kernels::copy_and_cast::CopyAndCast2DFactory;
295
- DispatchTableBuilder<copy_and_cast_2d_fn_ptr_t , CopyAndCast2DFactory,
296
- num_types>
297
- dtb_2d;
298
- dtb_2d.populate_dispatch_table (copy_and_cast_2d_dispatch_table);
299
273
}
300
274
301
275
} // namespace py_internal
0 commit comments