|
19 | 19 |
|
20 | 20 | import numpy as np
|
21 | 21 |
|
| 22 | +import dpctl |
22 | 23 | import dpctl.tensor as dpt
|
| 24 | +import dpctl.tensor._tensor_impl as ti |
23 | 25 |
|
24 | 26 | __doc__ = "Print functions for :class:`dpctl.tensor.usm_ndarray`."
|
25 | 27 |
|
@@ -220,25 +222,43 @@ def print_options(*args, **kwargs):
|
220 | 222 | dpt.set_print_options(**options)
|
221 | 223 |
|
222 | 224 |
|
223 |
| -def _nd_corners(x, edge_items, slices=()): |
224 |
| - axes_reduced = len(slices) |
225 |
| - if axes_reduced == x.ndim: |
226 |
| - return x[slices] |
227 |
| - |
228 |
| - if x.shape[axes_reduced] > 2 * edge_items: |
229 |
| - return dpt.concat( |
230 |
| - ( |
231 |
| - _nd_corners( |
232 |
| - x, edge_items, slices + (slice(None, edge_items, None),) |
233 |
| - ), |
234 |
| - _nd_corners( |
235 |
| - x, edge_items, slices + (slice(-edge_items, None, None),) |
236 |
| - ), |
237 |
| - ), |
238 |
| - axis=axes_reduced, |
| 225 | +def _nd_corners(arr_in, edge_items): |
| 226 | + arr_ndim = arr_in.ndim |
| 227 | + res_shape = tuple( |
| 228 | + 2 * edge_items if arr_in.shape[i] > 2 * edge_items else arr_in.shape[i] |
| 229 | + for i in range(arr_ndim) |
| 230 | + ) |
| 231 | + |
| 232 | + arr_out = dpt.empty( |
| 233 | + res_shape, |
| 234 | + dtype=arr_in.dtype, |
| 235 | + usm_type=arr_in.usm_type, |
| 236 | + sycl_queue=arr_in.sycl_queue, |
| 237 | + ) |
| 238 | + |
| 239 | + hev_list = [] |
| 240 | + for corner in range(arr_ndim**2): |
| 241 | + slices = () |
| 242 | + tmp = bin(corner).replace("0b", "").zfill(arr_ndim) |
| 243 | + |
| 244 | + for dim in reversed(range(arr_ndim)): |
| 245 | + if arr_in.shape[dim] < 2 * edge_items: |
| 246 | + slices = (np.s_[:],) + slices |
| 247 | + else: |
| 248 | + ind = (-1) ** int(tmp[dim]) * edge_items |
| 249 | + if ind < 0: |
| 250 | + slices = (np.s_[-edge_items::],) + slices |
| 251 | + else: |
| 252 | + slices = (np.s_[:edge_items:],) + slices |
| 253 | + hev, _ = ti._copy_usm_ndarray_into_usm_ndarray( |
| 254 | + src=arr_in[slices], |
| 255 | + dst=arr_out[slices], |
| 256 | + sycl_queue=arr_in.sycl_queue, |
239 | 257 | )
|
240 |
| - else: |
241 |
| - return _nd_corners(x, edge_items, slices + (slice(None, None, None),)) |
| 258 | + hev_list.append(hev) |
| 259 | + |
| 260 | + dpctl.SyclEvent.wait_for(hev_list) |
| 261 | + return arr_out |
242 | 262 |
|
243 | 263 |
|
244 | 264 | def usm_ndarray_str(
|
|
0 commit comments