Skip to content

Commit 5707661

Browse files
authored
Merge pull request #1187 from IntelPython/print_corners
print_corners
2 parents e01e270 + e153d77 commit 5707661

File tree

1 file changed

+41
-20
lines changed

1 file changed

+41
-20
lines changed

dpctl/tensor/_print.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
# limitations under the License.
1616

1717
import contextlib
18+
import itertools
1819
import operator
1920

2021
import numpy as np
2122

23+
import dpctl
2224
import dpctl.tensor as dpt
25+
import dpctl.tensor._tensor_impl as ti
2326

2427
__doc__ = "Print functions for :class:`dpctl.tensor.usm_ndarray`."
2528

@@ -220,25 +223,44 @@ def print_options(*args, **kwargs):
220223
dpt.set_print_options(**options)
221224

222225

223-
def _nd_corners(x, edge_items, slices=()):
224-
axes_reduced = len(slices)
225-
if axes_reduced == x.ndim:
226-
return x[slices]
227-
228-
if x.shape[axes_reduced] > 2 * edge_items:
229-
return dpt.concat(
230-
(
231-
_nd_corners(
232-
x, edge_items, slices + (slice(None, edge_items, None),)
233-
),
234-
_nd_corners(
235-
x, edge_items, slices + (slice(-edge_items, None, None),)
236-
),
237-
),
238-
axis=axes_reduced,
226+
def _nd_corners(arr_in, edge_items):
227+
_shape = arr_in.shape
228+
max_shape = 2 * edge_items + 1
229+
if max(_shape) <= max_shape:
230+
return dpt.asnumpy(arr_in)
231+
res_shape = tuple(
232+
max_shape if _shape[i] > max_shape else _shape[i]
233+
for i in range(arr_in.ndim)
234+
)
235+
236+
arr_out = dpt.empty(
237+
res_shape,
238+
dtype=arr_in.dtype,
239+
usm_type=arr_in.usm_type,
240+
sycl_queue=arr_in.sycl_queue,
241+
)
242+
243+
blocks = []
244+
for i in range(len(_shape)):
245+
if _shape[i] > max_shape:
246+
blocks.append(
247+
(
248+
np.s_[:edge_items],
249+
np.s_[-edge_items:],
250+
)
251+
)
252+
else:
253+
blocks.append((np.s_[:],))
254+
255+
hev_list = []
256+
for slc in itertools.product(*blocks):
257+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
258+
src=arr_in[slc], dst=arr_out[slc], sycl_queue=arr_in.sycl_queue
239259
)
240-
else:
241-
return _nd_corners(x, edge_items, slices + (slice(None, None, None),))
260+
hev_list.append(hev)
261+
262+
dpctl.SyclEvent.wait_for(hev_list)
263+
return dpt.asnumpy(arr_out)
242264

243265

244266
def usm_ndarray_str(
@@ -345,8 +367,7 @@ def usm_ndarray_str(
345367
edge_items = options["edgeitems"]
346368

347369
if x.size > threshold:
348-
# need edge_items + 1 elements for np.array2string to abbreviate
349-
data = dpt.asnumpy(_nd_corners(x, edge_items + 1))
370+
data = _nd_corners(x, edge_items)
350371
options["threshold"] = 0
351372
else:
352373
data = dpt.asnumpy(x)

0 commit comments

Comments
 (0)