Skip to content

Commit 6ccb204

Browse files
authored
feat: support Intel GPUs in Array API testing (scikit-learn#31650)
1 parent ba954b7 commit 6ccb204

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

sklearn/utils/_array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def yield_namespace_device_dtype_combinations(include_numpy_namespaces=True):
8686
):
8787
if array_namespace == "torch":
8888
for device, dtype in itertools.product(
89-
("cpu", "cuda"), ("float64", "float32")
89+
("cpu", "cuda", "xpu"), ("float64", "float32")
9090
):
9191
yield array_namespace, device, dtype
9292
yield array_namespace, "mps", "float32"

sklearn/utils/_testing.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,17 @@ def _array_api_for_tests(array_namespace, device):
13431343
"MPS is not available because the current PyTorch install was not "
13441344
"built with MPS enabled."
13451345
)
1346+
elif array_namespace == "torch" and device == "xpu": # pragma: nocover
1347+
if not hasattr(xp, "xpu"):
1348+
# skip xpu testing for PyTorch <2.4
1349+
raise SkipTest(
1350+
"XPU is not available because the current PyTorch install was not "
1351+
"built with XPU support."
1352+
)
1353+
if not xp.xpu.is_available():
1354+
raise SkipTest(
1355+
"Skipping XPU device test because no XPU device is available"
1356+
)
13461357
elif array_namespace == "cupy": # pragma: nocover
13471358
import cupy
13481359

0 commit comments

Comments
 (0)