16
16
17
17
"""Tools to support array_api."""
18
18
19
+ import itertools
20
+
19
21
import numpy as np
20
22
21
23
from daal4py .sklearn ._utils import sklearn_check_version
22
- from onedal .utils ._array_api import _get_sycl_namespace
24
+ from onedal .utils ._array_api import _asarray , _get_sycl_namespace
23
25
26
+ # TODO:
27
+ # check the version of skl.
24
28
if sklearn_check_version ("1.2" ):
25
29
from sklearn .utils ._array_api import get_namespace as sklearn_get_namespace
30
+ from sklearn .utils ._array_api import _convert_to_numpy as _sklearn_convert_to_numpy
26
31
32
+ from onedal ._device_offload import dpctl_available , dpnp_available
33
+
34
+ if dpctl_available :
35
+ import dpctl .tensor as dpt
36
+
37
+ if dpnp_available :
38
+ import dpnp
39
+
40
+ _NUMPY_NAMESPACE_NAMES = {"numpy" , "array_api_compat.numpy" }
27
41
28
- def get_namespace (* arrays ):
29
- """Get namespace of arrays.
30
42
31
- Introspect `arrays` arguments and return their common Array API
32
- compatible namespace object, if any. NumPy 1.22 and later can
33
- construct such containers using the `numpy.array_api` namespace
34
- for instance.
43
+ def yield_namespaces (include_numpy_namespaces = True ):
44
+ """Yield supported namespace.
35
45
36
- This function will return the namespace of SYCL-related arrays
37
- which define the __sycl_usm_array_interface__ attribute
38
- regardless of array_api support, the configuration of
39
- array_api_dispatch, or scikit-learn version.
46
+ This is meant to be used for testing purposes only.
47
+
48
+ Parameters
49
+ ----------
50
+ include_numpy_namespaces : bool, default=True
51
+ If True, also yield numpy namespaces.
52
+
53
+ Returns
54
+ -------
55
+ array_namespace : str
56
+ The name of the Array API namespace.
57
+ """
58
+ for array_namespace in [
59
+ # The following is used to test the array_api_compat wrapper when
60
+ # array_api_dispatch is enabled: in particular, the arrays used in the
61
+ # tests are regular numpy arrays without any "device" attribute.
62
+ "numpy" ,
63
+ # Stricter NumPy-based Array API implementation. The
64
+ # array_api_strict.Array instances always have a dummy "device" attribute.
65
+ "array_api_strict" ,
66
+ "dpctl.tensor" ,
67
+ "cupy" ,
68
+ "torch" ,
69
+ ]:
70
+ if not include_numpy_namespaces and array_namespace in _NUMPY_NAMESPACE_NAMES :
71
+ continue
72
+ yield array_namespace
73
+
74
+
75
+ def yield_namespace_device_dtype_combinations (include_numpy_namespaces = True ):
76
+ """Yield supported namespace, device, dtype tuples for testing.
77
+
78
+ Use this to test that an estimator works with all combinations.
40
79
41
- See: https://numpy.org/neps/nep-0047-array-api-standard.html
80
+ Parameters
81
+ ----------
82
+ include_numpy_namespaces : bool, default=True
83
+ If True, also yield numpy namespaces.
42
84
43
- If `arrays` are regular numpy arrays, an instance of the
44
- `_NumPyApiWrapper` compatibility wrapper is returned instead.
85
+ Returns
86
+ -------
87
+ array_namespace : str
88
+ The name of the Array API namespace.
45
89
46
- Namespace support is not enabled by default. To enabled it
47
- call:
90
+ device : str
91
+ The name of the device on which to allocate the arrays. Can be None to
92
+ indicate that the default value should be used.
48
93
49
- sklearn.set_config(array_api_dispatch=True)
94
+ dtype_name : str
95
+ The name of the data type to use for arrays. Can be None to indicate
96
+ that the default value should be used.
97
+ """
98
+ for array_namespace in yield_namespaces (
99
+ include_numpy_namespaces = include_numpy_namespaces
100
+ ):
101
+ if array_namespace == "torch" :
102
+ for device , dtype in itertools .product (
103
+ ("cpu" , "cuda" ), ("float64" , "float32" )
104
+ ):
105
+ yield array_namespace , device , dtype
106
+ yield array_namespace , "mps" , "float32"
107
+ elif array_namespace == "dpctl.tensor" :
108
+ for device , dtype in itertools .product (
109
+ ("cpu" , "gpu" ), ("float64" , "float32" )
110
+ ):
111
+ yield array_namespace , device , dtype
112
+ else :
113
+ yield array_namespace , None , None
114
+
115
+
116
+ def _convert_to_numpy (array , xp ):
117
+ """Convert X into a NumPy ndarray on the CPU."""
118
+ xp_name = xp .__name__
119
+
120
+ # if dpctl_available and isinstance(array, dpctl.tensor):
121
+ if dpctl_available and xp_name in {
122
+ "dpctl.tensor" ,
123
+ }:
124
+ return dpt .to_numpy (array )
125
+ elif dpnp_available and isinstance (array , dpnp .ndarray ):
126
+ return dpnp .asnumpy (array )
127
+ elif sklearn_check_version ("1.2" ):
128
+ return _sklearn_convert_to_numpy (array , xp )
129
+ else :
130
+ return _asarray (array , xp )
50
131
51
- or:
52
132
53
- with sklearn.config_context(array_api_dispatch=True ):
54
- # your code here
133
+ def get_namespace ( * arrays , remove_none = True , remove_types = ( str ,), xp = None ):
134
+ """Get namespace of arrays.
55
135
56
- Otherwise an instance of the `_NumPyApiWrapper`
57
- compatibility wrapper is always returned irrespective of
58
- the fact that arrays implement the `__array_namespace__`
59
- protocol or not.
136
+ TBD
60
137
61
138
Parameters
62
139
----------
@@ -72,11 +149,13 @@ def get_namespace(*arrays):
72
149
True of the arrays are containers that implement the Array API spec.
73
150
"""
74
151
75
- sycl_type , xp , is_array_api_compliant = _get_sycl_namespace (* arrays )
152
+ sycl_type , xp_sycl_namespace , is_array_api_compliant = _get_sycl_namespace (* arrays )
76
153
77
154
if sycl_type :
78
- return xp , is_array_api_compliant
155
+ return xp_sycl_namespace , is_array_api_compliant
79
156
elif sklearn_check_version ("1.2" ):
80
- return sklearn_get_namespace (* arrays )
157
+ return sklearn_get_namespace (
158
+ * arrays , remove_none = remove_none , remove_types = remove_types , xp = xp
159
+ )
81
160
else :
82
161
return np , False
0 commit comments