|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
17 | 17 | import contextlib
|
| 18 | +import itertools |
18 | 19 | import operator
|
19 | 20 |
|
20 | 21 | import numpy as np
|
21 | 22 |
|
| 23 | +import dpctl |
22 | 24 | import dpctl.tensor as dpt
|
| 25 | +import dpctl.tensor._tensor_impl as ti |
23 | 26 |
|
24 | 27 | __doc__ = "Print functions for :class:`dpctl.tensor.usm_ndarray`."
|
25 | 28 |
|
@@ -220,25 +223,44 @@ def print_options(*args, **kwargs):
|
220 | 223 | dpt.set_print_options(**options)
|
221 | 224 |
|
222 | 225 |
|
223 |
| -def _nd_corners(x, edge_items, slices=()): |
224 |
| - axes_reduced = len(slices) |
225 |
| - if axes_reduced == x.ndim: |
226 |
| - return x[slices] |
227 |
| - |
228 |
| - if x.shape[axes_reduced] > 2 * edge_items: |
229 |
| - return dpt.concat( |
230 |
| - ( |
231 |
| - _nd_corners( |
232 |
| - x, edge_items, slices + (slice(None, edge_items, None),) |
233 |
| - ), |
234 |
| - _nd_corners( |
235 |
| - x, edge_items, slices + (slice(-edge_items, None, None),) |
236 |
| - ), |
237 |
| - ), |
238 |
| - axis=axes_reduced, |
| 226 | +def _nd_corners(arr_in, edge_items): |
| 227 | + _shape = arr_in.shape |
| 228 | + max_shape = 2 * edge_items + 1 |
| 229 | + if max(_shape) <= max_shape: |
| 230 | + return dpt.asnumpy(arr_in) |
| 231 | + res_shape = tuple( |
| 232 | + max_shape if _shape[i] > max_shape else _shape[i] |
| 233 | + for i in range(arr_in.ndim) |
| 234 | + ) |
| 235 | + |
| 236 | + arr_out = dpt.empty( |
| 237 | + res_shape, |
| 238 | + dtype=arr_in.dtype, |
| 239 | + usm_type=arr_in.usm_type, |
| 240 | + sycl_queue=arr_in.sycl_queue, |
| 241 | + ) |
| 242 | + |
| 243 | + blocks = [] |
| 244 | + for i in range(len(_shape)): |
| 245 | + if _shape[i] > max_shape: |
| 246 | + blocks.append( |
| 247 | + ( |
| 248 | + np.s_[:edge_items], |
| 249 | + np.s_[-edge_items:], |
| 250 | + ) |
| 251 | + ) |
| 252 | + else: |
| 253 | + blocks.append((np.s_[:],)) |
| 254 | + |
| 255 | + hev_list = [] |
| 256 | + for slc in itertools.product(*blocks): |
| 257 | + hev, _ = ti._copy_usm_ndarray_into_usm_ndarray( |
| 258 | + src=arr_in[slc], dst=arr_out[slc], sycl_queue=arr_in.sycl_queue |
239 | 259 | )
|
240 |
| - else: |
241 |
| - return _nd_corners(x, edge_items, slices + (slice(None, None, None),)) |
| 260 | + hev_list.append(hev) |
| 261 | + |
| 262 | + dpctl.SyclEvent.wait_for(hev_list) |
| 263 | + return dpt.asnumpy(arr_out) |
242 | 264 |
|
243 | 265 |
|
244 | 266 | def usm_ndarray_str(
|
@@ -345,8 +367,7 @@ def usm_ndarray_str(
|
345 | 367 | edge_items = options["edgeitems"]
|
346 | 368 |
|
347 | 369 | if x.size > threshold:
|
348 |
| - # need edge_items + 1 elements for np.array2string to abbreviate |
349 |
| - data = dpt.asnumpy(_nd_corners(x, edge_items + 1)) |
| 370 | + data = _nd_corners(x, edge_items) |
350 | 371 | options["threshold"] = 0
|
351 | 372 | else:
|
352 | 373 | data = dpt.asnumpy(x)
|
|
0 commit comments