Skip to content

Commit ced43bf

Browse files
committed
ENH: Array API dispatching
1 parent 404e8c0 commit ced43bf

File tree

2 files changed

+150
-27
lines changed

2 files changed

+150
-27
lines changed

sklearnex/dispatcher.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def get_patch_map_core(preview=False):
128128
from ._config import get_config as get_config_sklearnex
129129
from ._config import set_config as set_config_sklearnex
130130

131+
# TODO:
132+
# check the version of skl.
133+
if sklearn_check_version("1.4"):
134+
import sklearn.utils._array_api as _array_api_module
135+
131136
if sklearn_check_version("1.2.1"):
132137
from .utils.parallel import _FuncWrapper as _FuncWrapper_sklearnex
133138
else:
@@ -165,6 +170,15 @@ def get_patch_map_core(preview=False):
165170
from .svm import NuSVC as NuSVC_sklearnex
166171
from .svm import NuSVR as NuSVR_sklearnex
167172

173+
# TODO:
174+
# check the version of skl.
175+
if sklearn_check_version("1.4"):
176+
from .utils._array_api import _convert_to_numpy as _convert_to_numpy_sklearnex
177+
from .utils._array_api import get_namespace as get_namespace_sklearnex
178+
from .utils._array_api import (
179+
yield_namespace_device_dtype_combinations as yield_namespace_device_dtype_combinations_sklearnex,
180+
)
181+
168182
# DBSCAN
169183
mapping.pop("dbscan")
170184
mapping["dbscan"] = [[(cluster_module, "DBSCAN", DBSCAN_sklearnex), None]]
@@ -440,6 +454,36 @@ def get_patch_map_core(preview=False):
440454
mapping["_funcwrapper"] = [
441455
[(parallel_module, "_FuncWrapper", _FuncWrapper_sklearnex), None]
442456
]
457+
# TODO:
458+
# check the version of skl.
459+
if sklearn_check_version("1.4"):
460+
# Necessary for array_api support
461+
mapping["get_namespace"] = [
462+
[
463+
(
464+
_array_api_module,
465+
"get_namespace",
466+
get_namespace_sklearnex,
467+
),
468+
None,
469+
]
470+
]
471+
mapping["_convert_to_numpy"] = [
472+
[
473+
(_array_api_module, "_convert_to_numpy", _convert_to_numpy_sklearnex),
474+
None,
475+
]
476+
]
477+
mapping["yield_namespace_device_dtype_combinations"] = [
478+
[
479+
(
480+
_array_api_module,
481+
"yield_namespace_device_dtype_combinations",
482+
yield_namespace_device_dtype_combinations_sklearnex,
483+
),
484+
None,
485+
]
486+
]
443487
return mapping
444488

445489

sklearnex/utils/_array_api.py

Lines changed: 106 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,47 +16,124 @@
1616

1717
"""Tools to support array_api."""
1818

19+
import itertools
20+
1921
import numpy as np
2022

2123
from daal4py.sklearn._utils import sklearn_check_version
22-
from onedal.utils._array_api import _get_sycl_namespace
24+
from onedal.utils._array_api import _asarray, _get_sycl_namespace
2325

26+
# TODO:
27+
# check the version of skl.
2428
if sklearn_check_version("1.2"):
2529
from sklearn.utils._array_api import get_namespace as sklearn_get_namespace
30+
from sklearn.utils._array_api import _convert_to_numpy as _sklearn_convert_to_numpy
2631

32+
from onedal._device_offload import dpctl_available, dpnp_available
33+
34+
if dpctl_available:
35+
import dpctl.tensor as dpt
36+
37+
if dpnp_available:
38+
import dpnp
39+
40+
_NUMPY_NAMESPACE_NAMES = {"numpy", "array_api_compat.numpy"}
2741

