Skip to content

Commit 173488f

Browse files
Fixed deprecation warning about NumPy scalar construction
If the input Python object is out of range for the type being created, a warning is issued, e.g. `np.int16(512*1024)`. This commit implements the suggested work-around.
1 parent 456f46f commit 173488f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

dpctl/tensor/_ctors.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,12 @@ def _get_arange_length(start, stop, step):
785785
return _round_for_arange(tmp)
786786

787787

788+
def _to_scalar(obj, sc_ty):
789+
"A way to convert object to NumPy scalar type"
790+
zd_arr = np.asarray(obj).astype(sc_ty, casting="unsafe")
791+
return zd_arr[tuple()]
792+
793+
788794
def arange(
789795
start,
790796
/,
@@ -861,9 +867,9 @@ def arange(
861867
buffer_ctor_kwargs={"queue": sycl_queue},
862868
)
863869
sc_ty = dt.type
864-
_first = sc_ty(start)
870+
_first = _to_scalar(start, sc_ty)
865871
if sh > 1:
866-
_second = sc_ty(start + step)
872+
_second = _to_scalar(start + step, sc_ty)
867873
if dt in [dpt.uint8, dpt.uint16, dpt.uint32, dpt.uint64]:
868874
int64_ty = dpt.int64.type
869875
_step = int64_ty(_second) - int64_ty(_first)

0 commit comments

Comments
 (0)