Skip to content

Commit b4eafff

Browse files
committed
initial commit
1 parent 4e41318 commit b4eafff

File tree

1 file changed

+38
-18
lines changed

1 file changed

+38
-18
lines changed

dpctl/tensor/_print.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
import numpy as np
2121

22+
import dpctl
2223
import dpctl.tensor as dpt
24+
import dpctl.tensor._tensor_impl as ti
2325

2426
__doc__ = "Print functions for :class:`dpctl.tensor.usm_ndarray`."
2527

@@ -220,25 +222,43 @@ def print_options(*args, **kwargs):
220222
dpt.set_print_options(**options)
221223

222224

223-
def _nd_corners(x, edge_items, slices=()):
224-
axes_reduced = len(slices)
225-
if axes_reduced == x.ndim:
226-
return x[slices]
227-
228-
if x.shape[axes_reduced] > 2 * edge_items:
229-
return dpt.concat(
230-
(
231-
_nd_corners(
232-
x, edge_items, slices + (slice(None, edge_items, None),)
233-
),
234-
_nd_corners(
235-
x, edge_items, slices + (slice(-edge_items, None, None),)
236-
),
237-
),
238-
axis=axes_reduced,
225+
def _nd_corners(arr_in, edge_items):
226+
arr_ndim = arr_in.ndim
227+
res_shape = tuple(
228+
2 * edge_items if arr_in.shape[i] > 2 * edge_items else arr_in.shape[i]
229+
for i in range(arr_ndim)
230+
)
231+
232+
arr_out = dpt.empty(
233+
res_shape,
234+
dtype=arr_in.dtype,
235+
usm_type=arr_in.usm_type,
236+
sycl_queue=arr_in.sycl_queue,
237+
)
238+
239+
hev_list = []
240+
for corner in range(arr_ndim**2):
241+
slices = ()
242+
tmp = bin(corner).replace("0b", "").zfill(arr_ndim)
243+
244+
for dim in reversed(range(arr_ndim)):
245+
if arr_in.shape[dim] < 2 * edge_items:
246+
slices = (np.s_[:],) + slices
247+
else:
248+
ind = (-1) ** int(tmp[dim]) * edge_items
249+
if ind < 0:
250+
slices = (np.s_[-edge_items::],) + slices
251+
else:
252+
slices = (np.s_[:edge_items:],) + slices
253+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
254+
src=arr_in[slices],
255+
dst=arr_out[slices],
256+
sycl_queue=arr_in.sycl_queue,
239257
)
240-
else:
241-
return _nd_corners(x, edge_items, slices + (slice(None, None, None),))
258+
hev_list.append(hev)
259+
260+
dpctl.SyclEvent.wait_for(hev_list)
261+
return arr_out
242262

243263

244264
def usm_ndarray_str(

0 commit comments

Comments
 (0)