Skip to content

Commit c050404

Browse files
committed
factor common utilities for scalar arguments to a new file
1 parent 5355fb8 commit c050404

File tree

5 files changed

+130
-96
lines changed

5 files changed

+130
-96
lines changed

dpctl/tensor/_clip.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@
2323
_empty_like_pair_orderK,
2424
_empty_like_triple_orderK,
2525
)
26-
from dpctl.tensor._elementwise_common import (
26+
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
27+
from dpctl.tensor._type_utils import _can_cast
28+
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
29+
30+
from ._scalar_utils import (
2731
_get_dtype,
2832
_get_queue_usm_type,
2933
_get_shape,
3034
_validate_dtype,
3135
)
32-
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
33-
from dpctl.tensor._type_utils import _can_cast
34-
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
35-
3636
from ._type_utils import (
3737
_resolve_one_strong_one_weak_types,
3838
_resolve_one_strong_two_weak_types,

dpctl/tensor/_elementwise_common.py

Lines changed: 6 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,27 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import numbers
18-
19-
import numpy as np
20-
2117
import dpctl
22-
import dpctl.memory as dpm
2318
import dpctl.tensor as dpt
2419
import dpctl.tensor._tensor_impl as ti
2520
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
26-
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
2721
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
2822

2923
from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
24+
from ._scalar_utils import (
25+
_get_dtype,
26+
_get_queue_usm_type,
27+
_get_shape,
28+
_validate_dtype,
29+
)
3030
from ._type_utils import (
31-
WeakBooleanType,
32-
WeakComplexType,
33-
WeakFloatingType,
34-
WeakIntegralType,
3531
_acceptance_fn_default_binary,
3632
_acceptance_fn_default_unary,
3733
_all_data_types,
3834
_find_buf_dtype,
3935
_find_buf_dtype2,
4036
_find_buf_dtype_in_place_op,
4137
_resolve_weak_types,
42-
_to_device_supported_dtype,
4338
)
4439

4540

@@ -289,78 +284,6 @@ def __call__(self, x, /, *, out=None, order="K"):
289284
return out
290285

291286

292-
def _get_queue_usm_type(o):
293-
"""Return SYCL device where object `o` allocated memory, or None."""
294-
if isinstance(o, dpt.usm_ndarray):
295-
return o.sycl_queue, o.usm_type
296-
elif hasattr(o, "__sycl_usm_array_interface__"):
297-
try:
298-
m = dpm.as_usm_memory(o)
299-
return m.sycl_queue, m.get_usm_type()
300-
except Exception:
301-
return None, None
302-
return None, None
303-
304-
305-
def _get_dtype(o, dev):
306-
if isinstance(o, dpt.usm_ndarray):
307-
return o.dtype
308-
if hasattr(o, "__sycl_usm_array_interface__"):
309-
return dpt.asarray(o).dtype
310-
if _is_buffer(o):
311-
host_dt = np.array(o).dtype
312-
dev_dt = _to_device_supported_dtype(host_dt, dev)
313-
return dev_dt
314-
if hasattr(o, "dtype"):
315-
dev_dt = _to_device_supported_dtype(o.dtype, dev)
316-
return dev_dt
317-
if isinstance(o, bool):
318-
return WeakBooleanType(o)
319-
if isinstance(o, int):
320-
return WeakIntegralType(o)
321-
if isinstance(o, float):
322-
return WeakFloatingType(o)
323-
if isinstance(o, complex):
324-
return WeakComplexType(o)
325-
return np.object_
326-
327-
328-
def _validate_dtype(dt) -> bool:
329-
return isinstance(
330-
dt,
331-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
332-
) or (
333-
isinstance(dt, dpt.dtype)
334-
and dt
335-
in [
336-
dpt.bool,
337-
dpt.int8,
338-
dpt.uint8,
339-
dpt.int16,
340-
dpt.uint16,
341-
dpt.int32,
342-
dpt.uint32,
343-
dpt.int64,
344-
dpt.uint64,
345-
dpt.float16,
346-
dpt.float32,
347-
dpt.float64,
348-
dpt.complex64,
349-
dpt.complex128,
350-
]
351-
)
352-
353-
354-
def _get_shape(o):
355-
if isinstance(o, dpt.usm_ndarray):
356-
return o.shape
357-
if _is_buffer(o):
358-
return memoryview(o).shape
359-
if isinstance(o, numbers.Number):
360-
return tuple()
361-
return getattr(o, "shape", tuple())
362-
363-
364287
class BinaryElementwiseFunc:
365288
"""
366289
Class that implements binary element-wise functions.

dpctl/tensor/_scalar_utils.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import numbers
18+
19+
import numpy as np
20+
21+
import dpctl.memory as dpm
22+
import dpctl.tensor as dpt
23+
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
24+
25+
from ._type_utils import (
26+
WeakBooleanType,
27+
WeakComplexType,
28+
WeakFloatingType,
29+
WeakIntegralType,
30+
_to_device_supported_dtype,
31+
)
32+
33+
34+
def _get_queue_usm_type(o):
35+
"""Return SYCL device where object `o` allocated memory, or None."""
36+
if isinstance(o, dpt.usm_ndarray):
37+
return o.sycl_queue, o.usm_type
38+
elif hasattr(o, "__sycl_usm_array_interface__"):
39+
try:
40+
m = dpm.as_usm_memory(o)
41+
return m.sycl_queue, m.get_usm_type()
42+
except Exception:
43+
return None, None
44+
return None, None
45+
46+
47+
def _get_dtype(o, dev):
48+
if isinstance(o, dpt.usm_ndarray):
49+
return o.dtype
50+
if hasattr(o, "__sycl_usm_array_interface__"):
51+
return dpt.asarray(o).dtype
52+
if _is_buffer(o):
53+
host_dt = np.array(o).dtype
54+
dev_dt = _to_device_supported_dtype(host_dt, dev)
55+
return dev_dt
56+
if hasattr(o, "dtype"):
57+
dev_dt = _to_device_supported_dtype(o.dtype, dev)
58+
return dev_dt
59+
if isinstance(o, bool):
60+
return WeakBooleanType(o)
61+
if isinstance(o, int):
62+
return WeakIntegralType(o)
63+
if isinstance(o, float):
64+
return WeakFloatingType(o)
65+
if isinstance(o, complex):
66+
return WeakComplexType(o)
67+
return np.object_
68+
69+
70+
def _validate_dtype(dt) -> bool:
71+
return isinstance(
72+
dt,
73+
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
74+
) or (
75+
isinstance(dt, dpt.dtype)
76+
and dt
77+
in [
78+
dpt.bool,
79+
dpt.int8,
80+
dpt.uint8,
81+
dpt.int16,
82+
dpt.uint16,
83+
dpt.int32,
84+
dpt.uint32,
85+
dpt.int64,
86+
dpt.uint64,
87+
dpt.float16,
88+
dpt.float32,
89+
dpt.float64,
90+
dpt.complex64,
91+
dpt.complex128,
92+
]
93+
)
94+
95+
96+
def _get_shape(o):
97+
if isinstance(o, dpt.usm_ndarray):
98+
return o.shape
99+
if _is_buffer(o):
100+
return memoryview(o).shape
101+
if isinstance(o, numbers.Number):
102+
return tuple()
103+
return getattr(o, "shape", tuple())
104+
105+
106+
__all__ = [
107+
"_get_dtype",
108+
"_get_queue_usm_type",
109+
"_get_shape",
110+
"_validate_dtype",
111+
]

dpctl/tensor/_search_functions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@
1717
import dpctl
1818
import dpctl.tensor as dpt
1919
import dpctl.tensor._tensor_impl as ti
20-
from dpctl.tensor._elementwise_common import (
20+
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
21+
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
22+
23+
from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
24+
from ._scalar_utils import (
2125
_get_dtype,
2226
_get_queue_usm_type,
2327
_get_shape,
2428
_validate_dtype,
2529
)
26-
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
27-
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
28-
29-
from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
3030
from ._type_utils import (
3131
WeakBooleanType,
3232
WeakComplexType,

dpctl/tensor/_utility_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
import dpctl.tensor._tensor_impl as ti
2222
import dpctl.tensor._tensor_reductions_impl as tri
2323
import dpctl.utils as du
24-
from dpctl.tensor._elementwise_common import (
24+
25+
from ._numpy_helper import normalize_axis_index, normalize_axis_tuple
26+
from ._scalar_utils import (
2527
_get_dtype,
2628
_get_queue_usm_type,
2729
_get_shape,
2830
_validate_dtype,
2931
)
30-
31-
from ._numpy_helper import normalize_axis_index, normalize_axis_tuple
3232
from ._type_utils import (
3333
_resolve_one_strong_one_weak_types,
3434
_resolve_one_strong_two_weak_types,

0 commit comments

Comments
 (0)