28-
def get_namespace(*arrays):
29-
"""Get namespace of arrays.
3042

31-
Introspect `arrays` arguments and return their common Array API
32-
compatible namespace object, if any. NumPy 1.22 and later can
33-
construct such containers using the `numpy.array_api` namespace
34-
for instance.
43+
def yield_namespaces(include_numpy_namespaces=True):
44+
"""Yield supported namespace.
3545
36-
This function will return the namespace of SYCL-related arrays
37-
which define the __sycl_usm_array_interface__ attribute
38-
regardless of array_api support, the configuration of
39-
array_api_dispatch, or scikit-learn version.
46+
This is meant to be used for testing purposes only.
47+
48+
Parameters
49+
----------
50+
include_numpy_namespaces : bool, default=True
51+
If True, also yield numpy namespaces.
52+
53+
Returns
54+
-------
55+
array_namespace : str
56+
The name of the Array API namespace.
57+
"""
58+
for array_namespace in [
59+
# The following is used to test the array_api_compat wrapper when
60+
# array_api_dispatch is enabled: in particular, the arrays used in the
61+
# tests are regular numpy arrays without any "device" attribute.
62+
"numpy",
63+
# Stricter NumPy-based Array API implementation. The
64+
# array_api_strict.Array instances always have a dummy "device" attribute.
65+
"array_api_strict",
66+
"dpctl.tensor",
67+
"cupy",
68+
"torch",
69+
]:
70+
if not include_numpy_namespaces and array_namespace in _NUMPY_NAMESPACE_NAMES:
71+
continue
72+
yield array_namespace
73+
74+
75+
def yield_namespace_device_dtype_combinations(include_numpy_namespaces=True):
76+
"""Yield supported namespace, device, dtype tuples for testing.
77+
78+
Use this to test that an estimator works with all combinations.
4079
41-
See: https://numpy.org/neps/nep-0047-array-api-standard.html
80+
Parameters
81+
----------
82+
include_numpy_namespaces : bool, default=True
83+
If True, also yield numpy namespaces.
4284
43-
If `arrays` are regular numpy arrays, an instance of the
44-
`_NumPyApiWrapper` compatibility wrapper is returned instead.
85+
Returns
86+
-------
87+
array_namespace : str
88+
The name of the Array API namespace.
4589
46-
Namespace support is not enabled by default. To enabled it
47-
call:
90+
device : str
91+
The name of the device on which to allocate the arrays. Can be None to
92+
indicate that the default value should be used.
4893
49-
sklearn.set_config(array_api_dispatch=True)
94+
dtype_name : str
95+
The name of the data type to use for arrays. Can be None to indicate
96+
that the default value should be used.
97+
"""
98+
for array_namespace in yield_namespaces(
99+
include_numpy_namespaces=include_numpy_namespaces
100+
):
101+
if array_namespace == "torch":
102+
for device, dtype in itertools.product(
103+
("cpu", "cuda"), ("float64", "float32")
104+
):
105+
yield array_namespace, device, dtype
106+
yield array_namespace, "mps", "float32"
107+
elif array_namespace == "dpctl.tensor":
108+
for device, dtype in itertools.product(
109+
("cpu", "gpu"), ("float64", "float32")
110+
):
111+
yield array_namespace, device, dtype
112+
else:
113+
yield array_namespace, None, None
114+
115+
116+
def _convert_to_numpy(array, xp):
117+
"""Convert X into a NumPy ndarray on the CPU."""
118+
xp_name = xp.__name__
119+
120+
# if dpctl_available and isinstance(array, dpctl.tensor):
121+
if dpctl_available and xp_name in {
122+
"dpctl.tensor",
123+
}:
124+
return dpt.to_numpy(array)
125+
elif dpnp_available and isinstance(array, dpnp.ndarray):
126+
return dpnp.asnumpy(array)
127+
elif sklearn_check_version("1.2"):
128+
return _sklearn_convert_to_numpy(array, xp)
129+
else:
130+
return _asarray(array, xp)
50131

51-
or:
52132

53-
with sklearn.config_context(array_api_dispatch=True):
54-
# your code here
133+
def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):
134+
"""Get namespace of arrays.
55135
56-
Otherwise an instance of the `_NumPyApiWrapper`
57-
compatibility wrapper is always returned irrespective of
58-
the fact that arrays implement the `__array_namespace__`
59-
protocol or not.
136+
TBD
60137
61138
Parameters
62139
----------
@@ -72,11 +149,13 @@ def get_namespace(*arrays):
72149
True of the arrays are containers that implement the Array API spec.
73150
"""
74151

75-
sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays)
152+
sycl_type, xp_sycl_namespace, is_array_api_compliant = _get_sycl_namespace(*arrays)
76153

77154
if sycl_type:
78-
return xp, is_array_api_compliant
155+
return xp_sycl_namespace, is_array_api_compliant
79156
elif sklearn_check_version("1.2"):
80-
return sklearn_get_namespace(*arrays)
157+
return sklearn_get_namespace(
158+
*arrays, remove_none=remove_none, remove_types=remove_types, xp=xp
159+
)
81160
else:
82161
return np, False

0 commit comments

Comments
 (0)